diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7a655c1..e60e7d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,6 +44,7 @@ jobs: pnpm --filter @maschina/events build pnpm --filter @maschina/nats build pnpm --filter @maschina/jobs build + pnpm --filter @maschina/model build pnpm --filter @maschina/telemetry build pnpm --filter @maschina/usage build pnpm --filter @maschina/billing build @@ -127,6 +128,7 @@ jobs: pnpm --filter @maschina/events build pnpm --filter @maschina/nats build pnpm --filter @maschina/jobs build + pnpm --filter @maschina/model build pnpm --filter @maschina/telemetry build pnpm --filter @maschina/usage build pnpm --filter @maschina/billing build @@ -231,8 +233,9 @@ jobs: uv pip install -e packages/agents --system uv pip install -e packages/risk --system uv pip install -e "packages/sdk/python[dev]" --system + uv pip install -e "services/runtime[dev]" --system uv pip install pytest pytest-asyncio pytest-mock --system - - run: pytest packages/runtime/tests packages/agents/tests packages/risk/tests packages/sdk/python/tests -v + - run: pytest packages/runtime/tests packages/agents/tests packages/risk/tests packages/sdk/python/tests services/runtime/tests -v # ── Gate ────────────────────────────────────────────────────────────────── diff --git a/.lintstagedrc.json b/.lintstagedrc.json deleted file mode 100644 index 6389702..0000000 --- a/.lintstagedrc.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "*.{ts,tsx,js,jsx,mjs,cjs,json,jsonc}": ["biome check --write --no-errors-on-unmatched"], - "*.py": ["ruff check --fix", "ruff format"], - "*.rs": ["rustfmt --edition 2021"] -} diff --git a/.lintstagedrc.mjs b/.lintstagedrc.mjs new file mode 100644 index 0000000..21a2210 --- /dev/null +++ b/.lintstagedrc.mjs @@ -0,0 +1,23 @@ +import { existsSync } from "node:fs"; + +// Filter helper — lint-staged 15 can temporarily drop newly-tracked files +// from disk during its stash/restore cycle. Filter to only existing files +// before passing to formatters/linters. +const existing = (files) => files.filter(existsSync); + +export default { + "*.{ts,tsx,js,jsx,mjs,cjs,json,jsonc}": ["biome check --write --no-errors-on-unmatched"], + + "*.py": (files) => { + const ex = existing(files); + if (!ex.length) return []; + const paths = ex.join(" "); + return [`ruff check --fix ${paths}`, `ruff format ${paths}`]; + }, + + "*.rs": (files) => { + const ex = existing(files); + if (!ex.length) return []; + return [`rustfmt --edition 2021 ${ex.join(" ")}`]; + }, +}; diff --git a/CHANGELOG.md b/CHANGELOG.md index fa2af6f..0ae3b06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,24 @@ Format: [Semantic Versioning](https://semver.org) — `[version] YYYY-MM-DD` ## [Unreleased] +### Added (2026-03-08 — Model routing) +- `packages/model/src/catalog.ts` — TypeScript model catalog: 3 Anthropic cloud models + 3 Ollama local models, per-tier access gates, billing multipliers (Haiku 1x, Sonnet 3x, Opus 15x, Ollama 0x) +- `packages/model/src/index.ts` — Barrel export +- `packages/model/src/catalog.test.ts` — 20 vitest tests covering multipliers, tier access, validation, resolution +- `packages/model/tsconfig.json` + build script — TS package alongside existing Python code +- `packages/validation` — `RunAgentSchema` gains optional `model` field +- `packages/jobs` — `AgentExecuteJob` gains `model` + `systemPrompt` fields; `dispatchAgentRun` updated +- `services/api` — Model access validation at run dispatch; resolves system prompt from `agent.config.systemPrompt`; passes model + system prompt through job queue +- `services/daemon` — `AgentExecuteJob`, `JobToRun` gain `model` + `system_prompt`; `RuntimeRequest` now sends all fields the Python runtime needs (`plan_tier`, `model`, `system_prompt`, `max_tokens`, `timeout_secs`); URL fixed from `/execute` → `/run` +- `services/daemon` — `RunOutput.payload` renamed to `output_payload` to match Python `RunResponse` +- `services/runtime` — Full model routing in `runner.py`: routes by model ID prefix (ollama/* vs Anthropic), applies billing multiplier, lazy-imports Anthropic client per request; drops global Ollama flag +- `services/runtime/tests/test_runner_routing.py` — Unit tests for multiplier + routing helpers (no real LLM calls) +- CI + pytest scripts updated to include `services/runtime` tests + +### Fixed (2026-03-08 — Model routing) +- Daemon was calling `/execute` endpoint on Python runtime — correct endpoint is `/run` +- Daemon `RuntimeRequest` was missing `plan_tier`, `model`, `system_prompt`, `timeout_secs` fields that the Python `RunRequest` model requires + ### Fixed (2026-03-07 — Session N+1: backend boot + E2E) - All 31 TS packages now build clean (`pnpm turbo build --filter='./packages/*'`) - `packages/cache/src/client.ts` — ioredis ESM default import via `(Redis as any)` constructor cast diff --git a/package.json b/package.json index 31b75f5..08c2966 100644 --- a/package.json +++ b/package.json @@ -29,8 +29,9 @@ "cargo:run:cli": "cargo run -p maschina-cli", "cargo:run:code": "cargo run -p maschina-code", - "pytest": "pytest packages/runtime packages/agents packages/ml packages/model packages/risk packages/sdk/python services/worker", + "pytest": "pytest packages/runtime packages/agents packages/ml packages/model packages/risk packages/sdk/python services/worker services/runtime", "pytest:runtime": "pytest packages/runtime", + "pytest:runtime-service": "pytest services/runtime", "pytest:agents": "pytest packages/agents", "pytest:ml": "pytest packages/ml", "pytest:model": "pytest packages/model", @@ -186,7 +187,7 @@ "ci": "pnpm check && pnpm build:packages && pnpm test", "ci:ts": "pnpm check && pnpm build:packages && turbo test --filter='!@maschina/daemon' --filter='!@maschina/cli' --filter='!@maschina/code' --filter='!@maschina/rust'", "ci:rust": "pnpm check:rust && pnpm build:rust && pnpm test:rust", - "ci:python": "pytest tests packages/runtime packages/agents packages/ml packages/risk packages/sdk/python services/worker", + "ci:python": "pytest tests packages/runtime packages/agents packages/ml packages/risk packages/sdk/python services/worker services/runtime", "ci:e2e": "turbo test --filter=@maschina/tests", "ci:integration": "vitest run --root tests/integration" }, diff --git a/packages/jobs/src/dispatch.ts b/packages/jobs/src/dispatch.ts index cf5600b..06bc952 100644 --- a/packages/jobs/src/dispatch.ts +++ b/packages/jobs/src/dispatch.ts @@ -31,6 +31,8 @@ export async function dispatchAgentRun(opts: { agentId: string; userId: string; tier: string; + model: string; + systemPrompt: string; inputPayload: unknown; timeoutSecs: number; }): Promise { diff --git a/packages/jobs/src/types.ts b/packages/jobs/src/types.ts index e2ea3e0..05b0bb9 100644 --- a/packages/jobs/src/types.ts +++ b/packages/jobs/src/types.ts @@ -15,6 +15,8 @@ export interface AgentExecuteJob { agentId: string; userId: string; tier: string; + model: string; + systemPrompt: string; inputPayload: unknown; timeoutSecs: number; } diff --git a/packages/model/package.json b/packages/model/package.json index 24947f2..d7d370d 100644 --- a/packages/model/package.json +++ b/packages/model/package.json @@ -3,11 +3,27 @@ "version": "0.0.0", "private": true, "type": "module", + "main": "./dist/index.js", + "types": "./dist/index.d.ts", + "exports": { + ".": { + "import": "./dist/index.js", + "types": "./dist/index.d.ts" + } + }, "scripts": { + "build": "tsc", "train": "python -m maschina_model.train", "eval": "python -m maschina_model.eval", "infer": "python -m maschina_model.infer", "test": "pytest", "clean": "rm -rf dist __pycache__ .pytest_cache" + }, + "dependencies": { + "@maschina/plans": "workspace:*" + }, + "devDependencies": { + "@maschina/tsconfig": "workspace:*", + "typescript": "^5" } } diff --git a/packages/model/src/catalog.test.ts b/packages/model/src/catalog.test.ts new file mode 100644 index 0000000..25aa083 --- /dev/null +++ b/packages/model/src/catalog.test.ts @@ -0,0 +1,119 @@ +import { describe, expect, it } from "vitest"; +import { + DEFAULT_MODEL, + getAllowedModels, + getModel, + getModelMultiplier, + resolveModel, + validateModelAccess, +} from "./catalog.js"; + +describe("getModel", () => { + it("returns the model def for a known ID", () => { + const m = getModel("claude-haiku-4-5-20251001"); + expect(m).toBeDefined(); + expect(m?.provider).toBe("anthropic"); + expect(m?.multiplier).toBe(1); + }); + + it("returns undefined for an unknown ID", () => { + expect(getModel("gpt-99")).toBeUndefined(); + }); +}); + +describe("getModelMultiplier", () => { + it("returns 1 for haiku", () => expect(getModelMultiplier("claude-haiku-4-5-20251001")).toBe(1)); + it("returns 3 for sonnet", () => expect(getModelMultiplier("claude-sonnet-4-6")).toBe(3)); + it("returns 15 for opus", () => expect(getModelMultiplier("claude-opus-4-6")).toBe(15)); + it("returns 0 for ollama models", () => expect(getModelMultiplier("ollama/llama3.2")).toBe(0)); + it("returns 1 for unknown model (safe default)", () => + expect(getModelMultiplier("unknown")).toBe(1)); +}); + +describe("getAllowedModels", () => { + it("access tier only gets local ollama models", () => { + const allowed = getAllowedModels("access"); + expect(allowed.every((m) => m.isLocal)).toBe(true); + }); + + it("m1 tier can use haiku and ollama", () => { + const ids = getAllowedModels("m1").map((m) => m.id); + expect(ids).toContain("claude-haiku-4-5-20251001"); + expect(ids).toContain("ollama/llama3.2"); + expect(ids).not.toContain("claude-sonnet-4-6"); + expect(ids).not.toContain("claude-opus-4-6"); + }); + + it("m5 tier can use haiku and sonnet but not opus", () => { + const ids = getAllowedModels("m5").map((m) => m.id); + expect(ids).toContain("claude-haiku-4-5-20251001"); + expect(ids).toContain("claude-sonnet-4-6"); + expect(ids).not.toContain("claude-opus-4-6"); + }); + + it("m10 tier can use all models", () => { + const ids = getAllowedModels("m10").map((m) => m.id); + expect(ids).toContain("claude-haiku-4-5-20251001"); + expect(ids).toContain("claude-sonnet-4-6"); + expect(ids).toContain("claude-opus-4-6"); + }); + + it("internal tier can use all models", () => { + const ids = getAllowedModels("internal").map((m) => m.id); + expect(ids).toContain("claude-opus-4-6"); + }); +}); + +describe("validateModelAccess", () => { + it("allows access tier to use ollama", () => { + const result = validateModelAccess("access", "ollama/llama3.2"); + expect(result.allowed).toBe(true); + }); + + it("denies access tier from using haiku", () => { + const result = validateModelAccess("access", "claude-haiku-4-5-20251001"); + expect(result.allowed).toBe(false); + expect(result.reason).toMatch(/m1/); + }); + + it("denies m1 tier from using sonnet", () => { + const result = validateModelAccess("m1", "claude-sonnet-4-6"); + expect(result.allowed).toBe(false); + expect(result.reason).toMatch(/m5/); + }); + + it("denies m5 tier from using opus", () => { + const result = validateModelAccess("m5", "claude-opus-4-6"); + expect(result.allowed).toBe(false); + expect(result.reason).toMatch(/m10/); + }); + + it("allows m10 tier to use opus", () => { + const result = validateModelAccess("m10", "claude-opus-4-6"); + expect(result.allowed).toBe(true); + }); + + it("denies unknown model with clear error", () => { + const result = validateModelAccess("enterprise", "gpt-99"); + expect(result.allowed).toBe(false); + expect(result.reason).toMatch(/Unknown model/); + }); +}); + +describe("resolveModel", () => { + it("returns the requested model if allowed", () => { + expect(resolveModel("m5", "claude-haiku-4-5-20251001")).toBe("claude-haiku-4-5-20251001"); + }); + + it("falls back to tier default if requested model is denied", () => { + // m1 requesting opus → should fall back to m1 default + expect(resolveModel("m1", "claude-opus-4-6")).toBe(DEFAULT_MODEL.m1); + }); + + it("returns tier default when no model is requested", () => { + expect(resolveModel("access")).toBe("ollama/llama3.2"); + expect(resolveModel("m1")).toBe("claude-haiku-4-5-20251001"); + expect(resolveModel("m5")).toBe("claude-sonnet-4-6"); + expect(resolveModel("m10")).toBe("claude-opus-4-6"); + }); +}); diff --git a/packages/model/src/catalog.ts b/packages/model/src/catalog.ts new file mode 100644 index 0000000..8ec5271 --- /dev/null +++ b/packages/model/src/catalog.ts @@ -0,0 +1,148 @@ +import type { PlanTier } from "@maschina/plans"; + +// ─── Model definitions ──────────────────────────────────────────────────────── +// multiplier: tokens billed = actual_tokens * multiplier +// Ollama (local) = 0 — never deducted from quota +// Haiku = 1 — 1:1 deduction +// Sonnet = 3 — 3x deduction per token +// Opus = 15 — 15x deduction per token +// +// minTier: minimum plan tier required to use this model via cloud execution. +// Local Ollama models have minTier "access" (always allowed). + +export interface ModelDef { + id: string; + displayName: string; + provider: "anthropic" | "ollama"; + /** Token billing multiplier. 0 = no deduction (local). */ + multiplier: number; + /** Minimum tier for cloud access. */ + minTier: PlanTier; + /** Whether this is a local Ollama model. */ + isLocal: boolean; +} + +export const MODEL_CATALOG: ModelDef[] = [ + // ─── Anthropic cloud models ───────────────────────────────────────────── + { + id: "claude-haiku-4-5-20251001", + displayName: "Claude Haiku", + provider: "anthropic", + multiplier: 1, + minTier: "m1", + isLocal: false, + }, + { + id: "claude-sonnet-4-6", + displayName: "Claude Sonnet", + provider: "anthropic", + multiplier: 3, + minTier: "m5", + isLocal: false, + }, + { + id: "claude-opus-4-6", + displayName: "Claude Opus", + provider: "anthropic", + multiplier: 15, + minTier: "m10", + isLocal: false, + }, + + // ─── Local Ollama models (Access tier and up) ──────────────────────────── + { + id: "ollama/llama3.2", + displayName: "Llama 3.2 (local)", + provider: "ollama", + multiplier: 0, + minTier: "access", + isLocal: true, + }, + { + id: "ollama/llama3.1", + displayName: "Llama 3.1 (local)", + provider: "ollama", + multiplier: 0, + minTier: "access", + isLocal: true, + }, + { + id: "ollama/mistral", + displayName: "Mistral (local)", + provider: "ollama", + multiplier: 0, + minTier: "access", + isLocal: true, + }, +]; + +const TIER_RANK: Record = { + access: 0, + m1: 1, + m5: 2, + m10: 3, + teams: 4, + enterprise: 5, + internal: 5, +}; + +/** Default model for a given plan tier. */ +export const DEFAULT_MODEL: Record = { + access: "ollama/llama3.2", + m1: "claude-haiku-4-5-20251001", + m5: "claude-sonnet-4-6", + m10: "claude-opus-4-6", + teams: "claude-sonnet-4-6", + enterprise: "claude-opus-4-6", + internal: "claude-opus-4-6", +}; + +/** Returns all models accessible at or below the given tier. */ +export function getAllowedModels(tier: PlanTier): ModelDef[] { + return MODEL_CATALOG.filter((m) => TIER_RANK[tier] >= TIER_RANK[m.minTier]); +} + +/** Returns the model definition by ID, or undefined if not found. */ +export function getModel(modelId: string): ModelDef | undefined { + return MODEL_CATALOG.find((m) => m.id === modelId); +} + +/** Returns the billing multiplier for a model. Returns 1 if model not found. */ +export function getModelMultiplier(modelId: string): number { + return getModel(modelId)?.multiplier ?? 1; +} + +export interface ModelAccessResult { + allowed: boolean; + reason?: string; + model: ModelDef | undefined; +} + +/** + * Validates whether a given tier may use a given model. + * Returns { allowed: true, model } on success. + * Returns { allowed: false, reason } if denied. + */ +export function validateModelAccess(tier: PlanTier, modelId: string): ModelAccessResult { + const model = getModel(modelId); + if (!model) { + return { allowed: false, reason: `Unknown model: ${modelId}`, model: undefined }; + } + if (TIER_RANK[tier] < TIER_RANK[model.minTier]) { + return { + allowed: false, + reason: `Model ${model.displayName} requires the ${model.minTier} plan or higher.`, + model, + }; + } + return { allowed: true, model }; +} + +/** Returns the default model ID for a tier, resolving to the best allowed model. */ +export function resolveModel(tier: PlanTier, requested?: string): string { + if (requested) { + const { allowed } = validateModelAccess(tier, requested); + if (allowed) return requested; + } + return DEFAULT_MODEL[tier]; +} diff --git a/packages/model/src/index.ts b/packages/model/src/index.ts new file mode 100644 index 0000000..3d227dc --- /dev/null +++ b/packages/model/src/index.ts @@ -0,0 +1,10 @@ +export { + MODEL_CATALOG, + DEFAULT_MODEL, + getAllowedModels, + getModel, + getModelMultiplier, + validateModelAccess, + resolveModel, +} from "./catalog.js"; +export type { ModelDef, ModelAccessResult } from "./catalog.js"; diff --git a/packages/model/tsconfig.json b/packages/model/tsconfig.json new file mode 100644 index 0000000..415a499 --- /dev/null +++ b/packages/model/tsconfig.json @@ -0,0 +1,6 @@ +{ + "extends": "@maschina/tsconfig/node.json", + "compilerOptions": { "outDir": "./dist", "rootDir": "./src" }, + "include": ["src"], + "exclude": ["node_modules", "dist"] +} diff --git a/packages/validation/src/schemas/agent.ts b/packages/validation/src/schemas/agent.ts index fc49347..5a15b26 100644 --- a/packages/validation/src/schemas/agent.ts +++ b/packages/validation/src/schemas/agent.ts @@ -19,6 +19,8 @@ export const UpdateAgentSchema = z.object({ export const RunAgentSchema = z.object({ input: z.record(z.unknown()).optional().default({}), + /** Optional model override. Validated against the caller's plan tier. */ + model: z.string().optional(), timeout: z .number() .int() diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 38e5150..40185e1 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -896,7 +896,18 @@ importers: packages/ml: {} - packages/model: {} + packages/model: + dependencies: + '@maschina/plans': + specifier: workspace:* + version: link:../plans + devDependencies: + '@maschina/tsconfig': + specifier: workspace:* + version: link:../tsconfig + typescript: + specifier: ^5 + version: 5.9.3 packages/nats: dependencies: @@ -1452,6 +1463,9 @@ importers: '@maschina/jobs': specifier: workspace:* version: link:../../packages/jobs + '@maschina/model': + specifier: workspace:* + version: link:../../packages/model '@maschina/nats': specifier: workspace:* version: link:../../packages/nats diff --git a/services/api/Dockerfile b/services/api/Dockerfile index cd937ea..b50bb68 100644 --- a/services/api/Dockerfile +++ b/services/api/Dockerfile @@ -14,6 +14,7 @@ COPY packages/db/package.json ./packages/db/package.json COPY packages/email/package.json ./packages/email/package.json COPY packages/events/package.json ./packages/events/package.json COPY packages/jobs/package.json ./packages/jobs/package.json +COPY packages/model/package.json ./packages/model/package.json COPY packages/nats/package.json ./packages/nats/package.json COPY packages/notifications/package.json ./packages/notifications/package.json COPY packages/plans/package.json ./packages/plans/package.json @@ -38,6 +39,7 @@ RUN pnpm --filter @maschina/plans build RUN pnpm --filter @maschina/events build RUN pnpm --filter @maschina/nats build RUN pnpm --filter @maschina/jobs build +RUN pnpm --filter @maschina/model build RUN pnpm --filter @maschina/telemetry build RUN pnpm --filter @maschina/usage build RUN pnpm --filter @maschina/billing build diff --git a/services/api/package.json b/services/api/package.json index 699066e..fafb080 100644 --- a/services/api/package.json +++ b/services/api/package.json @@ -19,6 +19,7 @@ "@maschina/email": "workspace:*", "@maschina/events": "workspace:*", "@maschina/jobs": "workspace:*", + "@maschina/model": "workspace:*", "@maschina/nats": "workspace:*", "@maschina/notifications": "workspace:*", "@maschina/plans": "workspace:*", diff --git a/services/api/src/routes/agents.ts b/services/api/src/routes/agents.ts index 52406c9..db51cfa 100644 --- a/services/api/src/routes/agents.ts +++ b/services/api/src/routes/agents.ts @@ -3,6 +3,7 @@ import { agents } from "@maschina/db"; import { and, eq, isNull } from "@maschina/db"; import { Subjects } from "@maschina/events"; import { dispatchAgentRun } from "@maschina/jobs"; +import { resolveModel, validateModelAccess } from "@maschina/model"; import { publishSafe } from "@maschina/nats"; import { recordAgentExecution } from "@maschina/usage"; import { @@ -147,9 +148,26 @@ app.post( if (!agent) throw new HTTPException(404, { message: "Agent not found" }); - // Get the user's current plan for timeout config - const { getPlan } = await import("@maschina/plans"); - const plan = getPlan(user.tier); + // Validate model access and resolve to the appropriate model for this tier + if (input.model) { + const access = validateModelAccess(user.tier, input.model); + if (!access.allowed) { + throw new HTTPException(403, { + message: access.reason ?? "Model not available on your plan.", + }); + } + } + const resolvedModel = resolveModel(user.tier, input.model); + + // Resolve system prompt from agent config, fall back to a sensible default + const agentConfig = (agent.config ?? {}) as Record; + const systemPrompt = + typeof agentConfig.systemPrompt === "string" + ? agentConfig.systemPrompt + : `You are a Maschina ${agent.type} agent named "${agent.name}". Complete the task provided.`; + + // Convert timeout from ms (API input) to seconds (runtime) + const timeoutSecs = Math.floor((input.timeout ?? 300_000) / 1000); // Insert the agent_runs row const { agentRuns } = await import("@maschina/db"); @@ -169,8 +187,10 @@ app.post( agentId, userId: user.id, tier: user.tier, + model: resolvedModel, + systemPrompt, inputPayload: input.input ?? {}, - timeoutSecs: 300, + timeoutSecs, }); // Publish event (fire-and-forget — realtime service fans this out to WebSocket clients) diff --git a/services/daemon/src/orchestrator/analyze.rs b/services/daemon/src/orchestrator/analyze.rs index f1da3e9..b0293e8 100644 --- a/services/daemon/src/orchestrator/analyze.rs +++ b/services/daemon/src/orchestrator/analyze.rs @@ -63,7 +63,7 @@ async fn persist_success( WHERE id = $4 "#, ) - .bind(&output.payload) + .bind(&output.output_payload) .bind(output.input_tokens as i64) .bind(output.output_tokens as i64) .bind(run.id) diff --git a/services/daemon/src/orchestrator/scan.rs b/services/daemon/src/orchestrator/scan.rs index 0f29d12..4f5afe1 100644 --- a/services/daemon/src/orchestrator/scan.rs +++ b/services/daemon/src/orchestrator/scan.rs @@ -17,6 +17,8 @@ pub struct AgentExecuteJob { pub agent_id: Uuid, pub user_id: Uuid, pub tier: String, + pub model: String, + pub system_prompt: String, pub input_payload: serde_json::Value, pub timeout_secs: u64, } @@ -119,6 +121,8 @@ pub async fn scan_and_dispatch(state: AppState) -> Result<()> { agent_id: job.agent_id, user_id: job.user_id, plan_tier: job.tier, + model: job.model, + system_prompt: job.system_prompt, input_payload: job.input_payload, timeout_secs: job.timeout_secs as i64, }; diff --git a/services/daemon/src/orchestrator/scan_compat.rs b/services/daemon/src/orchestrator/scan_compat.rs index 77e6bf6..67f3fa3 100644 --- a/services/daemon/src/orchestrator/scan_compat.rs +++ b/services/daemon/src/orchestrator/scan_compat.rs @@ -8,6 +8,10 @@ pub struct JobToRun { pub agent_id: Uuid, pub user_id: Uuid, pub plan_tier: String, + /// Resolved model ID (e.g. "claude-haiku-4-5-20251001" or "ollama/llama3.2"). + pub model: String, + /// System prompt resolved from agent config at dispatch time. + pub system_prompt: String, pub input_payload: serde_json::Value, pub timeout_secs: i64, } diff --git a/services/daemon/src/runtime/mod.rs b/services/daemon/src/runtime/mod.rs index 86fdae7..823e3cb 100644 --- a/services/daemon/src/runtime/mod.rs +++ b/services/daemon/src/runtime/mod.rs @@ -4,32 +4,44 @@ use crate::state::AppState; use serde::{Deserialize, Serialize}; /// Output returned by the Python runtime after a successful agent execution. +/// Must match services/runtime/src/models.py::RunResponse exactly. #[derive(Debug, Deserialize)] pub struct RunOutput { - pub payload: serde_json::Value, + pub output_payload: serde_json::Value, pub input_tokens: u64, pub output_tokens: u64, } /// Request body sent to the Python runtime service. +/// Must match services/runtime/src/models.py::RunRequest exactly. #[derive(Debug, Serialize)] struct RuntimeRequest<'a> { run_id: uuid::Uuid, agent_id: uuid::Uuid, user_id: uuid::Uuid, + plan_tier: &'a str, + model: &'a str, + system_prompt: &'a str, + max_tokens: u32, input_payload: &'a serde_json::Value, + timeout_secs: i64, } /// Dispatch a run to the Python runtime and await the result. /// The caller is responsible for enforcing the timeout wrapper. pub async fn dispatch(state: &AppState, run: &QueuedRun) -> Result { - let url = format!("{}/execute", state.config.runtime_url); + let url = format!("{}/run", state.config.runtime_url); let body = RuntimeRequest { run_id: run.id, agent_id: run.agent_id, user_id: run.user_id, + plan_tier: &run.plan_tier, + model: &run.model, + system_prompt: &run.system_prompt, + max_tokens: 4096, input_payload: &run.input_payload, + timeout_secs: run.timeout_secs, }; let response = state diff --git a/services/runtime/src/runner.py b/services/runtime/src/runner.py index 557d406..6a86447 100644 --- a/services/runtime/src/runner.py +++ b/services/runtime/src/runner.py @@ -1,6 +1,16 @@ """ Agent execution — delegates to maschina-runtime (the shared execution package) and runs risk checks via maschina-risk before and after the LLM call. + +Model routing: + - Models starting with "ollama/" → OllamaRunner (local, no token quota deduction) + - All other models → AnthropicRunner (cloud, billed with multiplier) + +Token billing multipliers (applied to raw token counts before returning): + claude-haiku-* → 1x + claude-sonnet-* → 3x + claude-opus-* → 15x + ollama/* → 0x (local, never deducted from quota) """ import logging @@ -15,13 +25,33 @@ logger = logging.getLogger(__name__) -# Lazily import Anthropic only if an API key is configured -if not settings.use_ollama: - import anthropic +# ─── Token billing multipliers ──────────────────────────────────────────────── +# Must stay in sync with packages/model/src/catalog.ts + +_MULTIPLIERS: list[tuple[str, int]] = [ + ("claude-haiku-", 1), + ("claude-sonnet-", 3), + ("claude-opus-", 15), + ("ollama/", 0), +] + +_DEFAULT_MULTIPLIER = 1 + + +def _get_multiplier(model: str) -> int: + for prefix, mult in _MULTIPLIERS: + if model.startswith(prefix): + return mult + return _DEFAULT_MULTIPLIER + - _anthropic_client = anthropic.AsyncAnthropic(api_key=settings.anthropic_api_key) -else: - _anthropic_client = None # type: ignore[assignment] +def _is_ollama(model: str) -> bool: + return model.startswith("ollama/") + + +def _ollama_model_name(model: str) -> str: + """Strip 'ollama/' prefix to get the bare Ollama model name.""" + return model[len("ollama/") :] def _extract_user_message(input_payload: dict[str, Any]) -> str: @@ -38,9 +68,10 @@ async def execute(req: RunRequest) -> RunResponse: Pipeline: 1. Risk-check the user input (block prompt injection / oversized inputs) - 2. Run the agent via maschina-runtime AgentRunner + 2. Route to the appropriate runner based on model prefix 3. Risk-scan the output (flag PII leakage) - 4. Return structured response + 4. Apply token billing multiplier to reported token counts + 5. Return structured response """ user_message = _extract_user_message(req.input_payload) @@ -50,20 +81,34 @@ async def execute(req: RunRequest) -> RunResponse: codes = ", ".join(f.code for f in risk.flags) raise ValueError(f"Input blocked by risk check: {codes}") - # ── Execute via maschina-runtime ──────────────────────────────────────── - if settings.use_ollama: + # ── Route to runner ───────────────────────────────────────────────────── + if _is_ollama(req.model): + # Local Ollama — use the model name from the request runner = OllamaRunner( base_url=settings.ollama_base_url, - model=settings.ollama_model, + model=_ollama_model_name(req.model), system_prompt=req.system_prompt, max_tokens=min(req.max_tokens, settings.max_output_tokens), timeout_secs=req.timeout_secs, ) else: + # Cloud Anthropic model — lazy-import to avoid requiring the key for local dev + try: + import anthropic + + client = anthropic.AsyncAnthropic(api_key=settings.anthropic_api_key) + except ImportError as exc: + raise RuntimeError("anthropic package not installed") from exc + + if not settings.anthropic_api_key: + raise RuntimeError( + f"ANTHROPIC_API_KEY is not set but model '{req.model}' requires cloud execution" + ) + from maschina_runtime import AgentRunner runner = AgentRunner( - client=_anthropic_client, + client=client, system_prompt=req.system_prompt, model=req.model, max_tokens=min(req.max_tokens, settings.max_output_tokens), @@ -85,20 +130,30 @@ async def execute(req: RunRequest) -> RunResponse: extra={"run_id": req.run_id, "flags": [f.code for f in output_risk.flags]}, ) + # ── Apply billing multiplier ──────────────────────────────────────────── + # Multiply raw token counts so the daemon's quota deduction reflects cost. + # Ollama multiplier = 0, so local runs never deduct from the cloud quota. + multiplier = _get_multiplier(req.model) + billed_input_tokens = result.input_tokens * multiplier + billed_output_tokens = result.output_tokens * multiplier + logger.info( "run completed", extra={ "run_id": req.run_id, "model": req.model, "turns": result.turns, - "input_tokens": result.input_tokens, - "output_tokens": result.output_tokens, + "raw_input_tokens": result.input_tokens, + "raw_output_tokens": result.output_tokens, + "billed_input_tokens": billed_input_tokens, + "billed_output_tokens": billed_output_tokens, + "multiplier": multiplier, }, ) return RunResponse( run_id=req.run_id, output_payload={"text": result.output}, - input_tokens=result.input_tokens, - output_tokens=result.output_tokens, + input_tokens=billed_input_tokens, + output_tokens=billed_output_tokens, ) diff --git a/services/runtime/tests/__init__.py b/services/runtime/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/runtime/tests/test_runner_routing.py b/services/runtime/tests/test_runner_routing.py new file mode 100644 index 0000000..c8088fa --- /dev/null +++ b/services/runtime/tests/test_runner_routing.py @@ -0,0 +1,89 @@ +""" +Unit tests for model routing and billing multiplier logic in runner.py. +Tests the private helpers directly — no actual LLM calls made. +""" + +import sys +import types +import unittest.mock as mock + +# ─── Stub out packages that aren't installed in CI ─────────────────────────── + + +def _stub_module(name: str, **attrs): + mod = types.ModuleType(name) + for k, v in attrs.items(): + setattr(mod, k, v) + sys.modules[name] = mod + return mod + + +# Stub maschina_risk +_stub_module( + "maschina_risk", + check_input=lambda text: mock.MagicMock(approved=True, flags=[]), + check_output=lambda text: mock.MagicMock(flags=[]), +) + +# Stub maschina_runtime as a package (needs __path__ so submodule imports work) +_rt = _stub_module("maschina_runtime", RunInput=mock.MagicMock, AgentRunner=mock.MagicMock) +_rt.__path__ = [] # marks it as a package to Python's import system + +# Stub maschina_runtime.models (imported by ollama_runner.py at module level) +_stub_module("maschina_runtime.models", RunInput=mock.MagicMock, RunResult=mock.MagicMock) + +# Stub openai (imported by ollama_runner.py) +_stub_module("openai", AsyncOpenAI=mock.MagicMock) + +# Stub src.config.settings before importing runner +settings_mock = mock.MagicMock() +settings_mock.anthropic_api_key = "" +settings_mock.ollama_base_url = "http://localhost:11434/v1" +settings_mock.ollama_model = "llama3.2" +settings_mock.max_output_tokens = 16_384 +settings_mock.use_ollama = True + +config_mod = _stub_module("src.config", settings=settings_mock) + +# Now we can import the helpers +from src.runner import _get_multiplier, _is_ollama, _ollama_model_name # noqa: E402 + +# ─── Multiplier tests ───────────────────────────────────────────────────────── + + +class TestGetMultiplier: + def test_haiku_is_1x(self): + assert _get_multiplier("claude-haiku-4-5-20251001") == 1 + + def test_sonnet_is_3x(self): + assert _get_multiplier("claude-sonnet-4-6") == 3 + + def test_opus_is_15x(self): + assert _get_multiplier("claude-opus-4-6") == 15 + + def test_ollama_is_0x(self): + assert _get_multiplier("ollama/llama3.2") == 0 + assert _get_multiplier("ollama/mistral") == 0 + + def test_unknown_model_defaults_to_1x(self): + assert _get_multiplier("gpt-99") == 1 + assert _get_multiplier("") == 1 + + +# ─── Routing helpers ────────────────────────────────────────────────────────── + + +class TestIsOllama: + def test_ollama_prefix(self): + assert _is_ollama("ollama/llama3.2") is True + assert _is_ollama("ollama/mistral") is True + + def test_anthropic_is_not_ollama(self): + assert _is_ollama("claude-haiku-4-5-20251001") is False + assert _is_ollama("claude-sonnet-4-6") is False + + +class TestOllamaModelName: + def test_strips_prefix(self): + assert _ollama_model_name("ollama/llama3.2") == "llama3.2" + assert _ollama_model_name("ollama/mistral") == "mistral"