diff --git a/src/plugin-handlers/provider-config-handler.test.ts b/src/plugin-handlers/provider-config-handler.test.ts new file mode 100644 index 0000000000..421dd5d0ec --- /dev/null +++ b/src/plugin-handlers/provider-config-handler.test.ts @@ -0,0 +1,84 @@ +/// + +import { describe, expect, test } from "bun:test" +import { applyProviderConfig } from "./provider-config-handler" +import { createModelCacheState } from "../plugin-state" +import { clearVisionCapableModelsCache, readVisionCapableModelsCache } from "../shared/vision-capable-models-cache" + +describe("applyProviderConfig", () => { + test("caches vision-capable models from modalities and capabilities", () => { + // given + const modelCacheState = createModelCacheState() + const visionCapableModelsCache = modelCacheState.visionCapableModelsCache + if (!visionCapableModelsCache) { + throw new Error("visionCapableModelsCache should be initialized") + } + const config = { + provider: { + rundao: { + models: { + "public/qwen3.5-397b": { + modalities: { + input: ["text", "image"], + }, + }, + "public/text-only": { + modalities: { + input: ["text"], + }, + }, + }, + }, + google: { + models: { + "gemini-3-flash": { + capabilities: { + input: { + image: true, + }, + }, + }, + }, + }, + }, + } satisfies Record + + // when + applyProviderConfig({ config, modelCacheState }) + + // then + expect(Array.from(visionCapableModelsCache.keys())).toEqual([ + "rundao/public/qwen3.5-397b", + "google/gemini-3-flash", + ]) + expect(readVisionCapableModelsCache()).toEqual([ + { providerID: "rundao", modelID: "public/qwen3.5-397b" }, + { providerID: "google", modelID: "gemini-3-flash" }, + ]) + }) + + test("clears stale vision-capable models when provider config changes", () => { + // given + const modelCacheState = createModelCacheState() + const visionCapableModelsCache = modelCacheState.visionCapableModelsCache + if (!visionCapableModelsCache) { + throw new Error("visionCapableModelsCache should be initialized") + } + visionCapableModelsCache.set("stale/old-model", { + providerID: "stale", + modelID: "old-model", + }) + + // when + applyProviderConfig({ + config: { provider: {} }, + modelCacheState, + }) + + // then + expect(visionCapableModelsCache.size).toBe(0) + expect(readVisionCapableModelsCache()).toEqual([]) + }) +}) + +clearVisionCapableModelsCache() diff --git a/src/plugin-handlers/provider-config-handler.ts b/src/plugin-handlers/provider-config-handler.ts index 75964d20be..cc01508bd1 100644 --- a/src/plugin-handlers/provider-config-handler.ts +++ b/src/plugin-handlers/provider-config-handler.ts @@ -1,10 +1,31 @@ -import type { ModelCacheState } from "../plugin-state"; +import type { ModelCacheState, VisionCapableModel } from "../plugin-state"; +import { setVisionCapableModelsCache } from "../shared/vision-capable-models-cache" type ProviderConfig = { options?: { headers?: Record }; - models?: Record; + models?: Record; }; +type ProviderModelConfig = { + limit?: { context?: number }; + modalities?: { + input?: string[]; + }; + capabilities?: { + input?: { + image?: boolean; + }; + }; +} + +function supportsImageInput(modelConfig: ProviderModelConfig | undefined): boolean { + if (modelConfig?.modalities?.input?.includes("image")) { + return true + } + + return modelConfig?.capabilities?.input?.image === true +} + export function applyProviderConfig(params: { config: Record; modelCacheState: ModelCacheState; @@ -17,6 +38,12 @@ export function applyProviderConfig(params: { params.modelCacheState.anthropicContext1MEnabled = anthropicBeta?.includes("context-1m") ?? false; + const visionCapableModelsCache = params.modelCacheState.visionCapableModelsCache + ?? new Map() + params.modelCacheState.visionCapableModelsCache = visionCapableModelsCache + visionCapableModelsCache.clear() + setVisionCapableModelsCache(visionCapableModelsCache) + if (!providers) return; for (const [providerID, providerConfig] of Object.entries(providers)) { @@ -24,6 +51,13 @@ export function applyProviderConfig(params: { if (!models) continue; for (const [modelID, modelConfig] of Object.entries(models)) { + if (supportsImageInput(modelConfig)) { + visionCapableModelsCache.set( + `${providerID}/${modelID}`, + { providerID, modelID }, + ) + } + const contextLimit = modelConfig?.limit?.context; if (!contextLimit) continue; diff --git a/src/plugin-state.ts b/src/plugin-state.ts index 5f20c02333..e4f3a0afaa 100644 --- a/src/plugin-state.ts +++ b/src/plugin-state.ts @@ -1,11 +1,18 @@ +export type VisionCapableModel = { + providerID: string + modelID: string +} + export interface ModelCacheState { modelContextLimitsCache: Map; + visionCapableModelsCache?: Map; anthropicContext1MEnabled: boolean; } export function createModelCacheState(): ModelCacheState { return { modelContextLimitsCache: new Map(), + visionCapableModelsCache: new Map(), anthropicContext1MEnabled: false, }; } diff --git a/src/shared/vision-capable-models-cache.ts b/src/shared/vision-capable-models-cache.ts new file mode 100644 index 0000000000..3aaf629b51 --- /dev/null +++ b/src/shared/vision-capable-models-cache.ts @@ -0,0 +1,17 @@ +import type { VisionCapableModel } from "../plugin-state" + +let visionCapableModelsCache = new Map() + +export function setVisionCapableModelsCache( + cache: Map, +): void { + visionCapableModelsCache = cache +} + +export function readVisionCapableModelsCache(): VisionCapableModel[] { + return Array.from(visionCapableModelsCache.values()) +} + +export function clearVisionCapableModelsCache(): void { + visionCapableModelsCache = new Map() +} diff --git a/src/tools/look-at/multimodal-agent-metadata.test.ts b/src/tools/look-at/multimodal-agent-metadata.test.ts new file mode 100644 index 0000000000..d47a377e58 --- /dev/null +++ b/src/tools/look-at/multimodal-agent-metadata.test.ts @@ -0,0 +1,115 @@ +/// + +import { afterEach, beforeEach, describe, expect, mock, spyOn, test } from "bun:test" +import type { PluginInput } from "@opencode-ai/plugin" +import { resolveMultimodalLookerAgentMetadata } from "./multimodal-agent-metadata" +import { setVisionCapableModelsCache, clearVisionCapableModelsCache } from "../../shared/vision-capable-models-cache" +import * as connectedProvidersCache from "../../shared/connected-providers-cache" +import * as modelAvailability from "../../shared/model-availability" + +function createPluginInput(agentData: Array>): PluginInput { + const client = {} as PluginInput["client"] + Object.assign(client, { + app: { + agents: mock(async () => ({ data: agentData })), + }, + }) + + return { + client, + project: {} as PluginInput["project"], + directory: "/project", + worktree: "/project", + serverUrl: new URL("http://localhost"), + $: {} as PluginInput["$"], + } +} + +describe("resolveMultimodalLookerAgentMetadata", () => { + beforeEach(() => { + clearVisionCapableModelsCache() + }) + + afterEach(() => { + clearVisionCapableModelsCache() + ;(modelAvailability.fetchAvailableModels as unknown as { mockRestore?: () => void }).mockRestore?.() + ;(connectedProvidersCache.readConnectedProvidersCache as unknown as { mockRestore?: () => void }).mockRestore?.() + }) + + test("returns configured multimodal-looker model when it already matches a vision-capable override", async () => { + // given + setVisionCapableModelsCache(new Map([ + [ + "rundao/public/qwen3.5-397b", + { providerID: "rundao", modelID: "public/qwen3.5-397b" }, + ], + ])) + spyOn(modelAvailability, "fetchAvailableModels").mockResolvedValue( + new Set(["rundao/public/qwen3.5-397b"]), + ) + spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["rundao"]) + const ctx = createPluginInput([ + { + name: "multimodal-looker", + model: { providerID: "rundao", modelID: "public/qwen3.5-397b" }, + }, + ]) + + // when + const result = await resolveMultimodalLookerAgentMetadata(ctx) + + // then + expect(result).toEqual({ + agentModel: { providerID: "rundao", modelID: "public/qwen3.5-397b" }, + agentVariant: undefined, + }) + }) + + test("prefers connected vision-capable provider models before the hardcoded fallback chain", async () => { + // given + setVisionCapableModelsCache(new Map([ + [ + "rundao/public/qwen3.5-397b", + { providerID: "rundao", modelID: "public/qwen3.5-397b" }, + ], + ])) + spyOn(modelAvailability, "fetchAvailableModels").mockResolvedValue( + new Set(["openai/gpt-5.4", "rundao/public/qwen3.5-397b"]), + ) + spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai", "rundao"]) + const ctx = createPluginInput([ + { + name: "multimodal-looker", + model: { providerID: "openai", modelID: "gpt-5.4" }, + variant: "medium", + }, + ]) + + // when + const result = await resolveMultimodalLookerAgentMetadata(ctx) + + // then + expect(result).toEqual({ + agentModel: { providerID: "rundao", modelID: "public/qwen3.5-397b" }, + agentVariant: undefined, + }) + }) + + test("falls back to the hardcoded multimodal chain when no dynamic vision model exists", async () => { + // given + spyOn(modelAvailability, "fetchAvailableModels").mockResolvedValue( + new Set(["google/gemini-3-flash"]), + ) + spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["google"]) + const ctx = createPluginInput([]) + + // when + const result = await resolveMultimodalLookerAgentMetadata(ctx) + + // then + expect(result).toEqual({ + agentModel: { providerID: "google", modelID: "gemini-3-flash" }, + agentVariant: undefined, + }) + }) +}) diff --git a/src/tools/look-at/multimodal-agent-metadata.ts b/src/tools/look-at/multimodal-agent-metadata.ts index e24c8b6fba..a96f9471ec 100644 --- a/src/tools/look-at/multimodal-agent-metadata.ts +++ b/src/tools/look-at/multimodal-agent-metadata.ts @@ -1,6 +1,11 @@ import type { PluginInput } from "@opencode-ai/plugin" import { MULTIMODAL_LOOKER_AGENT } from "./constants" -import { log } from "../../shared" +import { fetchAvailableModels } from "../../shared/model-availability" +import { log } from "../../shared/logger" +import { readConnectedProvidersCache } from "../../shared/connected-providers-cache" +import { resolveModelPipeline } from "../../shared/model-resolution-pipeline" +import { readVisionCapableModelsCache } from "../../shared/vision-capable-models-cache" +import { buildMultimodalLookerFallbackChain } from "./multimodal-fallback-chain" type AgentModel = { providerID: string; modelID: string } @@ -19,6 +24,20 @@ function isObject(value: unknown): value is Record { return typeof value === "object" && value !== null } +function getFullModelKey(model: AgentModel): string { + return `${model.providerID}/${model.modelID}` +} + +function parseAgentModel(model: string): AgentModel | undefined { + const [providerID, ...modelIDParts] = model.split("/") + const modelID = modelIDParts.join("/") + if (!providerID || modelID.length === 0) { + return undefined + } + + return { providerID, modelID } +} + function toAgentInfo(value: unknown): AgentInfo | null { if (!isObject(value)) return null const name = typeof value["name"] === "string" ? value["name"] : undefined @@ -33,22 +52,83 @@ function toAgentInfo(value: unknown): AgentInfo | null { return { name, model, variant } } +async function resolveRegisteredAgentMetadata( + ctx: PluginInput, +): Promise { + const agentsResult = await ctx.client.app?.agents?.() + const agentsRaw = isObject(agentsResult) ? agentsResult["data"] : undefined + const agents = Array.isArray(agentsRaw) ? agentsRaw.map(toAgentInfo).filter(Boolean) : [] + + const matched = agents.find( + (agent) => agent?.name?.toLowerCase() === MULTIMODAL_LOOKER_AGENT.toLowerCase() + ) + + return { + agentModel: matched?.model, + agentVariant: matched?.variant, + } +} + +async function resolveDynamicAgentMetadata( + ctx: PluginInput, + visionCapableModels = readVisionCapableModelsCache(), +): Promise { + const fallbackChain = buildMultimodalLookerFallbackChain(visionCapableModels) + const connectedProviders = readConnectedProvidersCache() + const availableModels = await fetchAvailableModels(ctx.client, { + connectedProviders, + }) + + const resolution = resolveModelPipeline({ + constraints: { + availableModels, + connectedProviders, + }, + policy: { + fallbackChain, + }, + }) + + const agentModel = resolution ? parseAgentModel(resolution.model) : undefined + return { + agentModel, + agentVariant: resolution?.variant, + } +} + +function isConfiguredVisionModel( + configuredModel: AgentModel | undefined, + dynamicModel: AgentModel | undefined, +): boolean { + if (!configuredModel || !dynamicModel) { + return false + } + + return getFullModelKey(configuredModel) === getFullModelKey(dynamicModel) +} + export async function resolveMultimodalLookerAgentMetadata( ctx: PluginInput ): Promise { try { - const agentsResult = await ctx.client.app?.agents?.() - const agentsRaw = isObject(agentsResult) ? agentsResult["data"] : undefined - const agents = Array.isArray(agentsRaw) ? agentsRaw.map(toAgentInfo).filter(Boolean) : [] + const registeredMetadata = await resolveRegisteredAgentMetadata(ctx) + const visionCapableModels = readVisionCapableModelsCache() + + if (registeredMetadata.agentModel && visionCapableModels.length === 0) { + return registeredMetadata + } - const matched = agents.find( - (agent) => agent?.name?.toLowerCase() === MULTIMODAL_LOOKER_AGENT.toLowerCase() - ) + const dynamicMetadata = await resolveDynamicAgentMetadata(ctx, visionCapableModels) - return { - agentModel: matched?.model, - agentVariant: matched?.variant, + if (isConfiguredVisionModel(registeredMetadata.agentModel, dynamicMetadata.agentModel)) { + return registeredMetadata } + + if (dynamicMetadata.agentModel) { + return dynamicMetadata + } + + return registeredMetadata } catch (error) { log("[look_at] Failed to resolve multimodal-looker model info", error) return {} diff --git a/src/tools/look-at/multimodal-fallback-chain.test.ts b/src/tools/look-at/multimodal-fallback-chain.test.ts new file mode 100644 index 0000000000..d37c5b6eca --- /dev/null +++ b/src/tools/look-at/multimodal-fallback-chain.test.ts @@ -0,0 +1,31 @@ +describe("buildMultimodalLookerFallbackChain", () => { + it("builds fallback chain from vision-capable models", async () => { + // given + const { buildMultimodalLookerFallbackChain } = await import("./multimodal-fallback-chain") + const visionCapableModels = [ + { providerID: "openai", modelID: "gpt-5.4" }, + { providerID: "opencode", modelID: "gpt-5.4" }, + ] + + // when + const result = buildMultimodalLookerFallbackChain(visionCapableModels) + + // then + const gpt54Entries = result.filter((entry) => entry.model === "gpt-5.4") + expect(gpt54Entries.length).toBeGreaterThan(0) + }) + + it("avoids duplicates when adding hardcoded entries", async () => { + // given + const { buildMultimodalLookerFallbackChain } = await import("./multimodal-fallback-chain") + const visionCapableModels = [{ providerID: "openai", modelID: "gpt-5.4" }] + + // when + const result = buildMultimodalLookerFallbackChain(visionCapableModels) + + // then + expect(result.length).toBeGreaterThan(0) + expect(result[0].model).toBe("gpt-5.4") + expect(result[0].providers).toContain("openai") + }) +}) diff --git a/src/tools/look-at/multimodal-fallback-chain.ts b/src/tools/look-at/multimodal-fallback-chain.ts new file mode 100644 index 0000000000..2e0f65de1c --- /dev/null +++ b/src/tools/look-at/multimodal-fallback-chain.ts @@ -0,0 +1,49 @@ +import type { FallbackEntry } from "../../shared/model-requirements" +import { AGENT_MODEL_REQUIREMENTS } from "../../shared/model-requirements" +import type { VisionCapableModel } from "../../plugin-state" + +const MULTIMODAL_LOOKER_REQUIREMENT = AGENT_MODEL_REQUIREMENTS["multimodal-looker"] + +function getFullModelKey(providerID: string, modelID: string): string { + return `${providerID}/${modelID}` +} + +export function isHardcodedMultimodalFallbackModel(model: VisionCapableModel): boolean { + return MULTIMODAL_LOOKER_REQUIREMENT.fallbackChain.some((entry) => + entry.providers.some((providerID) => + getFullModelKey(providerID, entry.model) === getFullModelKey(model.providerID, model.modelID), + ), + ) +} + +export function buildMultimodalLookerFallbackChain( + visionCapableModels: VisionCapableModel[], +): FallbackEntry[] { + const seen = new Set() + const fallbackChain: FallbackEntry[] = [] + + for (const visionCapableModel of visionCapableModels) { + const key = getFullModelKey(visionCapableModel.providerID, visionCapableModel.modelID) + if (seen.has(key)) continue + + seen.add(key) + fallbackChain.push({ + providers: [visionCapableModel.providerID], + model: visionCapableModel.modelID, + }) + } + + for (const entry of MULTIMODAL_LOOKER_REQUIREMENT.fallbackChain) { + const providerModelKeys = entry.providers.map((providerID) => + getFullModelKey(providerID, entry.model), + ) + if (providerModelKeys.every((key) => seen.has(key))) { + continue + } + + providerModelKeys.forEach((key) => seen.add(key)) + fallbackChain.push(entry) + } + + return fallbackChain +}