diff --git a/lib/chat/__tests__/handleChatGenerate.test.ts b/lib/chat/__tests__/handleChatGenerate.test.ts index 04b0a05d..1c03bf5f 100644 --- a/lib/chat/__tests__/handleChatGenerate.test.ts +++ b/lib/chat/__tests__/handleChatGenerate.test.ts @@ -26,6 +26,10 @@ vi.mock("@/lib/chat/setupChatRequest", () => ({ setupChatRequest: vi.fn(), })); +vi.mock("@/lib/chat/handleChatCompletion", () => ({ + handleChatCompletion: vi.fn(), +})); + vi.mock("ai", () => ({ generateText: vi.fn(), })); @@ -33,12 +37,14 @@ vi.mock("ai", () => ({ import { getApiKeyAccountId } from "@/lib/auth/getApiKeyAccountId"; import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId"; import { setupChatRequest } from "@/lib/chat/setupChatRequest"; +import { handleChatCompletion } from "@/lib/chat/handleChatCompletion"; import { generateText } from "ai"; import { handleChatGenerate } from "../handleChatGenerate"; const mockGetApiKeyAccountId = vi.mocked(getApiKeyAccountId); const mockValidateOverrideAccountId = vi.mocked(validateOverrideAccountId); const mockSetupChatRequest = vi.mocked(setupChatRequest); +const mockHandleChatCompletion = vi.mocked(handleChatCompletion); const mockGenerateText = vi.mocked(generateText); // Helper to create mock NextRequest @@ -58,6 +64,7 @@ function createMockRequest( describe("handleChatGenerate", () => { beforeEach(() => { vi.clearAllMocks(); + mockHandleChatCompletion.mockResolvedValue(); }); afterEach(() => { @@ -336,4 +343,169 @@ describe("handleChatGenerate", () => { ); }); }); + + describe("chat completion handling", () => { + it("calls handleChatCompletion with body and constructed UIMessage", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + mockSetupChatRequest.mockResolvedValue({ + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + mockGenerateText.mockResolvedValue({ + text: "Hello! How can I help you?", + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + response: { messages: [], headers: {}, body: null }, + } as any); + + const messages = [{ id: "msg-1", role: "user", parts: [{ type: "text", text: "Hi" }] }]; + const request = createMockRequest( + { messages, roomId: "room-123", artistId: "artist-456" }, + { "x-api-key": "valid-key" }, + ); + + await handleChatGenerate(request as any); + + // Verify handleChatCompletion was called + expect(mockHandleChatCompletion).toHaveBeenCalledTimes(1); + + // Verify the body contains the correct fields + expect(mockHandleChatCompletion).toHaveBeenCalledWith( + expect.objectContaining({ + messages, + roomId: "room-123", + artistId: "artist-456", + accountId: "account-123", + }), + expect.arrayContaining([ + expect.objectContaining({ + role: "assistant", + parts: expect.arrayContaining([ + expect.objectContaining({ + type: "text", + text: "Hello! How can I help you?", + }), + ]), + }), + ]), + ); + }); + + it("constructs UIMessage with correct structure from generateText result", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + mockSetupChatRequest.mockResolvedValue({ + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + mockGenerateText.mockResolvedValue({ + text: "Generated response text", + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + response: { messages: [], headers: {}, body: null }, + } as any); + + const request = createMockRequest( + { prompt: "Hello" }, + { "x-api-key": "valid-key" }, + ); + + await handleChatGenerate(request as any); + + // Get the UIMessage that was passed to handleChatCompletion + const [, responseMessages] = mockHandleChatCompletion.mock.calls[0]; + + expect(responseMessages).toHaveLength(1); + expect(responseMessages[0]).toMatchObject({ + id: expect.any(String), + role: "assistant", + parts: [ + { + type: "text", + text: "Generated response text", + }, + ], + }); + }); + + it("does not throw when handleChatCompletion fails", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + mockHandleChatCompletion.mockRejectedValue(new Error("Completion failed")); + + mockSetupChatRequest.mockResolvedValue({ + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + mockGenerateText.mockResolvedValue({ + text: "Response", + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + response: { messages: [], headers: {}, body: null }, + } as any); + + const request = createMockRequest( + { prompt: "Hello" }, + { "x-api-key": "valid-key" }, + ); + + // Should not throw even if handleChatCompletion fails + const result = await handleChatGenerate(request as any); + + expect(result.status).toBe(200); + const json = await result.json(); + expect(json.text).toBe("Response"); + }); + + it("handles empty text from generateText result", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + mockSetupChatRequest.mockResolvedValue({ + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + mockGenerateText.mockResolvedValue({ + text: "", + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 0 }, + response: { messages: [], headers: {}, body: null }, + } as any); + + const request = createMockRequest( + { prompt: "Hello" }, + { "x-api-key": "valid-key" }, + ); + + await handleChatGenerate(request as any); + + // Get the UIMessage that was passed to handleChatCompletion + const [, responseMessages] = mockHandleChatCompletion.mock.calls[0]; + + expect((responseMessages[0].parts[0] as { text: string }).text).toBe(""); + }); + }); }); diff --git a/lib/chat/__tests__/handleChatStream.test.ts b/lib/chat/__tests__/handleChatStream.test.ts index b78918e3..4528117d 100644 --- a/lib/chat/__tests__/handleChatStream.test.ts +++ b/lib/chat/__tests__/handleChatStream.test.ts @@ -26,6 +26,10 @@ vi.mock("@/lib/chat/setupChatRequest", () => ({ setupChatRequest: vi.fn(), })); +vi.mock("@/lib/chat/handleChatCompletion", () => ({ + handleChatCompletion: vi.fn(), +})); + vi.mock("ai", () => ({ createUIMessageStream: vi.fn(), createUIMessageStreamResponse: vi.fn(), @@ -34,12 +38,14 @@ vi.mock("ai", () => ({ import { getApiKeyAccountId } from "@/lib/auth/getApiKeyAccountId"; import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId"; import { setupChatRequest } from "@/lib/chat/setupChatRequest"; +import { handleChatCompletion } from "@/lib/chat/handleChatCompletion"; import { createUIMessageStream, createUIMessageStreamResponse } from "ai"; import { handleChatStream } from "../handleChatStream"; const mockGetApiKeyAccountId = vi.mocked(getApiKeyAccountId); const mockValidateOverrideAccountId = vi.mocked(validateOverrideAccountId); const mockSetupChatRequest = vi.mocked(setupChatRequest); +const mockHandleChatCompletion = vi.mocked(handleChatCompletion); const mockCreateUIMessageStream = vi.mocked(createUIMessageStream); const mockCreateUIMessageStreamResponse = vi.mocked(createUIMessageStreamResponse); @@ -60,6 +66,7 @@ function createMockRequest( describe("handleChatStream", () => { beforeEach(() => { vi.clearAllMocks(); + mockHandleChatCompletion.mockResolvedValue(); }); afterEach(() => { @@ -294,4 +301,145 @@ describe("handleChatStream", () => { ); }); }); + + describe("chat completion handling", () => { + it("passes onFinish callback to createUIMessageStream", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + const mockAgent = { + stream: vi.fn().mockResolvedValue({ + toUIMessageStream: vi.fn().mockReturnValue(new ReadableStream()), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + }), + tools: {}, + }; + + mockSetupChatRequest.mockResolvedValue({ + agent: mockAgent, + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + const mockStream = new ReadableStream(); + mockCreateUIMessageStream.mockReturnValue(mockStream); + mockCreateUIMessageStreamResponse.mockReturnValue(new Response(mockStream)); + + const request = createMockRequest({ prompt: "Hello" }, { "x-api-key": "valid-key" }); + + await handleChatStream(request as any); + + // Verify onFinish callback was passed to createUIMessageStream + expect(mockCreateUIMessageStream).toHaveBeenCalledWith( + expect.objectContaining({ + onFinish: expect.any(Function), + }), + ); + }); + + it("calls handleChatCompletion with body and messages when onFinish is triggered", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + const mockAgent = { + stream: vi.fn().mockResolvedValue({ + toUIMessageStream: vi.fn().mockReturnValue(new ReadableStream()), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + }), + tools: {}, + }; + + mockSetupChatRequest.mockResolvedValue({ + agent: mockAgent, + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + const mockStream = new ReadableStream(); + mockCreateUIMessageStream.mockReturnValue(mockStream); + mockCreateUIMessageStreamResponse.mockReturnValue(new Response(mockStream)); + + const messages = [{ id: "msg-1", role: "user", parts: [{ type: "text", text: "Hi" }] }]; + const request = createMockRequest( + { messages, roomId: "room-123", artistId: "artist-456" }, + { "x-api-key": "valid-key" }, + ); + + await handleChatStream(request as any); + + // Get the onFinish callback that was passed to createUIMessageStream + const createUIMessageStreamCall = mockCreateUIMessageStream.mock.calls[0][0] as { + onFinish: (params: { messages: unknown[] }) => void; + }; + const onFinishCallback = createUIMessageStreamCall.onFinish; + + // Simulate onFinish being called with response messages + const responseMessages = [ + { id: "resp-1", role: "assistant", parts: [{ type: "text", text: "Hello!" }] }, + ]; + onFinishCallback({ messages: responseMessages }); + + // Verify handleChatCompletion was called with correct arguments + expect(mockHandleChatCompletion).toHaveBeenCalledWith( + expect.objectContaining({ + messages, + roomId: "room-123", + artistId: "artist-456", + accountId: "account-123", + }), + responseMessages, + ); + }); + + it("does not throw when handleChatCompletion fails", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + mockHandleChatCompletion.mockRejectedValue(new Error("Completion failed")); + + const mockAgent = { + stream: vi.fn().mockResolvedValue({ + toUIMessageStream: vi.fn().mockReturnValue(new ReadableStream()), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + }), + tools: {}, + }; + + mockSetupChatRequest.mockResolvedValue({ + agent: mockAgent, + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + const mockStream = new ReadableStream(); + mockCreateUIMessageStream.mockReturnValue(mockStream); + mockCreateUIMessageStreamResponse.mockReturnValue(new Response(mockStream)); + + const request = createMockRequest({ prompt: "Hello" }, { "x-api-key": "valid-key" }); + + // Should not throw even if handleChatCompletion fails + const result = await handleChatStream(request as any); + expect(result).toBeInstanceOf(Response); + + // Trigger onFinish to ensure error is caught gracefully + const createUIMessageStreamCall = mockCreateUIMessageStream.mock.calls[0][0] as { + onFinish: (params: { messages: unknown[] }) => void; + }; + const onFinishCallback = createUIMessageStreamCall.onFinish; + + // This should not throw + expect(() => onFinishCallback({ messages: [] })).not.toThrow(); + }); + }); }); diff --git a/lib/chat/handleChatGenerate.ts b/lib/chat/handleChatGenerate.ts index d708bcff..99ead14c 100644 --- a/lib/chat/handleChatGenerate.ts +++ b/lib/chat/handleChatGenerate.ts @@ -1,8 +1,10 @@ import { NextRequest, NextResponse } from "next/server"; -import { generateText } from "ai"; +import { generateText, UIMessage } from "ai"; import { validateChatRequest } from "./validateChatRequest"; import { setupChatRequest } from "./setupChatRequest"; +import { handleChatCompletion } from "./handleChatCompletion"; import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; +import generateUUID from "@/lib/uuid/generateUUID"; /** * Handles a non-streaming chat generate request. @@ -28,8 +30,20 @@ export async function handleChatGenerate(request: NextRequest): Promise { + console.error("Failed to handle chat completion:", e); + }); return NextResponse.json( { diff --git a/lib/chat/handleChatStream.ts b/lib/chat/handleChatStream.ts index 396a66ec..1f37b4f3 100644 --- a/lib/chat/handleChatStream.ts +++ b/lib/chat/handleChatStream.ts @@ -2,6 +2,7 @@ import { NextRequest, NextResponse } from "next/server"; import { createUIMessageStream, createUIMessageStreamResponse } from "ai"; import { validateChatRequest } from "./validateChatRequest"; import { setupChatRequest } from "./setupChatRequest"; +import { handleChatCompletion } from "./handleChatCompletion"; import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; import generateUUID from "@/lib/uuid/generateUUID"; @@ -34,8 +35,6 @@ export async function handleChatStream(request: NextRequest): Promise const { writer } = options; const result = await agent.stream(chatConfig); writer.merge(result.toUIMessageStream()); - // Note: Credit handling and chat completion handling will be added - // as part of the handleChatCredits and handleChatCompletion migrations }, onError: (e) => { console.error("/api/chat onError:", e); @@ -44,6 +43,11 @@ export async function handleChatStream(request: NextRequest): Promise message: e instanceof Error ? e.message : "Unknown error", }); }, + onFinish: ({ messages }) => { + void handleChatCompletion(body, messages).catch((e) => { + console.error("Failed to handle chat completion:", e); + }); + }, }); return createUIMessageStreamResponse({ stream });