diff --git a/assets/oh-my-opencode.schema.json b/assets/oh-my-opencode.schema.json index a8710453c5..c51a496c14 100644 --- a/assets/oh-my-opencode.schema.json +++ b/assets/oh-my-opencode.schema.json @@ -3702,6 +3702,62 @@ }, "additionalProperties": false }, + "model_scheduler": { + "type": "object", + "properties": { + "enabled": { + "type": "boolean" + }, + "interval_minutes": { + "type": "integer", + "minimum": 1, + "maximum": 1440 + }, + "mode": { + "type": "string", + "enum": [ + "observe", + "dry-run", + "active" + ] + }, + "preflight_on_session_created": { + "type": "boolean" + }, + "failure_threshold": { + "type": "integer", + "minimum": 1, + "maximum": 10 + }, + "recovery_threshold": { + "type": "integer", + "minimum": 1, + "maximum": 10 + }, + "agent_cooldown_minutes": { + "type": "integer", + "minimum": 0, + "maximum": 1440 + }, + "protect_manual_routing": { + "type": "boolean" + }, + "probe_enabled": { + "type": "boolean" + }, + "probe_timeout_ms": { + "type": "integer", + "minimum": 1000, + "maximum": 300000 + }, + "probe_max_latency_ms": { + "type": "integer", + "minimum": 100, + "maximum": 300000 + } + }, + "additionalProperties": false + }, "babysitting": { "type": "object", "properties": { diff --git a/bun.lock b/bun.lock index e99fefc76c..041ae4074d 100644 --- a/bun.lock +++ b/bun.lock @@ -29,17 +29,17 @@ "typescript": "^5.7.3", }, "optionalDependencies": { - "oh-my-opencode-darwin-arm64": "3.10.0", - "oh-my-opencode-darwin-x64": "3.10.0", - "oh-my-opencode-darwin-x64-baseline": "3.10.0", - "oh-my-opencode-linux-arm64": "3.10.0", - "oh-my-opencode-linux-arm64-musl": "3.10.0", - "oh-my-opencode-linux-x64": "3.10.0", - "oh-my-opencode-linux-x64-baseline": "3.10.0", - "oh-my-opencode-linux-x64-musl": "3.10.0", - "oh-my-opencode-linux-x64-musl-baseline": "3.10.0", - "oh-my-opencode-windows-x64": "3.10.0", - "oh-my-opencode-windows-x64-baseline": "3.10.0", + "oh-my-opencode-darwin-arm64": "3.10.1", + "oh-my-opencode-darwin-x64": "3.10.1", + "oh-my-opencode-darwin-x64-baseline": "3.10.1", + "oh-my-opencode-linux-arm64": "3.10.1", + "oh-my-opencode-linux-arm64-musl": "3.10.1", + "oh-my-opencode-linux-x64": "3.10.1", + "oh-my-opencode-linux-x64-baseline": "3.10.1", + "oh-my-opencode-linux-x64-musl": "3.10.1", + "oh-my-opencode-linux-x64-musl-baseline": "3.10.1", + "oh-my-opencode-windows-x64": "3.10.1", + "oh-my-opencode-windows-x64-baseline": "3.10.1", }, }, }, @@ -238,28 +238,6 @@ "object-inspect": ["object-inspect@1.13.4", "", {}, "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew=="], - "oh-my-opencode-darwin-arm64": ["oh-my-opencode-darwin-arm64@3.10.0", "", { "os": "darwin", "cpu": "arm64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-KQ1Nva4eU03WIaQI8BiEgizYJAeddUIaC8dmks0Ug/2EkH6VyNj41+shI58HFGN9Jlg9Fd6MxpOW92S3JUHjOw=="], - - "oh-my-opencode-darwin-x64": ["oh-my-opencode-darwin-x64@3.10.0", "", { "os": "darwin", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-PydZ6wKyLZzikSZA3Q89zKZwFyg0Ouqd/S6zDsf1zzpUWT1t5EcpBtYFwuscD7L4hdkIEFm8wxnnBkz5i6BEiA=="], - - "oh-my-opencode-darwin-x64-baseline": ["oh-my-opencode-darwin-x64-baseline@3.10.0", "", { "os": "darwin", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-yOaVd0E1qspT2xP/BMJaJ/rpFTwkOh9U/SAk6uOuxHld6dZGI9e2Oq8F3pSD16xHnnpaz4VzadtT6HkvPdtBYg=="], - - "oh-my-opencode-linux-arm64": ["oh-my-opencode-linux-arm64@3.10.0", "", { "os": "linux", "cpu": "arm64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-pLzcPMuzBb1tpVgqMilv7QdsE2xTMLCWT3b807mzjt0302fZTfm6emwymCG25RamHdq7+mI2B0rN7hjvbymFog=="], - - "oh-my-opencode-linux-arm64-musl": ["oh-my-opencode-linux-arm64-musl@3.10.0", "", { "os": "linux", "cpu": "arm64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-ca61zr+X8q0ipO2x72qU+4R6Dsr168OM9aXI6xDHbrr0l3XZlRO8xuwQidch1vE5QRv2/IJT10KjAFInCERDug=="], - - "oh-my-opencode-linux-x64": ["oh-my-opencode-linux-x64@3.10.0", "", { "os": "linux", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-m0Ys8Vnl8jUNRE5/aIseNOF1H57/W77xh3vkyBVfnjzHwQdEUWZz3IdoHaEWIFgIP2+fsNXRHqpx7Pbtuhxo6Q=="], - - "oh-my-opencode-linux-x64-baseline": ["oh-my-opencode-linux-x64-baseline@3.10.0", "", { "os": "linux", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-a6OhfqMXhOTq1On8YHRRlVsNtMx84kgNAnStk/sY1Dw0kXU68QK4tWXVF+wNdiRG3egeM2SvjhJ5RhWlr3CCNQ=="], - - "oh-my-opencode-linux-x64-musl": ["oh-my-opencode-linux-x64-musl@3.10.0", "", { "os": "linux", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-lZkoEWwmrlVoZKewHNslUmQ2D6eWi1YqsoZMTd3qRj8V4XI6TDZHxg86hw4oxZ/EnKO4un+r83tb09JAAb1nNQ=="], - - "oh-my-opencode-linux-x64-musl-baseline": ["oh-my-opencode-linux-x64-musl-baseline@3.10.0", "", { "os": "linux", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-UqArUpatMuen8+hZhMSbScaSmJlcwkEtf/IzDN1iYO0CttvhyYMUmm3el/1gWTAcaGNDFNkGmTli5WNYhnm2lA=="], - - "oh-my-opencode-windows-x64": ["oh-my-opencode-windows-x64@3.10.0", "", { "os": "win32", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode.exe" } }, "sha512-BivOu1+Yty9N6VSmNzmxROZqjQKu3ImWjooKZDfczvYLDQmZV104QcOKV6bmdOCpHrqQ7cvdbygmeiJeRoYShg=="], - - "oh-my-opencode-windows-x64-baseline": ["oh-my-opencode-windows-x64-baseline@3.10.0", "", { "os": "win32", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode.exe" } }, "sha512-BBv+dNPuh9LEuqXUJLXNsvi3vL30zS1qcJuzlq/s8rYHry+VvEVXCRcMm5Vo0CVna8bUZf5U8MDkGDHOAiTeEw=="], - "on-finished": ["on-finished@2.4.1", "", { "dependencies": { "ee-first": "1.1.1" } }, "sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg=="], "once": ["once@1.4.0", "", { "dependencies": { "wrappy": "1" } }, "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w=="], diff --git a/src/config/index.ts b/src/config/index.ts index 2f7f985783..995a42f3a7 100644 --- a/src/config/index.ts +++ b/src/config/index.ts @@ -9,6 +9,8 @@ export type { McpName, AgentName, HookName, + ModelSchedulerConfig, + ModelSchedulerMode, BuiltinCommandName, SisyphusAgentConfig, ExperimentalConfig, diff --git a/src/config/schema.ts b/src/config/schema.ts index 0d2c590ba2..5f5b5c31a4 100644 --- a/src/config/schema.ts +++ b/src/config/schema.ts @@ -12,6 +12,7 @@ export * from "./schema/experimental" export * from "./schema/fallback-models" export * from "./schema/git-master" export * from "./schema/hooks" +export * from "./schema/model-scheduler" export * from "./schema/notification" export * from "./schema/oh-my-opencode-config" export * from "./schema/ralph-loop" diff --git a/src/config/schema/model-scheduler.test.ts b/src/config/schema/model-scheduler.test.ts new file mode 100644 index 0000000000..098eac2958 --- /dev/null +++ b/src/config/schema/model-scheduler.test.ts @@ -0,0 +1,49 @@ +import { describe, expect, test } from "bun:test" +import { ZodError } from "zod/v4" +import { ModelSchedulerConfigSchema } from "./model-scheduler" + +describe("ModelSchedulerConfigSchema", () => { + test("parses valid scheduler config", () => { + const result = ModelSchedulerConfigSchema.parse({ + enabled: true, + interval_minutes: 60, + mode: "active", + preflight_on_session_created: true, + failure_threshold: 2, + recovery_threshold: 2, + agent_cooldown_minutes: 180, + protect_manual_routing: true, + probe_enabled: true, + probe_timeout_ms: 15000, + probe_max_latency_ms: 8000, + }) + + expect(result.mode).toBe("active") + expect(result.interval_minutes).toBe(60) + expect(result.probe_enabled).toBe(true) + }) + + test("rejects invalid interval", () => { + let thrownError: unknown + + try { + ModelSchedulerConfigSchema.parse({ interval_minutes: 0 }) + } catch (error) { + thrownError = error + } + + expect(thrownError).toBeInstanceOf(ZodError) + }) + + test("rejects invalid probe timeout", () => { + let thrownError: unknown + + try { + ModelSchedulerConfigSchema.parse({ probe_timeout_ms: 999 }) + } catch (error) { + thrownError = error + } + + expect(thrownError).toBeInstanceOf(ZodError) + }) +}) diff --git a/src/config/schema/model-scheduler.ts b/src/config/schema/model-scheduler.ts new file mode 100644 index 0000000000..37be2f2d67 --- /dev/null +++ b/src/config/schema/model-scheduler.ts @@ -0,0 +1,20 @@ +import { z } from "zod" + +export const ModelSchedulerModeSchema = z.enum(["observe", "dry-run", "active"]) + +export const ModelSchedulerConfigSchema = z.object({ + enabled: z.boolean().optional(), + interval_minutes: z.number().int().min(1).max(24 * 60).optional(), + mode: ModelSchedulerModeSchema.optional(), + preflight_on_session_created: z.boolean().optional(), + failure_threshold: z.number().int().min(1).max(10).optional(), + recovery_threshold: z.number().int().min(1).max(10).optional(), + agent_cooldown_minutes: z.number().int().min(0).max(24 * 60).optional(), + protect_manual_routing: z.boolean().optional(), + probe_enabled: z.boolean().optional(), + probe_timeout_ms: z.number().int().min(1000).max(300000).optional(), + probe_max_latency_ms: z.number().int().min(100).max(300000).optional(), +}) + +export type ModelSchedulerMode = z.infer +export type ModelSchedulerConfig = z.infer diff --git a/src/config/schema/oh-my-opencode-config.ts b/src/config/schema/oh-my-opencode-config.ts index d24bbef4d5..ddd9207b0a 100644 --- a/src/config/schema/oh-my-opencode-config.ts +++ b/src/config/schema/oh-my-opencode-config.ts @@ -12,6 +12,7 @@ import { BuiltinCommandNameSchema } from "./commands" import { ExperimentalConfigSchema } from "./experimental" import { GitMasterConfigSchema } from "./git-master" import { NotificationConfigSchema } from "./notification" +import { ModelSchedulerConfigSchema } from "./model-scheduler" import { RalphLoopConfigSchema } from "./ralph-loop" import { RuntimeFallbackConfigSchema } from "./runtime-fallback" import { SkillsConfigSchema } from "./skills" @@ -55,6 +56,7 @@ export const OhMyOpenCodeConfigSchema = z.object({ runtime_fallback: z.union([z.boolean(), RuntimeFallbackConfigSchema]).optional(), background_task: BackgroundTaskConfigSchema.optional(), notification: NotificationConfigSchema.optional(), + model_scheduler: ModelSchedulerConfigSchema.optional(), babysitting: BabysittingConfigSchema.optional(), git_master: GitMasterConfigSchema.optional(), browser_automation_engine: BrowserAutomationConfigSchema.optional(), diff --git a/src/features/model-scheduler/candidate-models.ts b/src/features/model-scheduler/candidate-models.ts new file mode 100644 index 0000000000..0ecdef1f38 --- /dev/null +++ b/src/features/model-scheduler/candidate-models.ts @@ -0,0 +1,59 @@ +import { + AGENT_MODEL_REQUIREMENTS, + CATEGORY_MODEL_REQUIREMENTS, + fuzzyMatchModel, +} from "../../shared" +import type { RoutingEntry, RoutingTargetKind } from "./types" + +function normalizeKey(value: string): string { + return value.trim().toLowerCase().replace(/[^a-z0-9]+/g, "-").replace(/^-+|-+$/g, "") +} + +function resolveCandidate(model: string | null | undefined, availableModels: Set): string[] { + if (!model) return [] + + for (const availableModel of availableModels) { + if (availableModel.toLowerCase() === model.trim().toLowerCase()) { + return [availableModel] + } + } + + return [] +} + +export function collectCandidateModels(args: { + kind: RoutingTargetKind + key: string + routingEntry?: RoutingEntry | null + currentModel: string | null + availableModels: Set +}): string[] { + const resolvedCandidates = new Set() + const pushResolved = (model: string | null | undefined) => { + for (const resolved of resolveCandidate(model, args.availableModels)) { + resolvedCandidates.add(resolved) + } + } + + pushResolved(args.currentModel) + for (const fallback of args.routingEntry?.fallback ?? []) { + pushResolved(fallback) + } + + const requirements = args.kind === "agent" + ? AGENT_MODEL_REQUIREMENTS[normalizeKey(args.key)] + : CATEGORY_MODEL_REQUIREMENTS[normalizeKey(args.key)] + + for (const fallbackEntry of requirements?.fallbackChain ?? []) { + const matchedModel = fuzzyMatchModel( + fallbackEntry.model, + args.availableModels, + fallbackEntry.providers, + ) + if (matchedModel) { + resolvedCandidates.add(matchedModel) + } + } + + return Array.from(resolvedCandidates) +} diff --git a/src/features/model-scheduler/constants.ts b/src/features/model-scheduler/constants.ts new file mode 100644 index 0000000000..264f11f5ac --- /dev/null +++ b/src/features/model-scheduler/constants.ts @@ -0,0 +1,17 @@ +export const MODEL_HEALTH_FILE = "model-health.json" +export const MODEL_SCHEDULER_AUDIT_FILE = "scheduler-audit.jsonl" +export const MODEL_ROUTING_FILE = "model-routing.json" + +export const DEFAULT_MODEL_SCHEDULER_CONFIG = { + enabled: true, + interval_minutes: 60, + mode: "active", + preflight_on_session_created: true, + failure_threshold: 1, + recovery_threshold: 1, + agent_cooldown_minutes: 180, + protect_manual_routing: true, + probe_enabled: true, + probe_timeout_ms: 15000, + probe_max_latency_ms: 8000, +} as const diff --git a/src/features/model-scheduler/health-store.ts b/src/features/model-scheduler/health-store.ts new file mode 100644 index 0000000000..c935c98499 --- /dev/null +++ b/src/features/model-scheduler/health-store.ts @@ -0,0 +1,60 @@ +import { existsSync, mkdirSync, readFileSync, renameSync, unlinkSync, writeFileSync } from "node:fs" +import { dirname, join } from "node:path" +import { getOmoOpenCodeCacheDir } from "../../shared" +import { MODEL_HEALTH_FILE, MODEL_SCHEDULER_AUDIT_FILE } from "./constants" +import type { ModelHealthSnapshot, ModelSchedulerAuditEntry } from "./types" + +function ensureParentDir(filePath: string): void { + const parentDir = dirname(filePath) + if (!existsSync(parentDir)) { + mkdirSync(parentDir, { recursive: true }) + } +} + +function writeJsonAtomic(filePath: string, data: unknown): void { + ensureParentDir(filePath) + const tempPath = `${filePath}.tmp.${Date.now()}` + + try { + writeFileSync(tempPath, JSON.stringify(data, null, 2), "utf-8") + renameSync(tempPath, filePath) + } catch (error) { + if (existsSync(tempPath)) { + unlinkSync(tempPath) + } + throw error + } +} + +export function getModelHealthFilePath(): string { + return join(getOmoOpenCodeCacheDir(), MODEL_HEALTH_FILE) +} + +export function getModelSchedulerAuditFilePath(): string { + return join(getOmoOpenCodeCacheDir(), MODEL_SCHEDULER_AUDIT_FILE) +} + +export function readModelHealthSnapshot(): ModelHealthSnapshot | null { + const filePath = getModelHealthFilePath() + if (!existsSync(filePath)) return null + + try { + const raw = readFileSync(filePath, "utf-8") + return JSON.parse(raw) as ModelHealthSnapshot + } catch { + return null + } +} + +export function writeModelHealthSnapshot(snapshot: ModelHealthSnapshot): void { + writeJsonAtomic(getModelHealthFilePath(), snapshot) +} + +export function appendModelSchedulerAuditEntry(entry: ModelSchedulerAuditEntry): void { + const filePath = getModelSchedulerAuditFilePath() + ensureParentDir(filePath) + writeFileSync(filePath, `${JSON.stringify(entry)}\n`, { + encoding: "utf-8", + flag: "a", + }) +} diff --git a/src/features/model-scheduler/index.ts b/src/features/model-scheduler/index.ts new file mode 100644 index 0000000000..2cd7a94712 --- /dev/null +++ b/src/features/model-scheduler/index.ts @@ -0,0 +1,8 @@ +export * from "./constants" +export * from "./candidate-models" +export * from "./health-store" +export * from "./model-probe" +export * from "./routing-store" +export * from "./scheduler" +export * from "./selector" +export * from "./types" diff --git a/src/features/model-scheduler/model-probe.ts b/src/features/model-scheduler/model-probe.ts new file mode 100644 index 0000000000..680f82c1f8 --- /dev/null +++ b/src/features/model-scheduler/model-probe.ts @@ -0,0 +1,245 @@ +import { normalizeSDKResponse } from "../../shared" +import { log } from "../../shared/logger" +import { normalizeModelFormat } from "../../shared/model-format-normalizer" +import { createPromptTimeoutContext } from "../../shared/prompt-timeout-context" +import type { ModelProbeResult } from "./types" + +type SchedulerSessionClient = { + create?: (args: { + body: { title: string; permission: Array<{ permission: string; action: "deny"; pattern: string }> } + query: { directory: string } + }) => Promise<{ data?: { id?: string }; error?: unknown }> + delete?: (args: { path: { id: string } }) => Promise + prompt?: (args: { + path: { id: string } + body: { + parts: Array<{ type: "text"; text: string }> + tools: { task: boolean; call_omo_agent: boolean; look_at: boolean; read: boolean; question: boolean } + model: { providerID: string; modelID: string } + } + signal?: AbortSignal + }) => Promise + messages?: (args: { path: { id: string } }) => Promise +} + +type SchedulerProbeClient = { + session?: SchedulerSessionClient +} + +type ProbeContext = { + directory: string + client: SchedulerProbeClient +} + +type ProbeConfig = { + probe_enabled: boolean + probe_timeout_ms: number + probe_max_latency_ms: number +} + +const MODEL_SCHEDULER_PROBE_PROMPT = "Reply with exactly OK." + +function isObject(value: unknown): value is Record { + return typeof value === "object" && value !== null +} + +function extractAssistantText(messages: unknown): string | null { + if (!Array.isArray(messages)) return null + + const assistantMessages = messages + .filter((message): message is Record => isObject(message)) + .filter((message) => { + const info = message["info"] + return isObject(info) && info["role"] === "assistant" + }) + + const lastAssistantMessage = assistantMessages.at(-1) + if (!lastAssistantMessage) return null + + const parts = lastAssistantMessage["parts"] + if (!Array.isArray(parts)) return null + + const text = parts + .filter((part): part is Record => isObject(part)) + .filter((part) => part["type"] === "text" && typeof part["text"] === "string") + .map((part) => part["text"]) + .join("\n") + .trim() + + return text.length > 0 ? text : null +} + +function createSkippedProbeResult(model: string, checkedAt: string): ModelProbeResult { + return { + model, + available: true, + status: "skipped", + checkedAt, + } +} + +export function createModelProbeRunner(ctx: ProbeContext, config: ProbeConfig, availableModels: Set) { + const memoizedResults = new Map>() + + const probeModel = async (model: string): Promise => { + const checkedAt = new Date().toISOString() + if (!availableModels.has(model)) { + return { + model, + available: false, + status: "unavailable", + checkedAt, + } + } + + if (!config.probe_enabled) { + return createSkippedProbeResult(model, checkedAt) + } + + const sessionClient = ctx.client.session + if (!sessionClient?.create || !sessionClient.prompt) { + return createSkippedProbeResult(model, checkedAt) + } + + let probeSessionId: string | undefined + let timeoutContext: ReturnType | undefined + + const normalizedModel = normalizeModelFormat(model) + if (!normalizedModel) { + return { + model, + available: true, + status: "error", + checkedAt, + error: "invalid model format", + } + } + + try { + const createResult = await sessionClient.create({ + body: { + title: `model scheduler probe: ${model}`, + permission: [{ permission: "question", action: "deny", pattern: "*" }], + }, + query: { directory: ctx.directory }, + }) + + if (createResult.error || !createResult.data?.id) { + return { + model, + available: true, + status: "error", + checkedAt, + error: createResult.error ? String(createResult.error) : "missing probe session id", + } + } + + probeSessionId = createResult.data.id + timeoutContext = createPromptTimeoutContext({}, config.probe_timeout_ms) + const startedAt = Date.now() + + await sessionClient.prompt({ + path: { id: probeSessionId }, + body: { + parts: [{ type: "text", text: MODEL_SCHEDULER_PROBE_PROMPT }], + tools: { + task: false, + call_omo_agent: false, + look_at: false, + read: false, + question: false, + }, + model: normalizedModel, + }, + signal: timeoutContext.signal, + }) + + if (timeoutContext.wasTimedOut()) { + return { + model, + available: true, + status: "timeout", + checkedAt, + error: `probe timed out after ${config.probe_timeout_ms}ms`, + } + } + + const latencyMs = Date.now() - startedAt + if (sessionClient.messages) { + const messagesResult = await sessionClient.messages({ path: { id: probeSessionId } }) + const messages = normalizeSDKResponse(messagesResult, [] as unknown[], { + preferResponseOnMissingData: true, + }) + if (!extractAssistantText(messages)) { + return { + model, + available: true, + status: "error", + checkedAt, + latencyMs, + error: "probe response contained no assistant text", + } + } + } + + if (latencyMs > config.probe_max_latency_ms) { + return { + model, + available: true, + status: "slow", + checkedAt, + latencyMs, + error: `probe latency ${latencyMs}ms exceeded ${config.probe_max_latency_ms}ms`, + } + } + + return { + model, + available: true, + status: "healthy", + checkedAt, + latencyMs, + } + } catch (error) { + return { + model, + available: true, + status: timeoutContext?.wasTimedOut() ? "timeout" : "error", + checkedAt, + error: timeoutContext?.wasTimedOut() + ? `probe timed out after ${config.probe_timeout_ms}ms` + : error instanceof Error + ? error.message + : String(error), + } + } finally { + timeoutContext?.cleanup() + if (probeSessionId && sessionClient.delete) { + try { + await sessionClient.delete({ path: { id: probeSessionId } }) + } catch (error) { + log( + `[model-scheduler] failed to delete probe session ${probeSessionId}: ${error instanceof Error ? error.message : String(error)}`, + ) + } + } + } + } + + return { + async probeModels(models: string[]): Promise> { + const uniqueModels = Array.from(new Set(models.filter((model) => model.length > 0))).sort() + + const entries = await Promise.all(uniqueModels.map(async (model) => { + const existingProbe = memoizedResults.get(model) + const probePromise = existingProbe ?? probeModel(model) + if (!existingProbe) { + memoizedResults.set(model, probePromise) + } + return [model, await probePromise] as const + })) + + return Object.fromEntries(entries) + }, + } +} diff --git a/src/features/model-scheduler/routing-store.ts b/src/features/model-scheduler/routing-store.ts new file mode 100644 index 0000000000..f2d352eedc --- /dev/null +++ b/src/features/model-scheduler/routing-store.ts @@ -0,0 +1,42 @@ +import { existsSync, mkdirSync, readFileSync, renameSync, unlinkSync, writeFileSync } from "node:fs" +import { dirname, join } from "node:path" +import { getOpenCodeConfigDir } from "../../shared" +import { MODEL_ROUTING_FILE } from "./constants" +import type { ModelRoutingFile } from "./types" + +function writeJsonAtomic(filePath: string, data: unknown): void { + const parentDir = dirname(filePath) + if (!existsSync(parentDir)) { + mkdirSync(parentDir, { recursive: true }) + } + + const tempPath = `${filePath}.tmp.${Date.now()}` + try { + writeFileSync(tempPath, JSON.stringify(data, null, 2), "utf-8") + renameSync(tempPath, filePath) + } catch (error) { + if (existsSync(tempPath)) { + unlinkSync(tempPath) + } + throw error + } +} + +export function getModelRoutingFilePath(): string { + return join(getOpenCodeConfigDir({ binary: "opencode", version: null }), MODEL_ROUTING_FILE) +} + +export function readModelRoutingFile(): ModelRoutingFile | null { + const filePath = getModelRoutingFilePath() + if (!existsSync(filePath)) return null + + try { + return JSON.parse(readFileSync(filePath, "utf-8")) as ModelRoutingFile + } catch { + return null + } +} + +export function writeModelRoutingFile(routing: ModelRoutingFile): void { + writeJsonAtomic(getModelRoutingFilePath(), routing) +} diff --git a/src/features/model-scheduler/scheduler.test.ts b/src/features/model-scheduler/scheduler.test.ts new file mode 100644 index 0000000000..2d8e9d7c52 --- /dev/null +++ b/src/features/model-scheduler/scheduler.test.ts @@ -0,0 +1,376 @@ +import { afterEach, describe, expect, it } from "bun:test" +import { existsSync, mkdtempSync, readFileSync, rmSync, writeFileSync } from "node:fs" +import { join } from "node:path" +import { tmpdir } from "node:os" +import { getModelHealthFilePath, getModelRoutingFilePath, runModelSchedulerCycle } from "./index" + +const tempDirs: string[] = [] + +function createTempDir(prefix: string): string { + const directory = mkdtempSync(join(tmpdir(), prefix)) + tempDirs.push(directory) + return directory +} + +afterEach(() => { + delete process.env.OPENCODE_CONFIG_DIR + delete process.env.XDG_CACHE_HOME + + while (tempDirs.length > 0) { + const directory = tempDirs.pop() + if (directory) { + rmSync(directory, { recursive: true, force: true }) + } + } +}) + +describe("model scheduler", () => { + it("rewrites unhealthy agent and category routing in active mode", async () => { + const configDir = createTempDir("omo-model-scheduler-config-") + const cacheHome = createTempDir("omo-model-scheduler-cache-") + process.env.OPENCODE_CONFIG_DIR = configDir + process.env.XDG_CACHE_HOME = cacheHome + + writeFileSync( + join(configDir, "model-routing.json"), + JSON.stringify({ + agentModelMapping: { + Hephaestus: { + primary: "augment-pro/gpt-5.4", + fallback: ["augment-pro/gpt-5.3-codex"], + }, + }, + categoryRouting: { + quick: "augment-pro/claude-haiku-4-5", + }, + }), + "utf-8", + ) + + const ctx = { + directory: "/project", + client: { + provider: { + list: async () => ({ + data: { + connected: ["augment-pro"], + all: [ + { + id: "augment-pro", + models: { + "gpt-5.3-codex": {}, + "claude-haiku-4-5": {}, + }, + }, + ], + }, + }), + }, + session: { + create: async () => ({ data: { id: "probe-session-1" } }), + prompt: async () => {}, + messages: async () => ({ + data: [{ info: { role: "assistant" }, parts: [{ type: "text", text: "OK" }] }], + }), + }, + }, + } + + const result = await runModelSchedulerCycle(ctx, { + mode: "active", + interval_minutes: 60, + }) + + expect(result).not.toBeNull() + + const routing = JSON.parse(readFileSync(getModelRoutingFilePath(), "utf-8")) as { + agentModelMapping: { Hephaestus: { primary: string; fallback: string[] } } + categoryRouting: { quick: string } + scheduler: { lastChangeCount: number } + } + expect(routing.agentModelMapping.Hephaestus.primary).toBe("augment-pro/gpt-5.3-codex") + expect(routing.categoryRouting.quick).toBe("augment-pro/claude-haiku-4-5") + expect(routing.scheduler.lastChangeCount).toBe(1) + expect(existsSync(getModelHealthFilePath())).toBe(true) + + const health = JSON.parse(readFileSync(getModelHealthFilePath(), "utf-8")) as { + probe: { enabled: boolean; checkedModelCount: number; models: Record } + } + expect(health.probe.enabled).toBe(true) + expect(health.probe.checkedModelCount).toBeGreaterThan(0) + expect(health.probe.models["augment-pro/gpt-5.3-codex"].status).toBe("healthy") + }) + + it("records health state without rewriting routing in observe mode", async () => { + const configDir = createTempDir("omo-model-scheduler-observe-config-") + const cacheHome = createTempDir("omo-model-scheduler-observe-cache-") + process.env.OPENCODE_CONFIG_DIR = configDir + process.env.XDG_CACHE_HOME = cacheHome + + writeFileSync( + join(configDir, "model-routing.json"), + JSON.stringify({ + agentModelMapping: { + Explore: { + primary: "augment-pro/gpt-5.2-codex", + fallback: ["augment-pro/claude-haiku-4-5"], + }, + }, + }), + "utf-8", + ) + + const ctx = { + directory: "/project", + client: { + provider: { + list: async () => ({ + data: { + connected: ["augment-pro"], + all: [ + { + id: "augment-pro", + models: { + "claude-haiku-4-5": {}, + }, + }, + ], + }, + }), + }, + session: { + create: async () => ({ data: { id: "probe-session-2" } }), + prompt: async () => {}, + messages: async () => ({ + data: [{ info: { role: "assistant" }, parts: [{ type: "text", text: "OK" }] }], + }), + }, + }, + } + + const result = await runModelSchedulerCycle(ctx, { + mode: "observe", + interval_minutes: 60, + }) + + expect(result?.auditEntry.changed).toBe(false) + + const routing = JSON.parse(readFileSync(getModelRoutingFilePath(), "utf-8")) as { + agentModelMapping: { Explore: { primary: string } } + } + expect(routing.agentModelMapping.Explore.primary).toBe("augment-pro/gpt-5.2-codex") + + const health = JSON.parse(readFileSync(getModelHealthFilePath(), "utf-8")) as { + agents: { Explore: { selectedModel: string; status: string } } + } + expect(health.agents.Explore.selectedModel).toBe("augment-pro/claude-haiku-4-5") + expect(health.agents.Explore.status).toBe("healthy") + }) + + it("treats slow probe results as unhealthy and reroutes away from them", async () => { + const configDir = createTempDir("omo-model-scheduler-probe-config-") + const cacheHome = createTempDir("omo-model-scheduler-probe-cache-") + process.env.OPENCODE_CONFIG_DIR = configDir + process.env.XDG_CACHE_HOME = cacheHome + + writeFileSync( + join(configDir, "model-routing.json"), + JSON.stringify({ + agentModelMapping: { + Hephaestus: { + primary: "augment-pro/gpt-5.4", + fallback: ["augment-pro/gpt-5.3-codex"], + }, + }, + }), + "utf-8", + ) + + const promptCalls: string[] = [] + const ctx = { + directory: "/project", + client: { + provider: { + list: async () => ({ + data: { + connected: ["augment-pro"], + all: [ + { + id: "augment-pro", + models: { + "gpt-5.4": {}, + "gpt-5.3-codex": {}, + }, + }, + ], + }, + }), + }, + session: { + create: async () => ({ data: { id: `probe-session-${promptCalls.length + 1}` } }), + prompt: async (args: { body: { model: { modelID: string } } }) => { + promptCalls.push(args.body.model.modelID) + if (args.body.model.modelID === "gpt-5.4") { + await new Promise((resolve) => setTimeout(resolve, 25)) + } + }, + messages: async () => ({ + data: [{ info: { role: "assistant" }, parts: [{ type: "text", text: "OK" }] }], + }), + }, + }, + } + + await runModelSchedulerCycle(ctx, { + mode: "active", + interval_minutes: 60, + probe_timeout_ms: 1000, + probe_max_latency_ms: 10, + }) + + const routing = JSON.parse(readFileSync(getModelRoutingFilePath(), "utf-8")) as { + agentModelMapping: { Hephaestus: { primary: string } } + } + const health = JSON.parse(readFileSync(getModelHealthFilePath(), "utf-8")) as { + agents: { Hephaestus: { status: string; reason: string; probeStatus: string } } + probe: { models: Record } + } + + expect(routing.agentModelMapping.Hephaestus.primary).toBe("augment-pro/gpt-5.3-codex") + expect(health.agents.Hephaestus.reason).toBe("latency-too-high") + expect(health.agents.Hephaestus.probeStatus).toBe("healthy") + expect(health.probe.models["augment-pro/gpt-5.4"].status).toBe("slow") + }) + + it("cleans up probe sessions after each scheduler cycle", async () => { + const configDir = createTempDir("omo-model-scheduler-cleanup-config-") + const cacheHome = createTempDir("omo-model-scheduler-cleanup-cache-") + process.env.OPENCODE_CONFIG_DIR = configDir + process.env.XDG_CACHE_HOME = cacheHome + + writeFileSync( + join(configDir, "model-routing.json"), + JSON.stringify({ + agentModelMapping: { + Hephaestus: { + primary: "augment-pro/gpt-5.4", + fallback: ["augment-pro/gpt-5.3-codex"], + }, + }, + }), + "utf-8", + ) + + const deletedSessionIds: string[] = [] + const ctx = { + directory: "/project", + client: { + provider: { + list: async () => ({ + data: { + connected: ["augment-pro"], + all: [ + { + id: "augment-pro", + models: { + "gpt-5.4": {}, + "gpt-5.3-codex": {}, + }, + }, + ], + }, + }), + }, + session: { + create: async (args: { body: { title: string } }) => ({ + data: { id: `probe-session-${args.body.title.split(": ")[1]}` }, + }), + delete: async (args: { path: { id: string } }) => { + deletedSessionIds.push(args.path.id) + return {} + }, + prompt: async () => {}, + messages: async () => ({ + data: [{ info: { role: "assistant" }, parts: [{ type: "text", text: "OK" }] }], + }), + }, + }, + } + + await runModelSchedulerCycle(ctx, { + mode: "active", + interval_minutes: 60, + }) + + expect(deletedSessionIds).toContain("probe-session-augment-pro/gpt-5.4") + expect(deletedSessionIds).toContain("probe-session-augment-pro/gpt-5.3-codex") + expect(deletedSessionIds).toHaveLength(2) + }) + + it("probes candidate models in parallel within a single cycle", async () => { + const configDir = createTempDir("omo-model-scheduler-parallel-config-") + const cacheHome = createTempDir("omo-model-scheduler-parallel-cache-") + process.env.OPENCODE_CONFIG_DIR = configDir + process.env.XDG_CACHE_HOME = cacheHome + + writeFileSync( + join(configDir, "model-routing.json"), + JSON.stringify({ + agentModelMapping: { + Hephaestus: { + primary: "augment-pro/gpt-5.4", + fallback: ["augment-pro/gpt-5.3-codex"], + }, + }, + }), + "utf-8", + ) + + let activePrompts = 0 + let maxConcurrentPrompts = 0 + const ctx = { + directory: "/project", + client: { + provider: { + list: async () => ({ + data: { + connected: ["augment-pro"], + all: [ + { + id: "augment-pro", + models: { + "gpt-5.4": {}, + "gpt-5.3-codex": {}, + }, + }, + ], + }, + }), + }, + session: { + create: async (args: { body: { title: string } }) => ({ + data: { id: `probe-session-${args.body.title.split(": ")[1]}` }, + }), + prompt: async () => { + activePrompts += 1 + maxConcurrentPrompts = Math.max(maxConcurrentPrompts, activePrompts) + await new Promise((resolve) => setTimeout(resolve, 25)) + activePrompts -= 1 + }, + messages: async () => ({ + data: [{ info: { role: "assistant" }, parts: [{ type: "text", text: "OK" }] }], + }), + }, + }, + } + + await runModelSchedulerCycle(ctx, { + mode: "active", + interval_minutes: 60, + probe_timeout_ms: 1000, + probe_max_latency_ms: 1000, + }) + + expect(maxConcurrentPrompts).toBeGreaterThan(1) + }) +}) diff --git a/src/features/model-scheduler/scheduler.ts b/src/features/model-scheduler/scheduler.ts new file mode 100644 index 0000000000..663d556dae --- /dev/null +++ b/src/features/model-scheduler/scheduler.ts @@ -0,0 +1,483 @@ +import { log } from "../../shared/logger" +import { + fetchAvailableModels, + readProviderModelsCache, + updateConnectedProvidersCache, +} from "../../shared" +import type { ModelSchedulerConfig } from "../../config" +import { collectCandidateModels } from "./candidate-models" +import { DEFAULT_MODEL_SCHEDULER_CONFIG } from "./constants" +import { + appendModelSchedulerAuditEntry, + readModelHealthSnapshot, + writeModelHealthSnapshot, +} from "./health-store" +import { createModelProbeRunner } from "./model-probe" +import { readModelRoutingFile, writeModelRoutingFile } from "./routing-store" +import { buildNextFallbackList, isModelHealthy, selectReplacementModel } from "./selector" +import type { + ModelHealthSnapshot, + ModelProbeResult, + ModelSchedulerAuditEntry, + ResolvedModelSchedulerConfig, + SchedulerChangeReason, + RoutingTargetHealth, + RoutingTargetKind, + SchedulerRunResult, +} from "./types" + +type SchedulerClient = { + provider?: { + list?: () => Promise<{ + data?: { + connected?: string[] + all?: Array<{ id: string; models?: Record }> + } + }> + } + model?: { + list?: () => Promise + } + tui?: { + showToast?: (input: { + body: { + title: string + message: string + variant: "info" | "success" | "warning" | "error" + duration?: number + } + }) => Promise + } + session?: { + create?: (args: { + body: { title: string; permission: Array<{ permission: string; action: "deny"; pattern: string }> } + query: { directory: string } + }) => Promise<{ data?: { id?: string }; error?: unknown }> + prompt?: (args: { + path: { id: string } + body: { + parts: Array<{ type: "text"; text: string }> + tools: { task: boolean; call_omo_agent: boolean; look_at: boolean; read: boolean; question: boolean } + model: { providerID: string; modelID: string } + } + signal?: AbortSignal + }) => Promise + messages?: (args: { path: { id: string } }) => Promise + } +} + +type SchedulerContext = { + directory: string + client: SchedulerClient +} + +function resolveConfig(config?: ModelSchedulerConfig): ResolvedModelSchedulerConfig { + return { + enabled: config?.enabled ?? DEFAULT_MODEL_SCHEDULER_CONFIG.enabled, + interval_minutes: config?.interval_minutes ?? DEFAULT_MODEL_SCHEDULER_CONFIG.interval_minutes, + mode: config?.mode ?? DEFAULT_MODEL_SCHEDULER_CONFIG.mode, + preflight_on_session_created: + config?.preflight_on_session_created ?? DEFAULT_MODEL_SCHEDULER_CONFIG.preflight_on_session_created, + failure_threshold: config?.failure_threshold ?? DEFAULT_MODEL_SCHEDULER_CONFIG.failure_threshold, + recovery_threshold: config?.recovery_threshold ?? DEFAULT_MODEL_SCHEDULER_CONFIG.recovery_threshold, + agent_cooldown_minutes: + config?.agent_cooldown_minutes ?? DEFAULT_MODEL_SCHEDULER_CONFIG.agent_cooldown_minutes, + protect_manual_routing: + config?.protect_manual_routing ?? DEFAULT_MODEL_SCHEDULER_CONFIG.protect_manual_routing, + probe_enabled: config?.probe_enabled ?? DEFAULT_MODEL_SCHEDULER_CONFIG.probe_enabled, + probe_timeout_ms: config?.probe_timeout_ms ?? DEFAULT_MODEL_SCHEDULER_CONFIG.probe_timeout_ms, + probe_max_latency_ms: + config?.probe_max_latency_ms ?? DEFAULT_MODEL_SCHEDULER_CONFIG.probe_max_latency_ms, + } +} + +function getModelFailureReason(probeResult: ModelProbeResult | undefined): SchedulerChangeReason { + if (!probeResult || !probeResult.available || probeResult.status === "unavailable") { + return "unavailable" + } + if (probeResult.status === "slow") { + return "latency-too-high" + } + if (probeResult.status === "error" || probeResult.status === "timeout") { + return "probe-failed" + } + return "unavailable" +} + +function getRoutingTargetStatus(args: { + effectiveHealthy: boolean + selectedHealthy: boolean + failureReason: SchedulerChangeReason +}): RoutingTargetHealth["status"] { + if (args.effectiveHealthy || args.selectedHealthy) { + return "healthy" + } + if (args.failureReason === "latency-too-high") { + return "degraded" + } + return "offline" +} + +function diffInventory( + previous: ReturnType, + current: ReturnType, +): ModelHealthSnapshot["inventory"] { + const added: Record = {} + const removed: Record = {} + + const providerIds = new Set([ + ...Object.keys(previous?.models ?? {}), + ...Object.keys(current?.models ?? {}), + ]) + + for (const providerId of providerIds) { + const previousModels = new Set( + ((previous?.models?.[providerId] ?? []) as Array).flatMap((entry) => + typeof entry === "string" ? [entry] : entry?.id ? [entry.id] : [], + ), + ) + const currentModels = new Set( + ((current?.models?.[providerId] ?? []) as Array).flatMap((entry) => + typeof entry === "string" ? [entry] : entry?.id ? [entry.id] : [], + ), + ) + + const addedModels = Array.from(currentModels).filter((model) => !previousModels.has(model)) + const removedModels = Array.from(previousModels).filter((model) => !currentModels.has(model)) + + if (addedModels.length > 0) added[providerId] = addedModels.sort() + if (removedModels.length > 0) removed[providerId] = removedModels.sort() + } + + return { added, removed } +} + +function getPreviousTargetHealth( + snapshot: ModelHealthSnapshot | null, + kind: RoutingTargetKind, + key: string, +): RoutingTargetHealth | null { + if (!snapshot) return null + return kind === "agent" ? snapshot.agents[key] ?? null : snapshot.categories[key] ?? null +} + +function getCooldownUntil(previous: RoutingTargetHealth | null, cooldownMinutes: number): string | undefined { + if (cooldownMinutes <= 0) return undefined + if (!previous?.changed) return undefined + const baseTime = previous.checkedAt ? Date.parse(previous.checkedAt) : Number.NaN + if (Number.isNaN(baseTime)) return undefined + return new Date(baseTime + cooldownMinutes * 60_000).toISOString() +} + +function isCooldownActive(cooldownUntil: string | undefined, nowIso: string): boolean { + if (!cooldownUntil) return false + return Date.parse(cooldownUntil) > Date.parse(nowIso) +} + +function nextHealthRecord(args: { + key: string + displayName: string + kind: RoutingTargetKind + currentModel: string | null + selectedModel: string | null + isHealthy: boolean + selectedHealthy: boolean + changed: boolean + status: RoutingTargetHealth["status"] + reason: RoutingTargetHealth["reason"] + checkedAt: string + previous: RoutingTargetHealth | null + cooldownUntil?: string + probeResult?: ModelProbeResult + selectedProbeResult?: ModelProbeResult +}): RoutingTargetHealth { + const previousFailures = args.previous?.consecutiveFailures ?? 0 + const previousSuccesses = args.previous?.consecutiveSuccesses ?? 0 + const currentProbe = args.probeResult + const selectedProbe = args.selectedProbeResult + + return { + key: args.key, + displayName: args.displayName, + kind: args.kind, + currentModel: args.currentModel, + status: args.status, + selectedModel: args.selectedModel, + changed: args.changed, + reason: args.reason, + checkedAt: args.checkedAt, + ...(args.cooldownUntil ? { cooldownUntil: args.cooldownUntil } : {}), + consecutiveFailures: args.status === "healthy" ? 0 : previousFailures + 1, + consecutiveSuccesses: args.status === "healthy" ? previousSuccesses + 1 : 0, + availabilityStatus: currentProbe?.available === false ? "unavailable" : "available", + probeStatus: selectedProbe?.status ?? currentProbe?.status, + probeLatencyMs: selectedProbe?.latencyMs ?? currentProbe?.latencyMs, + probeError: selectedProbe?.error ?? currentProbe?.error, + } +} + +export async function runModelSchedulerCycle( + ctx: SchedulerContext, + config?: ModelSchedulerConfig, +): Promise { + const resolvedConfig = resolveConfig(config) + if (!resolvedConfig.enabled) { + log("[model-scheduler] skipped because scheduler is disabled") + return null + } + + const previousSnapshot = readModelHealthSnapshot() + const previousProviderModels = readProviderModelsCache() + + await updateConnectedProvidersCache(ctx.client) + + const currentProviderModels = readProviderModelsCache() + const availableModels = await fetchAvailableModels(ctx.client) + const connectedProviders = currentProviderModels?.connected ?? [] + const routing = readModelRoutingFile() ?? {} + const nextRouting = structuredClone(routing) + const nowIso = new Date().toISOString() + const changes: ModelSchedulerAuditEntry["changes"] = [] + const agentHealth: Record = {} + const categoryHealth: Record = {} + const modelsToProbe = new Set() + + for (const [agentName, entry] of Object.entries(routing.agentModelMapping ?? {})) { + for (const model of collectCandidateModels({ + kind: "agent", + key: agentName, + routingEntry: entry, + currentModel: entry.primary ?? null, + availableModels, + })) { + modelsToProbe.add(model) + } + } + + for (const [categoryName, model] of Object.entries(routing.categoryRouting ?? {})) { + for (const candidate of collectCandidateModels({ + kind: "category", + key: categoryName, + currentModel: model, + availableModels, + })) { + modelsToProbe.add(candidate) + } + } + + const probeRunner = createModelProbeRunner(ctx, resolvedConfig, availableModels) + const probeResults = await probeRunner.probeModels(Array.from(modelsToProbe)) + const healthyModels = new Set( + Object.values(probeResults) + .filter((result) => result.status === "healthy" || result.status === "skipped") + .map((result) => result.model), + ) + + for (const [agentName, entry] of Object.entries(routing.agentModelMapping ?? {})) { + const currentModel = entry.primary ?? null + const previous = getPreviousTargetHealth(previousSnapshot, "agent", agentName) + const cooldownUntil = getCooldownUntil(previous, resolvedConfig.agent_cooldown_minutes) + const currentProbe = currentModel ? probeResults[currentModel] : undefined + const healthyCurrent = isModelHealthy(currentModel, healthyModels) + const previousFailures = previous?.consecutiveFailures ?? 0 + const previousSuccesses = previous?.consecutiveSuccesses ?? 0 + const nextSuccesses = healthyCurrent ? previousSuccesses + 1 : 0 + const meetsRecoveryThreshold = nextSuccesses >= resolvedConfig.recovery_threshold + const effectiveHealthy = healthyCurrent && (previousFailures === 0 || meetsRecoveryThreshold) + const nextFailures = effectiveHealthy ? 0 : previousFailures + 1 + const meetsFailureThreshold = nextFailures >= resolvedConfig.failure_threshold + const failureReason = getModelFailureReason(currentProbe) + const replacement = selectReplacementModel({ + kind: "agent", + key: agentName, + routingEntry: entry, + currentModel, + availableModels: healthyModels, + }) + const canChange = !isCooldownActive(cooldownUntil, nowIso) + const isProtected = resolvedConfig.protect_manual_routing && !!entry.reason + const nextPrimary = effectiveHealthy || !canChange || !meetsFailureThreshold || isProtected + ? currentModel + : replacement.model + const selectedProbe = nextPrimary ? probeResults[nextPrimary] : undefined + const selectedHealthy = isModelHealthy(nextPrimary, healthyModels) + const changed = currentModel !== nextPrimary && nextPrimary !== null + const status = getRoutingTargetStatus({ + effectiveHealthy, + selectedHealthy, + failureReason, + }) + + agentHealth[agentName] = nextHealthRecord({ + key: agentName, + displayName: agentName, + kind: "agent", + currentModel, + selectedModel: nextPrimary, + isHealthy: effectiveHealthy, + selectedHealthy, + changed, + status, + reason: effectiveHealthy ? replacement.reason : failureReason, + checkedAt: nowIso, + previous, + cooldownUntil, + probeResult: currentProbe, + selectedProbeResult: selectedProbe, + }) + + if (resolvedConfig.mode === "active" && changed && nextRouting.agentModelMapping?.[agentName]) { + nextRouting.agentModelMapping[agentName] = { + ...nextRouting.agentModelMapping[agentName], + primary: nextPrimary, + fallback: buildNextFallbackList({ + currentPrimary: currentModel, + selectedPrimary: nextPrimary, + routingEntry: entry, + availableModels, + }), + } + + changes.push({ + kind: "agent", + key: agentName, + from: currentModel, + to: nextPrimary, + reason: failureReason, + }) + } + } + + for (const [categoryName, model] of Object.entries(routing.categoryRouting ?? {})) { + const previous = getPreviousTargetHealth(previousSnapshot, "category", categoryName) + const cooldownUntil = getCooldownUntil(previous, resolvedConfig.agent_cooldown_minutes) + const currentProbe = model ? probeResults[model] : undefined + const healthyCurrent = isModelHealthy(model, healthyModels) + const previousFailures = previous?.consecutiveFailures ?? 0 + const previousSuccesses = previous?.consecutiveSuccesses ?? 0 + const nextSuccesses = healthyCurrent ? previousSuccesses + 1 : 0 + const meetsRecoveryThreshold = nextSuccesses >= resolvedConfig.recovery_threshold + const effectiveHealthy = healthyCurrent && (previousFailures === 0 || meetsRecoveryThreshold) + const nextFailures = effectiveHealthy ? 0 : previousFailures + 1 + const meetsFailureThreshold = nextFailures >= resolvedConfig.failure_threshold + const failureReason = getModelFailureReason(currentProbe) + const replacement = selectReplacementModel({ + kind: "category", + key: categoryName, + currentModel: model, + availableModels: healthyModels, + }) + const canChange = !isCooldownActive(cooldownUntil, nowIso) + const nextModel = effectiveHealthy || !canChange || !meetsFailureThreshold + ? model + : replacement.model + const selectedProbe = nextModel ? probeResults[nextModel] : undefined + const selectedHealthy = isModelHealthy(nextModel, healthyModels) + const changed = model !== nextModel && nextModel !== null + const status = getRoutingTargetStatus({ + effectiveHealthy, + selectedHealthy, + failureReason, + }) + + categoryHealth[categoryName] = nextHealthRecord({ + key: categoryName, + displayName: categoryName, + kind: "category", + currentModel: model, + selectedModel: nextModel, + isHealthy: effectiveHealthy, + selectedHealthy, + changed, + status, + reason: effectiveHealthy ? replacement.reason : failureReason, + checkedAt: nowIso, + previous, + cooldownUntil, + probeResult: currentProbe, + selectedProbeResult: selectedProbe, + }) + + if (resolvedConfig.mode === "active" && changed && nextRouting.categoryRouting) { + nextRouting.categoryRouting[categoryName] = nextModel + changes.push({ + kind: "category", + key: categoryName, + from: model, + to: nextModel, + reason: failureReason, + }) + } + } + + nextRouting.lastUpdated = nowIso.slice(0, 10) + nextRouting.scheduler = { + lastRunAt: nowIso, + lastMode: resolvedConfig.mode, + lastChangeCount: changes.length, + } + + if (resolvedConfig.mode === "active" && changes.length > 0) { + writeModelRoutingFile(nextRouting) + } + + const snapshot: ModelHealthSnapshot = { + version: 2, + mode: resolvedConfig.mode, + updatedAt: nowIso, + connectedProviders, + availableModelCount: availableModels.size, + inventory: diffInventory(previousProviderModels, currentProviderModels), + agents: agentHealth, + categories: categoryHealth, + probe: { + enabled: resolvedConfig.probe_enabled, + timeoutMs: resolvedConfig.probe_timeout_ms, + maxLatencyMs: resolvedConfig.probe_max_latency_ms, + checkedModelCount: Object.keys(probeResults).length, + healthyModelCount: Object.values(probeResults).filter((result) => + result.status === "healthy" || result.status === "skipped").length, + models: probeResults, + }, + } + const auditEntry: ModelSchedulerAuditEntry = { + timestamp: nowIso, + mode: resolvedConfig.mode, + changed: changes.length > 0, + availableModelCount: availableModels.size, + connectedProviders, + probeSummary: { + enabled: resolvedConfig.probe_enabled, + checkedModelCount: Object.keys(probeResults).length, + healthyModelCount: Object.values(probeResults).filter((result) => + result.status === "healthy" || result.status === "skipped").length, + }, + changes, + } + + writeModelHealthSnapshot(snapshot) + appendModelSchedulerAuditEntry(auditEntry) + + if (changes.length > 0) { + log("[model-scheduler] applied routing changes", { changes }) + await ctx.client.tui?.showToast?.({ + body: { + title: "Model Scheduler", + message: `Adjusted ${changes.length} routing entr${changes.length === 1 ? "y" : "ies"} after health check.`, + variant: "info", + duration: 6000, + }, + }) + } else { + log("[model-scheduler] completed without routing changes", { + availableModelCount: availableModels.size, + connectedProviders, + }) + } + + return { snapshot, auditEntry } +} + +export function getModelSchedulerIntervalMs(config?: ModelSchedulerConfig): number { + return resolveConfig(config).interval_minutes * 60_000 +} diff --git a/src/features/model-scheduler/selector.ts b/src/features/model-scheduler/selector.ts new file mode 100644 index 0000000000..5728e5ff0c --- /dev/null +++ b/src/features/model-scheduler/selector.ts @@ -0,0 +1,88 @@ +import { AGENT_MODEL_REQUIREMENTS, CATEGORY_MODEL_REQUIREMENTS, fuzzyMatchModel } from "../../shared" +import type { RoutingEntry, RoutingTargetKind, SchedulerChangeReason } from "./types" + +function normalizeKey(value: string): string { + return value.trim().toLowerCase().replace(/[^a-z0-9]+/g, "-").replace(/^-+|-+$/g, "") +} + +function findExactModelMatch(candidate: string, availableModels: Set): string | null { + const lowered = candidate.trim().toLowerCase() + for (const model of availableModels) { + if (model.toLowerCase() === lowered) { + return model + } + } + return null +} + +export function isModelHealthy(model: string | null | undefined, availableModels: Set): boolean { + if (!model) return false + return findExactModelMatch(model, availableModels) !== null +} + +export function resolveHealthyModel(model: string | null | undefined, availableModels: Set): string | null { + if (!model) return null + return findExactModelMatch(model, availableModels) +} + +export function selectReplacementModel(args: { + kind: RoutingTargetKind + key: string + routingEntry?: RoutingEntry | null + currentModel: string | null + availableModels: Set +}): { model: string | null; reason: SchedulerChangeReason } { + const { kind, key, routingEntry, currentModel, availableModels } = args + + const healthyCurrent = resolveHealthyModel(currentModel, availableModels) + if (healthyCurrent) { + return { model: healthyCurrent, reason: "unavailable" } + } + + for (const fallback of routingEntry?.fallback ?? []) { + const exactFallback = resolveHealthyModel(fallback, availableModels) + if (exactFallback) { + return { model: exactFallback, reason: "existing-fallback" } + } + } + + const requirements = kind === "agent" + ? AGENT_MODEL_REQUIREMENTS[normalizeKey(key)] + : CATEGORY_MODEL_REQUIREMENTS[normalizeKey(key)] + + if (!requirements) { + return { model: null, reason: "unavailable" } + } + + for (const fallbackEntry of requirements.fallbackChain) { + const match = fuzzyMatchModel(fallbackEntry.model, availableModels, fallbackEntry.providers) + if (match) { + return { model: match, reason: "requirements-fallback" } + } + } + + return { model: null, reason: "unavailable" } +} + +export function buildNextFallbackList(args: { + currentPrimary: string | null + selectedPrimary: string | null + routingEntry?: RoutingEntry | null + availableModels: Set +}): string[] { + const candidates = [ + ...(args.routingEntry?.fallback ?? []), + ...(args.currentPrimary && args.currentPrimary !== args.selectedPrimary ? [args.currentPrimary] : []), + ] + const nextFallbacks: string[] = [] + + for (const candidate of candidates) { + const resolved = resolveHealthyModel(candidate, args.availableModels) + if (!resolved) continue + if (resolved === args.selectedPrimary) continue + if (nextFallbacks.includes(resolved)) continue + nextFallbacks.push(resolved) + } + + return nextFallbacks +} diff --git a/src/features/model-scheduler/types.ts b/src/features/model-scheduler/types.ts new file mode 100644 index 0000000000..26ae345dd4 --- /dev/null +++ b/src/features/model-scheduler/types.ts @@ -0,0 +1,110 @@ +import type { ModelSchedulerConfig, ModelSchedulerMode } from "../../config" + +export type RoutingTargetKind = "agent" | "category" +export type HealthStatus = "healthy" | "offline" | "degraded" | "unknown" +export type SchedulerChangeReason = + | "existing-fallback" + | "requirements-fallback" + | "unavailable" + | "probe-failed" + | "latency-too-high" + +export type ModelProbeStatus = "healthy" | "unavailable" | "timeout" | "error" | "slow" | "skipped" + +export type ModelProbeResult = { + model: string + available: boolean + status: ModelProbeStatus + checkedAt: string + latencyMs?: number + error?: string +} + +export type RoutingTargetHealth = { + key: string + displayName: string + kind: RoutingTargetKind + currentModel: string | null + status: HealthStatus + selectedModel: string | null + changed: boolean + reason: SchedulerChangeReason + checkedAt: string + cooldownUntil?: string + consecutiveFailures: number + consecutiveSuccesses: number + availabilityStatus?: "available" | "unavailable" + probeStatus?: ModelProbeStatus + probeLatencyMs?: number + probeError?: string +} + +export type ModelHealthSnapshot = { + version: 2 + mode: ModelSchedulerMode + updatedAt: string + connectedProviders: string[] + availableModelCount: number + inventory: { + added: Record + removed: Record + } + agents: Record + categories: Record + probe: { + enabled: boolean + timeoutMs: number + maxLatencyMs: number + checkedModelCount: number + healthyModelCount: number + models: Record + } +} + +export type ModelSchedulerAuditEntry = { + timestamp: string + mode: ModelSchedulerMode + changed: boolean + availableModelCount: number + connectedProviders: string[] + probeSummary?: { + enabled: boolean + checkedModelCount: number + healthyModelCount: number + } + changes: Array<{ + kind: RoutingTargetKind + key: string + from: string | null + to: string | null + reason: SchedulerChangeReason + }> +} + +export type RoutingEntry = { + primary?: string + fallback?: string[] + reason?: string +} + +export type ModelRoutingFile = { + version?: string + lastUpdated?: string + note?: string + providerSummary?: Record + agentModelMapping?: Record + categoryRouting?: Record + summary?: Record + scheduler?: { + lastRunAt?: string + lastMode?: ModelSchedulerMode + lastChangeCount?: number + } +} + +export type SchedulerRunResult = { + snapshot: ModelHealthSnapshot + auditEntry: ModelSchedulerAuditEntry +} + +export type ResolvedModelSchedulerConfig = Required diff --git a/src/hooks/auto-update-checker/hook.test.ts b/src/hooks/auto-update-checker/hook.test.ts index 33d91b48bb..9b8873fd72 100644 --- a/src/hooks/auto-update-checker/hook.test.ts +++ b/src/hooks/auto-update-checker/hook.test.ts @@ -6,9 +6,14 @@ const mockUpdateAndShowConnectedProvidersCacheStatus = mock(async () => {}) const mockShowLocalDevToast = mock(async () => {}) const mockShowVersionToast = mock(async () => {}) const mockRunBackgroundUpdateCheck = mock(async () => {}) +const mockRunModelSchedulerCycle = mock(async () => null) +const mockGetModelSchedulerIntervalMs = mock(() => 3600000) const mockGetCachedVersion = mock(() => "3.6.0") const mockGetLocalDevVersion = mock<(directory: string) => string | null>(() => null) +const intervalCallbacks: Array<() => void> = [] +const originalSetInterval = globalThis.setInterval + mock.module("./hook/config-errors-toast", () => ({ showConfigErrorsIfAny: mockShowConfigErrorsIfAny, })) @@ -27,6 +32,11 @@ mock.module("./hook/startup-toasts", () => ({ showVersionToast: mockShowVersionToast, })) +mock.module("../../features/model-scheduler", () => ({ + runModelSchedulerCycle: mockRunModelSchedulerCycle, + getModelSchedulerIntervalMs: mockGetModelSchedulerIntervalMs, +})) + mock.module("./hook/background-update-check", () => ({ runBackgroundUpdateCheck: mockRunBackgroundUpdateCheck, })) @@ -61,15 +71,27 @@ beforeEach(() => { mockShowLocalDevToast.mockClear() mockShowVersionToast.mockClear() mockRunBackgroundUpdateCheck.mockClear() + mockRunModelSchedulerCycle.mockClear() + mockGetModelSchedulerIntervalMs.mockClear() mockGetCachedVersion.mockClear() mockGetLocalDevVersion.mockClear() mockGetCachedVersion.mockReturnValue("3.6.0") mockGetLocalDevVersion.mockReturnValue(null) + mockGetModelSchedulerIntervalMs.mockReturnValue(3600000) + + globalThis.setInterval = (((handler: TimerHandler) => { + if (typeof handler === "function") { + intervalCallbacks.push(handler) + } + return { unref: () => {} } as unknown as number + }) as typeof setInterval) }) afterEach(() => { delete process.env.OPENCODE_CLI_RUN_MODE + intervalCallbacks.length = 0 + globalThis.setInterval = originalSetInterval }) describe("createAutoUpdateCheckerHook", () => { @@ -100,6 +122,7 @@ describe("createAutoUpdateCheckerHook", () => { expect(mockShowLocalDevToast).not.toHaveBeenCalled() expect(mockShowVersionToast).not.toHaveBeenCalled() expect(mockRunBackgroundUpdateCheck).not.toHaveBeenCalled() + expect(mockRunModelSchedulerCycle).not.toHaveBeenCalled() }) it("runs all startup checks on normal session.created", async () => { @@ -119,8 +142,10 @@ describe("createAutoUpdateCheckerHook", () => { expect(mockShowConfigErrorsIfAny).toHaveBeenCalledTimes(1) expect(mockUpdateAndShowConnectedProvidersCacheStatus).toHaveBeenCalledTimes(1) expect(mockShowModelCacheWarningIfNeeded).toHaveBeenCalledTimes(1) + expect(mockRunModelSchedulerCycle).toHaveBeenCalledTimes(1) expect(mockShowVersionToast).toHaveBeenCalledTimes(1) expect(mockRunBackgroundUpdateCheck).toHaveBeenCalledTimes(1) + expect(intervalCallbacks).toHaveLength(1) }) it("ignores subagent sessions (parentID present)", async () => { @@ -144,6 +169,7 @@ describe("createAutoUpdateCheckerHook", () => { expect(mockShowLocalDevToast).not.toHaveBeenCalled() expect(mockShowVersionToast).not.toHaveBeenCalled() expect(mockRunBackgroundUpdateCheck).not.toHaveBeenCalled() + expect(mockRunModelSchedulerCycle).not.toHaveBeenCalled() }) it("runs only once (hasChecked guard)", async () => { @@ -168,6 +194,7 @@ describe("createAutoUpdateCheckerHook", () => { expect(mockShowConfigErrorsIfAny).toHaveBeenCalledTimes(1) expect(mockUpdateAndShowConnectedProvidersCacheStatus).toHaveBeenCalledTimes(1) expect(mockShowModelCacheWarningIfNeeded).toHaveBeenCalledTimes(1) + expect(mockRunModelSchedulerCycle).toHaveBeenCalledTimes(1) expect(mockShowVersionToast).toHaveBeenCalledTimes(1) expect(mockRunBackgroundUpdateCheck).toHaveBeenCalledTimes(1) }) @@ -190,6 +217,7 @@ describe("createAutoUpdateCheckerHook", () => { expect(mockShowConfigErrorsIfAny).toHaveBeenCalledTimes(1) expect(mockUpdateAndShowConnectedProvidersCacheStatus).toHaveBeenCalledTimes(1) expect(mockShowModelCacheWarningIfNeeded).toHaveBeenCalledTimes(1) + expect(mockRunModelSchedulerCycle).toHaveBeenCalledTimes(1) expect(mockShowLocalDevToast).toHaveBeenCalledTimes(1) expect(mockShowVersionToast).not.toHaveBeenCalled() expect(mockRunBackgroundUpdateCheck).not.toHaveBeenCalled() @@ -215,6 +243,7 @@ describe("createAutoUpdateCheckerHook", () => { expect(mockShowLocalDevToast).not.toHaveBeenCalled() expect(mockShowVersionToast).not.toHaveBeenCalled() expect(mockRunBackgroundUpdateCheck).not.toHaveBeenCalled() + expect(mockRunModelSchedulerCycle).not.toHaveBeenCalled() }) it("passes correct toast message with sisyphus enabled", async () => { diff --git a/src/hooks/auto-update-checker/hook.ts b/src/hooks/auto-update-checker/hook.ts index b915f9e554..8a96017e19 100644 --- a/src/hooks/auto-update-checker/hook.ts +++ b/src/hooks/auto-update-checker/hook.ts @@ -1,5 +1,6 @@ import type { PluginInput } from "@opencode-ai/plugin" import { log } from "../../shared/logger" +import { getModelSchedulerIntervalMs, runModelSchedulerCycle } from "../../features/model-scheduler" import { getCachedVersion, getLocalDevVersion } from "./checker" import type { AutoUpdateCheckerOptions } from "./types" import { runBackgroundUpdateCheck } from "./hook/background-update-check" @@ -9,7 +10,7 @@ import { showModelCacheWarningIfNeeded } from "./hook/model-cache-warning" import { showLocalDevToast, showVersionToast } from "./hook/startup-toasts" export function createAutoUpdateCheckerHook(ctx: PluginInput, options: AutoUpdateCheckerOptions = {}) { - const { showStartupToast = true, isSisyphusEnabled = false, autoUpdate = true } = options + const { showStartupToast = true, isSisyphusEnabled = false, autoUpdate = true, modelScheduler } = options const isCliRunMode = process.env.OPENCODE_CLI_RUN_MODE === "true" const getToastMessage = (isUpdate: boolean, latestVersion?: string): string => { @@ -24,6 +25,21 @@ export function createAutoUpdateCheckerHook(ctx: PluginInput, options: AutoUpdat } let hasChecked = false + let schedulerStarted = false + + const startSchedulerLoop = () => { + if (schedulerStarted) return + schedulerStarted = true + + const intervalMs = getModelSchedulerIntervalMs(modelScheduler) + const interval = setInterval(() => { + runModelSchedulerCycle(ctx, modelScheduler).catch((error) => { + log("[auto-update-checker] model scheduler interval failed:", error) + }) + }, intervalMs) + + interval.unref?.() + } return { event: ({ event }: { event: { type: string; properties?: unknown } }) => { @@ -45,6 +61,14 @@ export function createAutoUpdateCheckerHook(ctx: PluginInput, options: AutoUpdat await updateAndShowConnectedProvidersCacheStatus(ctx) await showModelCacheWarningIfNeeded(ctx) + if (modelScheduler?.preflight_on_session_created ?? true) { + await runModelSchedulerCycle(ctx, modelScheduler).catch((error) => { + log("[auto-update-checker] model scheduler preflight failed:", error) + }) + } + + startSchedulerLoop() + if (localDevVersion) { if (showStartupToast) { showLocalDevToast(ctx, displayVersion, isSisyphusEnabled).catch(() => {}) diff --git a/src/hooks/auto-update-checker/types.ts b/src/hooks/auto-update-checker/types.ts index 550e5137fd..097028bdc2 100644 --- a/src/hooks/auto-update-checker/types.ts +++ b/src/hooks/auto-update-checker/types.ts @@ -1,3 +1,5 @@ +import type { ModelSchedulerConfig } from "../../config" + export interface NpmDistTags { latest: string [key: string]: string @@ -26,4 +28,5 @@ export interface AutoUpdateCheckerOptions { showStartupToast?: boolean isSisyphusEnabled?: boolean autoUpdate?: boolean + modelScheduler?: ModelSchedulerConfig } diff --git a/src/plugin/hooks/create-session-hooks.ts b/src/plugin/hooks/create-session-hooks.ts index daa5e4ff54..3fd18e7760 100644 --- a/src/plugin/hooks/create-session-hooks.ts +++ b/src/plugin/hooks/create-session-hooks.ts @@ -184,6 +184,7 @@ export function createSessionHooks(args: { showStartupToast: isHookEnabled("startup-toast"), isSisyphusEnabled: pluginConfig.sisyphus_agent?.disabled !== true, autoUpdate: pluginConfig.auto_update ?? true, + modelScheduler: pluginConfig.model_scheduler, })) : null diff --git a/src/shared/connected-providers-cache.ts b/src/shared/connected-providers-cache.ts index e9a1ec368f..130408c13c 100644 --- a/src/shared/connected-providers-cache.ts +++ b/src/shared/connected-providers-cache.ts @@ -2,6 +2,7 @@ import { existsSync, readFileSync, writeFileSync, mkdirSync } from "fs" import { join } from "path" import { log } from "./logger" import { getOmoOpenCodeCacheDir } from "./data-path" +import { readUserConfiguredModels } from "./opencode-config-reader" const CONNECTED_PROVIDERS_CACHE_FILE = "connected-providers.json" const PROVIDER_MODELS_CACHE_FILE = "provider-models.json" @@ -169,14 +170,31 @@ export async function updateConnectedProvidersCache(client: { writeConnectedProvidersCache(connected) + const userConfiguredModels = readUserConfiguredModels() const modelsByProvider: Record = {} const allProviders = result.data?.all ?? [] for (const provider of allProviders) { if (provider.models) { - const modelIds = Object.keys(provider.models) - if (modelIds.length > 0) { - modelsByProvider[provider.id] = modelIds + const allModelIds = Object.keys(provider.models) + + if (userConfiguredModels?.has(provider.id)) { + const whitelist = userConfiguredModels.get(provider.id)! + const filteredModelIds = allModelIds.filter(modelId => whitelist.has(modelId)) + + if (filteredModelIds.length > 0) { + modelsByProvider[provider.id] = filteredModelIds + log("[connected-providers-cache] Filtered models by user config", { + provider: provider.id, + total: allModelIds.length, + filtered: filteredModelIds.length, + kept: filteredModelIds, + }) + } + } else { + if (allModelIds.length > 0) { + modelsByProvider[provider.id] = allModelIds + } } } } diff --git a/src/shared/index.ts b/src/shared/index.ts index 39b36482a7..f7e66fcd59 100644 --- a/src/shared/index.ts +++ b/src/shared/index.ts @@ -45,6 +45,7 @@ export type { export * from "./model-availability" export * from "./fallback-model-availability" export * from "./connected-providers-cache" +export * from "./opencode-config-reader" export * from "./session-utils" export * from "./tmux" export * from "./model-suggestion-retry" diff --git a/src/shared/opencode-config-reader.test.ts b/src/shared/opencode-config-reader.test.ts new file mode 100644 index 0000000000..f1e909e633 --- /dev/null +++ b/src/shared/opencode-config-reader.test.ts @@ -0,0 +1,131 @@ +import { describe, test, expect, beforeEach, afterEach } from "bun:test" +import { mkdirSync, writeFileSync, rmSync, existsSync } from "fs" +import { join } from "path" +import { readUserConfiguredModels } from "./opencode-config-reader" + +const TEST_CONFIG_DIR = join(process.cwd(), "test-opencode-config") + +describe("readUserConfiguredModels", () => { + beforeEach(() => { + if (existsSync(TEST_CONFIG_DIR)) { + rmSync(TEST_CONFIG_DIR, { recursive: true, force: true }) + } + mkdirSync(TEST_CONFIG_DIR, { recursive: true }) + process.env.OPENCODE_CONFIG_DIR = TEST_CONFIG_DIR + }) + + afterEach(() => { + delete process.env.OPENCODE_CONFIG_DIR + if (existsSync(TEST_CONFIG_DIR)) { + rmSync(TEST_CONFIG_DIR, { recursive: true, force: true }) + } + }) + + describe("#given opencode.json with provider models", () => { + describe("#when readUserConfiguredModels called", () => { + test("#then returns whitelist map", () => { + const config = { + provider: { + minimax: { + models: { + "MiniMax-M2.5-highspeed": { name: "Highspeed" }, + "MiniMax-M2.5": { name: "Standard" }, + }, + }, + openai: { + models: { + "gpt-5.2": { name: "GPT-5.2" }, + }, + }, + }, + } + + writeFileSync(join(TEST_CONFIG_DIR, "opencode.json"), JSON.stringify(config, null, 2)) + + const result = readUserConfiguredModels() + + expect(result).not.toBeNull() + expect(result!.size).toBe(2) + expect(result!.get("minimax")).toEqual(new Set(["MiniMax-M2.5-highspeed", "MiniMax-M2.5"])) + expect(result!.get("openai")).toEqual(new Set(["gpt-5.2"])) + }) + }) + }) + + describe("#given opencode.jsonc with comments", () => { + describe("#when readUserConfiguredModels called", () => { + test("#then parses JSONC correctly", () => { + const configContent = `{ + // Provider configuration + "provider": { + "minimax": { + "models": { + "MiniMax-M2.5-highspeed": { "name": "Highspeed" } + } + } + } + }` + + writeFileSync(join(TEST_CONFIG_DIR, "opencode.jsonc"), configContent) + + const result = readUserConfiguredModels() + + expect(result).not.toBeNull() + expect(result!.size).toBe(1) + expect(result!.get("minimax")).toEqual(new Set(["MiniMax-M2.5-highspeed"])) + }) + }) + }) + + describe("#given provider with no models", () => { + describe("#when readUserConfiguredModels called", () => { + test("#then excludes provider from whitelist", () => { + const config = { + provider: { + minimax: { + models: { + "MiniMax-M2.5-highspeed": { name: "Highspeed" }, + }, + }, + openai: { + apiKey: "sk-xxx", + }, + }, + } + + writeFileSync(join(TEST_CONFIG_DIR, "opencode.json"), JSON.stringify(config, null, 2)) + + const result = readUserConfiguredModels() + + expect(result).not.toBeNull() + expect(result!.size).toBe(1) + expect(result!.has("minimax")).toBe(true) + expect(result!.has("openai")).toBe(false) + }) + }) + }) + + describe("#given no config file", () => { + describe("#when readUserConfiguredModels called", () => { + test("#then returns null", () => { + const result = readUserConfiguredModels() + expect(result).toBeNull() + }) + }) + }) + + describe("#given config with no provider section", () => { + describe("#when readUserConfiguredModels called", () => { + test("#then returns null", () => { + const config = { + plugin: ["oh-my-opencode"], + } + + writeFileSync(join(TEST_CONFIG_DIR, "opencode.json"), JSON.stringify(config, null, 2)) + + const result = readUserConfiguredModels() + expect(result).toBeNull() + }) + }) + }) +}) diff --git a/src/shared/opencode-config-reader.ts b/src/shared/opencode-config-reader.ts new file mode 100644 index 0000000000..eca55dd715 --- /dev/null +++ b/src/shared/opencode-config-reader.ts @@ -0,0 +1,85 @@ +import { existsSync, readFileSync } from "fs" +import { join } from "path" +import { log } from "./logger" +import { parseJsonc } from "./jsonc-parser" +import { getOpenCodeConfigDir } from "./opencode-config-dir" + +interface OpenCodeProviderConfig { + models?: Record +} + +interface OpenCodeConfig { + provider?: Record +} + +/** + * Read opencode.json/opencode.jsonc and extract user-configured models per provider. + * Returns a map of provider ID → Set of configured model IDs. + * + * @example + * // opencode.json contains: + * // { "provider": { "minimax": { "models": { "MiniMax-M2.5-highspeed": {...} } } } } + * + * const whitelist = readUserConfiguredModels() + * // → Map { "minimax" => Set { "MiniMax-M2.5-highspeed" } } + */ +export function readUserConfiguredModels(): Map> | null { + const configDir = getOpenCodeConfigDir({ binary: "opencode", version: null }) + const configPaths = [ + join(configDir, "opencode.json"), + join(configDir, "opencode.jsonc"), + ] + + let configContent: string | null = null + let configPath: string | null = null + + for (const path of configPaths) { + if (existsSync(path)) { + try { + configContent = readFileSync(path, "utf-8") + configPath = path + break + } catch (err) { + log("[opencode-config-reader] Error reading config file", { path, error: String(err) }) + } + } + } + + if (!configContent || !configPath) { + log("[opencode-config-reader] No opencode config file found") + return null + } + + try { + const config = parseJsonc(configContent) + + if (!config?.provider) { + log("[opencode-config-reader] No provider config found") + return null + } + + const whitelist = new Map>() + + for (const [providerId, providerConfig] of Object.entries(config.provider)) { + if (!providerConfig?.models) { + continue + } + + const modelIds = Object.keys(providerConfig.models) + if (modelIds.length > 0) { + whitelist.set(providerId, new Set(modelIds)) + } + } + + log("[opencode-config-reader] Extracted user-configured models", { + providerCount: whitelist.size, + providers: Array.from(whitelist.keys()), + totalModels: Array.from(whitelist.values()).reduce((sum, set) => sum + set.size, 0), + }) + + return whitelist + } catch (err) { + log("[opencode-config-reader] Error parsing config file", { path: configPath, error: String(err) }) + return null + } +}