Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/daemon/agent-model.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import type { Api, Model } from "@mariozechner/pi-ai";
import type { Api, Context, Model } from "@mariozechner/pi-ai";
import { getModel } from "@mariozechner/pi-ai";
import { resolveOpenAiModel } from "../llm/providers/models.js";
import { resolveOpenAiClientConfig } from "../llm/providers/openai.js";
import { createSyntheticModel } from "../llm/providers/shared.js";
import { buildAutoModelAttempts, envHasKey } from "../model-auto.js";
import { parseCliUserModelId } from "../run/env.js";
Expand Down Expand Up @@ -27,6 +29,8 @@ const REQUIRED_ENV_BY_PROVIDER: Record<string, string> = {
zai: "Z_AI_API_KEY",
};

const TEXT_ONLY_CONTEXT: Context = { messages: [] };

function parseProviderModelId(modelId: string): { provider: string; model: string } {
const trimmed = modelId.trim();
const slash = trimmed.indexOf("/");
Expand Down Expand Up @@ -184,6 +188,7 @@ export async function resolveAgentModel({
config,
configPath,
configForCli,
openaiUseChatCompletions,
apiKey,
openrouterApiKey,
anthropicApiKey,
Expand Down Expand Up @@ -238,6 +243,25 @@ export async function resolveAgentModel({

const applyBaseUrlOverride = (provider: string, modelId: string) => {
const baseUrl = providerBaseUrlMap[provider] ?? null;
if (provider === "openai") {
const openaiConfig = resolveOpenAiClientConfig({
apiKeys: {
openaiApiKey: apiKeys.openaiApiKey,
openrouterApiKey: apiKeys.openrouterApiKey,
},
openaiBaseUrlOverride: baseUrl,
forceChatCompletions: openaiUseChatCompletions,
allowProcessEnvBaseUrlFallback: false,
});
return {
provider,
model: resolveOpenAiModel({
modelId,
context: TEXT_ONLY_CONTEXT,
openaiConfig,
}),
};
}
const providerForPiAi = provider === "nvidia" ? "openai" : provider;
return {
provider,
Expand Down
6 changes: 5 additions & 1 deletion src/llm/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@ export type OpenAiClientConfigInput = {
forceOpenRouter?: boolean;
openaiBaseUrlOverride?: string | null;
forceChatCompletions?: boolean;
allowProcessEnvBaseUrlFallback?: boolean;
};

export function resolveOpenAiClientConfig({
apiKeys,
forceOpenRouter,
openaiBaseUrlOverride,
forceChatCompletions,
allowProcessEnvBaseUrlFallback = true,
}: OpenAiClientConfigInput): OpenAiClientConfig {
const baseUrlRaw =
openaiBaseUrlOverride ??
(typeof process !== "undefined" ? process.env.OPENAI_BASE_URL : undefined);
(allowProcessEnvBaseUrlFallback && typeof process !== "undefined"
? process.env.OPENAI_BASE_URL
: undefined);
const baseUrl = normalizeBaseUrl(baseUrlRaw);
const isOpenRouterViaBaseUrl = baseUrl ? isOpenRouterBaseUrl(baseUrl) : false;
const hasOpenRouterKey = apiKeys.openrouterApiKey != null;
Expand Down
138 changes: 130 additions & 8 deletions tests/daemon.agent.test.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import { mkdtempSync, writeFileSync, chmodSync } from "node:fs";
import { chmodSync, mkdirSync, mkdtempSync, writeFileSync } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import type { AssistantMessage, Tool } from "@mariozechner/pi-ai";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { completeAgentResponse } from "../src/daemon/agent.js";
import { completeAgentResponse, streamAgentResponse } from "../src/daemon/agent.js";
import { runCliModel } from "../src/llm/cli.js";
import * as modelAuto from "../src/model-auto.js";
import { makeTextDeltaStream } from "./helpers/pi-ai-mock.js";

const { mockCompleteSimple, mockGetModel } = vi.hoisted(() => ({
const { mockCompleteSimple, mockGetModel, mockStreamSimple } = vi.hoisted(() => ({
mockCompleteSimple: vi.fn(),
mockGetModel: vi.fn(),
mockStreamSimple: vi.fn(),
}));

vi.mock("../src/llm/cli.js", async (importOriginal) => {
Expand All @@ -24,13 +26,18 @@ vi.mock("@mariozechner/pi-ai", () => {
return {
completeSimple: mockCompleteSimple,
getModel: mockGetModel,
streamSimple: mockStreamSimple,
};
});

const buildAssistant = (provider: string, model: string): AssistantMessage => ({
const buildAssistant = (
provider: string,
model: string,
api: "openai-completions" | "openai-responses" = "openai-completions",
): AssistantMessage => ({
role: "assistant",
content: [{ type: "text", text: "ok" }],
api: "openai-completions",
api,
provider,
model,
usage: {
Expand All @@ -49,7 +56,7 @@ const makeModel = (provider: string, modelId: string) => ({
id: modelId,
name: modelId,
provider,
api: "openai-completions" as const,
api: (provider === "openai" ? "openai-responses" : "openai-completions") as const,
baseUrl: "https://example.com",
reasoning: false,
input: ["text"],
Expand All @@ -71,13 +78,22 @@ const makeFakeCliBin = (binary: string) => {
beforeEach(() => {
mockCompleteSimple.mockReset();
mockGetModel.mockReset();
mockStreamSimple.mockReset();
vi.mocked(runCliModel).mockReset();
vi.mocked(runCliModel).mockResolvedValue({ text: "cli agent", usage: null, costUsd: null });
mockGetModel.mockImplementation((provider: string, modelId: string) =>
makeModel(provider, modelId),
);
mockCompleteSimple.mockImplementation(async (model: { provider: string; id: string }) =>
buildAssistant(model.provider, model.id),
mockCompleteSimple.mockImplementation(
async (model: {
provider: string;
id: string;
api?: "openai-completions" | "openai-responses";
}) => buildAssistant(model.provider, model.id, model.api),
);
mockStreamSimple.mockImplementation(
(model: { provider: string; id: string; api?: "openai-completions" | "openai-responses" }) =>
makeTextDeltaStream(["ok"], buildAssistant(model.provider, model.id, model.api)),
);
});

Expand Down Expand Up @@ -116,6 +132,112 @@ describe("daemon/agent", () => {
expect(options.apiKey).toBe("sk-openai");
});

it("forces chat completions for streaming agent responses via OPENAI_USE_CHAT_COMPLETIONS", async () => {
const home = makeTempHome();
const chunks: string[] = [];

await streamAgentResponse({
env: {
HOME: home,
OPENAI_API_KEY: "sk-openai",
OPENAI_USE_CHAT_COMPLETIONS: "1",
},
pageUrl: "https://example.com",
pageTitle: null,
pageContent: "Hello world",
messages: [{ role: "user", content: "Hi" }],
modelOverride: "openai/gpt-5-mini",
tools: [],
automationEnabled: false,
onChunk: (text) => chunks.push(text),
onAssistant: () => {},
});

const model = mockStreamSimple.mock.calls[0]?.[0] as { api?: string };
expect(model.api).toBe("openai-completions");
expect(chunks.join("")).toBe("ok");
});

it("forces chat completions for agent responses via config", async () => {
const home = makeTempHome();
const configDir = join(home, ".summarize");
mkdirSync(configDir, { recursive: true });
writeFileSync(
join(configDir, "config.json"),
JSON.stringify({ openai: { useChatCompletions: true } }),
"utf8",
);

await completeAgentResponse({
env: { HOME: home, OPENAI_API_KEY: "sk-openai" },
pageUrl: "https://example.com",
pageTitle: null,
pageContent: "Hello world",
messages: [{ role: "user", content: "Hi" }],
modelOverride: "openai/gpt-5-mini",
tools: [],
automationEnabled: false,
});

const model = mockCompleteSimple.mock.calls[0]?.[0] as { api?: string };
expect(model.api).toBe("openai-completions");
});

it("uses chat completions for custom OpenAI-compatible base URLs", async () => {
const home = makeTempHome();

await completeAgentResponse({
env: {
HOME: home,
OPENAI_API_KEY: "sk-openai",
OPENAI_BASE_URL: "http://127.0.0.1:1234/v1",
},
pageUrl: "https://example.com",
pageTitle: null,
pageContent: "Hello world",
messages: [{ role: "user", content: "Hi" }],
modelOverride: "openai/gpt-5-mini",
tools: [],
automationEnabled: false,
});

const model = mockCompleteSimple.mock.calls[0]?.[0] as { api?: string; baseUrl?: string };
expect(model.api).toBe("openai-completions");
expect(model.baseUrl).toBe("http://127.0.0.1:1234/v1");
});

it("ignores ambient process OPENAI_BASE_URL when agent env snapshot does not set it", async () => {
const home = makeTempHome();
const originalProcessBaseUrl = process.env.OPENAI_BASE_URL;
process.env.OPENAI_BASE_URL = "https://ambient.example/v1";

try {
await completeAgentResponse({
env: {
HOME: home,
OPENAI_API_KEY: "sk-openai",
},
pageUrl: "https://example.com",
pageTitle: null,
pageContent: "Hello world",
messages: [{ role: "user", content: "Hi" }],
modelOverride: "openai/gpt-5-mini",
tools: [],
automationEnabled: false,
});
} finally {
if (typeof originalProcessBaseUrl === "string") {
process.env.OPENAI_BASE_URL = originalProcessBaseUrl;
} else {
delete process.env.OPENAI_BASE_URL;
}
}

const model = mockCompleteSimple.mock.calls[0]?.[0] as { api?: string; baseUrl?: string };
expect(model.api).toBe("openai-responses");
expect(model.baseUrl).toBe("https://example.com");
});

it("throws a helpful error when openrouter key is missing", async () => {
const home = makeTempHome();
await expect(
Expand Down