Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions src/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
matchesModelIdentifier,
normalizeTextCatalog,
} from './model-catalog.js';
import { doesResponseMatchModel, isMatchingModelName } from './model-matching.js';
import { generateSeed } from './seed.js';

const FALLBACK_MODELS = [
Expand Down Expand Up @@ -437,8 +438,11 @@ async function handleChatResponse(initialResponse, model, endpoint) {
client,
);
if (response?.model && !isMatchingModelName(response.model, model)) {
const aliasMatch = doesResponseMatchModel(response, model);
console.warn(
'Model mismatch detected after tool call. Expected %s, received %s.',
aliasMatch
? 'Model mismatch detected after tool call. Expected %s, received %s. Proceeding based on alias metadata.'
: 'Model mismatch detected after tool call. Expected %s, received %s.',
model.id,
response?.model,
);
Expand Down Expand Up @@ -739,18 +743,6 @@ function buildEndpointSequence(model) {
return result;
}

function isMatchingModelName(value, model) {
if (!value && value !== 0) return false;
const normalized = String(value).trim().toLowerCase();
if (!normalized) return false;
const identifiers = model?.identifiers;
if (identifiers?.has?.(normalized)) return true;
if (normalized.includes('/')) {
const last = normalized.split('/').pop();
if (last && identifiers?.has?.(last)) return true;
}
return false;
}

async function requestChatCompletion(model, endpoints) {
if (!model) {
Expand All @@ -775,7 +767,18 @@ async function requestChatCompletion(model, endpoints) {
},
client,
);
if (!response?.model || isMatchingModelName(response.model, model)) {
if (!response?.model) {
return { response, endpoint };
}
if (isMatchingModelName(response.model, model)) {
return { response, endpoint };
}
if (doesResponseMatchModel(response, model)) {
console.warn(
'Model mismatch detected. Expected %s, received %s. Proceeding based on alias metadata.',
model.id,
response.model,
);
return { response, endpoint };
}
attemptErrors.push(
Expand Down
130 changes: 130 additions & 0 deletions src/model-matching.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
const RESPONSE_MODEL_FIELDS = [
'alias',
'model_alias',
'modelAlias',
'canonical_model',
'canonicalModel',
'resolved_model',
'resolvedModel',
'primary_model',
'primaryModel',
'requested_model',
'requestedModel',
'requested',
'backend_model',
'backendModel',
'provider_model',
'providerModel',
'origin_model',
'originModel',
'served_model',
'servedModel',
'model_name',
'modelName',
'model_id',
'modelId',
'target_model',
'targetModel',
];

function normalize(value) {
if (value == null) return null;
const text = String(value).trim();
return text ? text.toLowerCase() : null;
}

function addCandidate(set, value) {
if (value == null) return;
if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
const normalized = normalize(value);
if (normalized) set.add(normalized);
return;
}
if (Array.isArray(value)) {
value.forEach(entry => addCandidate(set, entry));
return;
}
if (typeof value === 'object') {
for (const key of ['id', 'model', 'name', 'alias', 'slug']) {
if (key in value) {
addCandidate(set, value[key]);
}
}
}
}

export function collectResponseModelNames(response) {
const names = new Set();
if (!response || typeof response !== 'object') {
return [];
}

for (const key of RESPONSE_MODEL_FIELDS) {
addCandidate(names, response[key]);
}

if (Array.isArray(response?.aliases)) {
response.aliases.forEach(entry => addCandidate(names, entry));
}
if (Array.isArray(response?.models)) {
response.models.forEach(entry => addCandidate(names, entry));
}
if (Array.isArray(response?.modelAliases)) {
response.modelAliases.forEach(entry => addCandidate(names, entry));
}
if (Array.isArray(response?.available_models)) {
response.available_models.forEach(entry => addCandidate(names, entry));
}

const metadata = response?.metadata;
if (metadata && typeof metadata === 'object') {
for (const key of RESPONSE_MODEL_FIELDS) {
addCandidate(names, metadata[key]);
}
if (Array.isArray(metadata.aliases)) {
metadata.aliases.forEach(entry => addCandidate(names, entry));
}
}

const reported = normalize(response?.model);
if (reported) {
names.delete(reported);
}

return Array.from(names);
}

export function isMatchingModelName(value, model) {
if (!value && value !== 0) return false;
const normalized = normalize(value);
if (!normalized) return false;
const identifiers = model?.identifiers;
if (identifiers?.has?.(normalized)) return true;
if (normalized.includes('/')) {
const last = normalized.split('/').pop();
if (last && identifiers?.has?.(last)) return true;
}
return false;
}

export function doesResponseMatchModel(response, model) {
if (!response || typeof response !== 'object') {
return false;
}
if (isMatchingModelName(response.model, model)) {
return true;
}
const candidates = collectResponseModelNames(response);
for (const candidate of candidates) {
if (isMatchingModelName(candidate, model)) {
return true;
}
}
return false;
}

export const __testing = {
collectResponseModelNames,
isMatchingModelName,
normalize,
};
37 changes: 37 additions & 0 deletions tests/model-response-match.test.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import assert from 'node:assert/strict';
import { createFallbackModel } from '../src/model-catalog.js';
import { doesResponseMatchModel, isMatchingModelName } from '../src/model-matching.js';

export const name = 'Model response metadata allows alias matching';

export async function run() {
const model = createFallbackModel('unity', 'Unity Seed Model', ['seed', 'openai']);

assert(isMatchingModelName('unity', model), 'Model should match its own identifier');

assert(
doesResponseMatchModel({ model: 'unity' }, model),
'Exact model name should match',
);

assert(
doesResponseMatchModel(
{ model: 'mistral-small', requested_model: 'unity' },
model,
),
'requested_model should allow alias matching',
);

assert(
doesResponseMatchModel(
{ model: 'mistral-small', metadata: { alias: 'Pollinations/Unity' } },
model,
),
'metadata alias should allow alias matching',
);

assert(
!doesResponseMatchModel({ model: 'mistral-small' }, model),
'Unknown model identifiers should not match',
);
}