diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index f49b678..cdefd99 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -4,8 +4,7 @@ on: push: branches: [main] paths: - - 'lib/index.d.ts' - - 'lib/Branch.js' + - 'src/**' - 'package.json' - 'typedoc.json' - 'README.md' diff --git a/.github/workflows/gpu-test.yml b/.github/workflows/gpu-test.yml index 37d9fd5..f165c51 100644 --- a/.github/workflows/gpu-test.yml +++ b/.github/workflows/gpu-test.yml @@ -6,7 +6,6 @@ on: paths: - 'liblloyal' - 'llama.cpp' - - 'lib/**' - 'src/**' - 'test/**' - 'CMakeLists.txt' @@ -108,6 +107,18 @@ jobs: - name: Configure Docker for Artifact Registry run: gcloud auth configure-docker ${{ secrets.GCP_REGION }}-docker.pkg.dev --quiet + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: 24 + cache: 'npm' + + - name: Compile TypeScript (src + tests) + run: | + npm ci --ignore-scripts + npm run build:ts + npm run build:test + - name: Download package artifact uses: actions/download-artifact@v4 with: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 980703b..0c3511b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -218,7 +218,7 @@ jobs: export LD_LIBRARY_PATH="${PKG_BIN}:${LD_LIBRARY_PATH:-}" echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH" node -e " - const { loadBinary } = require('./lib'); + const { loadBinary } = require('./dist'); const addon = loadBinary(); console.log('βœ“ Platform package loaded successfully'); console.log(' Exports:', Object.keys(addon)); @@ -255,7 +255,7 @@ jobs: Write-Host "VULKAN_SDK: $env:VULKAN_SDK" Write-Host "CUDA_PATH: $env:CUDA_PATH" node -e " - const { loadBinary } = require('./lib'); + const { loadBinary } = require('./dist'); const addon = loadBinary(); console.log('βœ“ Platform package loaded successfully'); console.log(' Exports:', Object.keys(addon)); @@ -385,7 +385,13 @@ jobs: env: NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} - # Publish main package + # Build TypeScript before publishing main package + - name: Install dependencies + run: npm install --ignore-scripts + + - name: Build TypeScript + run: npm run build:ts + - name: Sync package versions run: node scripts/sync-versions.js diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b8e01c8..cab8485 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -132,6 +132,12 @@ jobs: with: node-version: 24 + - name: Install dependencies + run: npm install --ignore-scripts + + - name: Build TypeScript + run: npm run build:ts + - name: Pack package run: npm pack @@ -141,9 +147,9 @@ jobs: echo "πŸ“¦ Package contents:" cat package-contents.txt - # Verify lib/ JavaScript is included - if ! grep -q "package/lib/index.js" package-contents.txt; then - echo "❌ ERROR: lib/index.js not in package!" + # Verify dist/ JavaScript is included + if ! grep -q "package/dist/index.js" package-contents.txt; then + echo "❌ ERROR: dist/index.js not in package!" exit 1 fi diff --git a/.gitignore b/.gitignore index 44b0e3f..4deb625 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Build outputs build/ +dist/ prebuilds/ *.node /include/ @@ -40,5 +41,8 @@ tmp/ dist/ packages/darwin-arm64 +# Compiled test artifacts (built by tsconfig.test.json for GPU CI) +test/*.js + # CI infra scripts (injected from lloyal-infra during CI) ci/ \ No newline at end of file diff --git a/.npmignore b/.npmignore index 95ac8ef..2662537 100644 --- a/.npmignore +++ b/.npmignore @@ -34,8 +34,9 @@ tests/ examples/ docs/ -# C++ source files (users get prebuilt binaries, not source) +# C++ and TS source files (users get prebuilt binaries + compiled JS, not source) src/ +tsconfig.json # Test models (too large for npm) models/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 170a818..7f1a294 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -116,6 +116,21 @@ endif() # This also sets up the llama/llama.h include structure automatically add_subdirectory(${LIBLLOYAL_DIR} liblloyal) +# ============================================================================= +# md4c (Markdown parser for structure extraction) +# ============================================================================= + +include(FetchContent) +FetchContent_Declare( + md4c + GIT_REPOSITORY https://github.com/mity/md4c + GIT_TAG release-0.5.2 +) +set(BUILD_MD2HTML_EXECUTABLE OFF CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(md4c) +FetchContent_GetProperties(md4c) +set(MD4C_INCLUDE_DIR "${md4c_SOURCE_DIR}/src") + # ============================================================================= # Addon Sources # ============================================================================= @@ -124,6 +139,7 @@ set(ADDON_SOURCES src/binding.cpp src/BackendManager.cpp src/SessionContext.cpp + src/Util.cpp ) # ============================================================================= @@ -136,6 +152,7 @@ add_library(${PROJECT_NAME} MODULE ${ADDON_SOURCES} ${CMAKE_JS_SRC}) target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_JS_INC} ${NODE_ADDON_API_DIR} + ${MD4C_INCLUDE_DIR} src ) @@ -147,6 +164,7 @@ target_include_directories(${PROJECT_NAME} PRIVATE target_link_libraries(${PROJECT_NAME} PRIVATE liblloyal::liblloyal common + md4c ${CMAKE_JS_LIB} ) diff --git a/README.md b/README.md index 49fb41d..a1783d3 100644 --- a/README.md +++ b/README.md @@ -218,7 +218,7 @@ See [`examples/grammar/`](./examples/grammar/) for the full branch fork pattern. Full API documentation: **[lloyal-ai.github.io/lloyal.node](https://lloyal-ai.github.io/lloyal.node/)** -Generated from [`lib/index.d.ts`](./lib/index.d.ts) with TypeDoc. +Generated from [`src/index.ts`](./src/index.ts) with TypeDoc. --- diff --git a/examples/best-of-n/README.md b/examples/best-of-n/README.md deleted file mode 100644 index e9d973d..0000000 --- a/examples/best-of-n/README.md +++ /dev/null @@ -1,118 +0,0 @@ -# Best-of-N Sampling with Perplexity Selection - -Demonstrates why best-of-n beats single generation: generate N diverse candidates, select the most coherent by perplexity. - -## Run It - -```bash -node best-of-n.mjs -``` - -## What You'll See - -``` -BASELINE: Single generation (T=0.3) - PPL: 2.07 | "In the realm where the moon dipped..." - -BEST-OF-5: Generate 5 candidates (T=0.9), select lowest PPL - [1] PPL: 2.95 | "In the heart of a moonlit forest..." - [2] PPL: 4.41 | "Under the cloak of a midnight moon..." - [3] PPL: 3.09 | "As the last wisps of sunlight..." - [4] PPL: 3.42 | "Under the moon's silvery glow..." - [5] PPL: 3.46 | "Under the emerald canopy..." - -RESULTS - Best candidate [1] (PPL 2.95) - PPL range: 2.95 - 4.41 (Ξ”1.46) -``` - -## How It Works - -| Step | What Happens | -|------|--------------| -| 1. Prefill | Decode prompt on seq 0 | -| 2. Capture logits | Copy logits buffer (critical for fair comparison) | -| 3. Generate N candidates | Each forks KV, samples from captured logits, then continues | -| 4. Track PPL | Accumulate surprisal per candidate | -| 5. Select best | Lowest perplexity wins | - -## Key Implementation Detail - -After prefilling, the logits buffer contains P(next_token | prompt). When we fork to multiple sequences, **each candidate's first token must sample from these same captured logits**: - -```javascript -// Capture after prefill -const capturedLogits = new Float32Array(ctx.getLogits()); - -// Each candidate: -// 1. Sample first token from captured logits (tsampler) -const token = sampleWithStrategy(capturedLogits, { params, workspace, prng }); - -// 2. Compute surprisal from captured logits (native C++) -const surprisal = ctx.modelSurprisal(token, 'nats', capturedLogits); -``` - -Without this, later candidates would sample from earlier candidates' states - unfair comparison. - -## Why Perplexity? - -``` -PPL = exp(average surprisal) = "how surprised is the model?" -``` - -| PPL | Meaning | -|-----|---------| -| Low | Model is confident in what it wrote | -| High | Model was uncertain, may have inconsistencies | - -Best-of-N trades compute for quality: -- High temp generates **diverse** candidates (explore) -- PPL filtering selects **coherent** ones (exploit) - -## Key APIs - -| Method | Description | -|--------|-------------| -| `kvSeqCopy(src, dst)` | Fork KV cache (O(1) tag copy) | -| `getLogits()` | Get raw logits (zero-copy view) | -| `modelSurprisal(token, base?, logits?)` | Surprisal from current or captured logits | -| `createPerplexityTracker()` | Create tracker handle | -| `addSurprisal(tracker, value)` | Accumulate to tracker | -| `getPerplexity(tracker)` | Get current PPL | - -## Native Metrics API - -The native `modelSurprisal()` accepts an optional `logits` parameter for captured logits: - -```javascript -// First token: surprisal from captured logits -const firstSurprisal = ctx.modelSurprisal(token, 'nats', capturedLogits); - -// Subsequent tokens: current context logits (default) -const surprisal = ctx.modelSurprisal(token); -``` - -All math runs in C++ - no JS overhead for softmax/log operations. - -## tsampler Integration - -[@lloyal-labs/tsampler](https://www.npmjs.com/package/@lloyal-labs/tsampler) handles sampling from captured logits: - -```javascript -import { sampleWithStrategy, SamplerWorkspace, Xoroshiro128Plus } from '@lloyal-labs/tsampler'; - -const token = sampleWithStrategy(capturedLogits, { - params: { temperature: 0.9, topP: 0.95 }, - workspace, - prng, -}); -``` - -**Division of labor:** -- **tsampler**: Sampling (temperature, topP, topK) from arbitrary logits -- **Native API**: Metrics (surprisal, entropy, perplexity) from arbitrary logits - -## References - -- [Stiennon et al. 2020](https://arxiv.org/abs/2009.01325) - "Learning to summarize from human feedback" (Best-of-N in RLHF) -- [tsampler](https://github.com/lloyal-ai/tsampler) - Pure TypeScript sampling with llama.cpp parity diff --git a/examples/best-of-n/best-of-n.mjs b/examples/best-of-n/best-of-n.mjs deleted file mode 100644 index 22c9328..0000000 --- a/examples/best-of-n/best-of-n.mjs +++ /dev/null @@ -1,244 +0,0 @@ -#!/usr/bin/env node -/** - * Best-of-N Sampling with Perplexity Selection (Parallel Streaming) - * - * Demonstrates why best-of-n beats single generation: - * - Generate N candidates with high temperature (diverse) - * - Select best by perplexity (model's confidence in its output) - * - Lower perplexity = more coherent, higher quality - * - * Based on: "Best-of-N" / "Rejection Sampling" used in RLHF pipelines - * See: Stiennon et al. 2020 "Learning to summarize from human feedback" - * - * KEY IMPLEMENTATION DETAIL: - * Uses the Branch API for parallel generation. The root branch prefills the - * prompt and captures logits. When forking to multiple candidates, each fork - * inherits the root's logits snapshot, ensuring all candidates start from - * the same probability distribution. - * - * Usage: - * node best-of-n.mjs [model-path] # Human-readable output - * node best-of-n.mjs [model-path] --jsonl # JSONL output for testing - */ - -import * as path from 'node:path'; -import { fileURLToPath } from 'node:url'; -import { createContext, Branch } from '../../lib/index.js'; - -const __dirname = path.dirname(fileURLToPath(import.meta.url)); -const DEFAULT_MODEL = path.resolve( - __dirname, - '../../models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf' -); - -// Parse args -const args = process.argv.slice(2); -const jsonlMode = args.includes('--jsonl'); -const modelPath = args.find(a => !a.startsWith('--')) || DEFAULT_MODEL; - -/** Emit output - JSONL or human-readable */ -function emit(event, data) { - if (jsonlMode) { - console.log(JSON.stringify({ event, ...data })); - } -} - -/** Collect tokens from a branch's async iterator, return text + perplexity. */ -async function generateWithBranch(branch, maxTokens, ctx) { - const tokens = []; - for await (const { token } of branch) { - tokens.push(token); - if (tokens.length >= maxTokens) break; - } - const ppl = branch.perplexity; - return { - text: await ctx.detokenize(tokens), - ppl: Number.isFinite(ppl) ? ppl : 999, - tokenCount: tokens.length, - }; -} - -async function main() { - const N = 5; // Number of candidates - const MAX_TOKENS = 60; - const HIGH_TEMP = 0.9; // High temp for diversity - const LOW_TEMP = 0.3; // Low temp for single baseline - - if (!jsonlMode) { - console.log('Best-of-N Sampling Demo (Parallel Streaming)'); - console.log('=============================================\n'); - console.log('Why best-of-n works:'); - console.log(' 1. Generate N candidates with HIGH temperature (diverse)'); - console.log(' 2. Score each by perplexity (model confidence)'); - console.log(' 3. Select LOWEST perplexity (most coherent)\n'); - console.log(`Loading model: ${path.basename(modelPath)}`); - } - - emit('start', { model: path.basename(modelPath), n: N, maxTokens: MAX_TOKENS, highTemp: HIGH_TEMP, lowTemp: LOW_TEMP }); - - const nCtx = parseInt(process.env.LLAMA_CTX_SIZE || '2048', 10); - const ctx = await createContext({ - modelPath, - nCtx, - nSeqMax: N + 2, // Need slots for N candidates + baseline + trunk - }); - - // Use chat template for consistent behavior - const userPrompt = 'Write a creative opening sentence for a fantasy novel.'; - const messages = [{ role: 'user', content: userPrompt }]; - const { prompt } = await ctx.formatChat(JSON.stringify(messages)); - - if (!jsonlMode) { - console.log(`\nPrompt: "${userPrompt}"`); - } - - // Prefill prompt via root branch - const promptTokens = await ctx.tokenize(prompt); - - const root = Branch.create(ctx, 0, { - temperature: HIGH_TEMP, - topP: 0.95, - }); - await root.prefill(promptTokens); - - if (!jsonlMode) { - console.log(`\nPrefill complete. Prompt length: ${promptTokens.length} tokens`); - } - - // === Baseline: Single generation with low temperature === - if (!jsonlMode) { - console.log('\n' + '='.repeat(70)); - console.log('BASELINE: Single generation (forked from root)'); - console.log('='.repeat(70)); - } - - // Fork baseline from root β€” inherits KV prefix + logits snapshot - const baselineBranch = await root.fork(); - - const baseline = await generateWithBranch(baselineBranch, MAX_TOKENS, ctx); - - emit('baseline', { ppl: baseline.ppl, text: baseline.text, tokenCount: baseline.tokenCount }); - - if (!jsonlMode) { - console.log(` PPL: ${baseline.ppl.toFixed(2)} | "${baseline.text}"`); - } - - await baselineBranch.prune(); - - // === Best-of-N: Parallel candidates with high temperature === - if (!jsonlMode) { - console.log('\n' + '='.repeat(70)); - console.log(`BEST-OF-${N}: Generate ${N} candidates in parallel (T=${HIGH_TEMP})`); - console.log('='.repeat(70)); - } - - // Fork N candidate branches from root - // Each fork gets: copied logits snapshot + copied KV cache + copied sampler - // CRITICAL: Reseed each branch's sampler for diversity (otherwise all produce identical output) - const branches = []; - for (let i = 0; i < N; i++) { - const branch = await root.fork(); - branch.reseedSampler(1000 + i); // Unique seed per branch - branches.push(branch); - } - - // Generate each candidate sequentially β€” same total GPU work, simpler flow - const candidates = []; - for (let i = 0; i < N; i++) { - const tokens = []; - for await (const { token, text } of branches[i]) { - tokens.push(token); - emit('token', { candidateIndex: i, text, index: tokens.length - 1 }); - if (tokens.length >= MAX_TOKENS) break; - } - - const ppl = branches[i].perplexity; - const text = await ctx.detokenize(tokens); - candidates.push({ - text, - ppl: Number.isFinite(ppl) ? ppl : 999, - tokenCount: tokens.length, - }); - - emit('candidate', { index: i + 1, ppl: candidates[i].ppl, text, tokenCount: tokens.length }); - - if (!jsonlMode) { - const truncated = text.length > 55 ? text.slice(0, 55) + '...' : text; - console.log(` [${i + 1}] PPL: ${candidates[i].ppl.toFixed(2).padStart(6)} | "${truncated}"`); - } - - await branches[i].prune(); - } - await root.prune(); - - // Select best - const best = candidates.reduce((a, b) => (a.ppl < b.ppl ? a : b)); - const worst = candidates.reduce((a, b) => (a.ppl > b.ppl ? a : b)); - const bestIdx = candidates.indexOf(best) + 1; - - // Analysis - const improvement = (baseline.ppl - best.ppl) / baseline.ppl; - const pplRange = worst.ppl - best.ppl; - - emit('complete', { - bestIndex: bestIdx, - bestPpl: best.ppl, - bestText: best.text, - worstPpl: worst.ppl, - baselinePpl: baseline.ppl, - pplRange, - improvement, - bestBeatBaseline: best.ppl < baseline.ppl, - }); - - if (!jsonlMode) { - // === Results === - console.log('\n' + '='.repeat(70)); - console.log('RESULTS'); - console.log('='.repeat(70)); - - console.log(`\n Best candidate [${bestIdx}] (PPL ${best.ppl.toFixed(2)}):`); - console.log(` "${best.text}"`); - - console.log(`\n Baseline (PPL ${baseline.ppl.toFixed(2)}):`); - console.log(` "${baseline.text}"`); - - console.log('\n Analysis:'); - console.log(` - PPL range across candidates: ${best.ppl.toFixed(2)} - ${worst.ppl.toFixed(2)} (Ξ”${pplRange.toFixed(2)})`); - if (best.ppl < baseline.ppl) { - console.log(` - Best-of-${N} beat baseline by ${(improvement * 100).toFixed(1)}% lower PPL`); - } else { - console.log(` - Baseline was already good (low temp = focused)`); - } - - console.log('\n' + '='.repeat(70)); - console.log('KEY INSIGHT'); - console.log('='.repeat(70)); - console.log(` - Perplexity = exp(average surprisal) = "how surprised is the model?" - - Lower PPL = model is confident in what it wrote = usually more coherent - Higher PPL = model was uncertain = may have inconsistencies - - Best-of-N trades compute for quality: - - High temp generates diverse candidates (explore the space) - - PPL filtering selects the coherent ones (exploit quality) - - Implementation note: - Uses the Branch API for parallel generation. After prefilling the - prompt, we create a root branch and capture its logits. When forking - to N candidates, each fork inherits the root's logits snapshot, - ensuring all candidates start from the same probability distribution. - Generation happens in round-robin fashion, interleaving tokens across - all candidates. -`); - } - - ctx.dispose(); -} - -main().catch((err) => { - console.error('Error:', err.message); - console.error(err.stack); - process.exit(1); -}); diff --git a/examples/chat/chat.mjs b/examples/chat/chat.ts similarity index 72% rename from examples/chat/chat.mjs rename to examples/chat/chat.ts index 4ec2ea0..a4cd00b 100644 --- a/examples/chat/chat.mjs +++ b/examples/chat/chat.ts @@ -3,8 +3,8 @@ * Simple chat example using lloyal.node * * Usage: - * node chat.mjs /path/to/model.gguf - * node chat.mjs # uses default model path + * npx tsx chat.ts /path/to/model.gguf + * npx tsx chat.ts # uses default model path * * This example demonstrates: * - Branch API for token generation (produce/commit two-phase) @@ -15,23 +15,22 @@ import * as readline from "node:readline"; import * as path from "node:path"; -import { fileURLToPath } from "node:url"; -import { createContext, Branch } from "../../lib/index.js"; +import { createContext, Branch } from "../../dist/index.js"; +import type { SessionContext, FormattedChatResult } from "../../dist/index.js"; -const __dirname = path.dirname(fileURLToPath(import.meta.url)); const DEFAULT_MODEL = path.resolve( __dirname, "../../models/Phi-3.5-mini-instruct-Q4_K_M.gguf", ); -async function main() { +async function main(): Promise { const modelPath = process.argv[2] || DEFAULT_MODEL; console.log(`Loading model: ${modelPath}`); console.log("This may take a moment...\n"); const nCtx = parseInt(process.env.LLAMA_CTX_SIZE || '2048', 10); - const ctx = await createContext({ + const ctx: SessionContext = await createContext({ modelPath, nCtx, threads: 4, @@ -40,19 +39,19 @@ async function main() { console.log("Model loaded! Type your message and press Enter."); console.log("Commands: /clear to reset, /quit to exit\n"); - const messages = []; - let branch = null; - let fmt = null; - const sep = ctx.getTurnSeparator(); + const messages: Array<{role: string; content: string; reasoning_content?: string}> = []; + let branch: InstanceType | null = null; + let fmt: FormattedChatResult | null = null; + const sep: number[] = ctx.getTurnSeparator(); const rl = readline.createInterface({ input: process.stdin, output: process.stdout, }); - const askUser = () => rl.question("> ", handleInput); + const askUser = (): void => { rl.question("> ", handleInput); }; - async function handleInput(input) { + async function handleInput(input: string): Promise { const trimmed = input.trim(); if (trimmed === "/quit" || trimmed === "/exit") { @@ -111,13 +110,13 @@ async function main() { console.log("\n"); // Parse output: separates reasoning from content for thinking models - const parsed = ctx.parseChatOutput(rawOutput, fmt.format, { - reasoningFormat: fmt.reasoningFormat, - thinkingForcedOpen: fmt.thinkingForcedOpen, - parser: fmt.parser, + const parsed = ctx.parseChatOutput(rawOutput, fmt!.format, { + reasoningFormat: fmt!.reasoningFormat, + thinkingForcedOpen: fmt!.thinkingForcedOpen, + parser: fmt!.parser, }); - const msg = { role: "assistant", content: parsed.content }; + const msg: {role: string; content: string; reasoning_content?: string} = { role: "assistant", content: parsed.content }; if (parsed.reasoningContent) { msg.reasoning_content = parsed.reasoningContent; } @@ -129,7 +128,7 @@ async function main() { askUser(); } -main().catch((err) => { - console.error("Error:", err.message); +main().catch((err: unknown) => { + console.error("Error:", (err as Error).message); process.exit(1); }); diff --git a/examples/deep-research/agreement.ts b/examples/deep-research/agreement.ts new file mode 100644 index 0000000..58380e1 --- /dev/null +++ b/examples/deep-research/agreement.ts @@ -0,0 +1,142 @@ +/** + * Per-section agreement analysis via bigram Jaccard similarity. + * + * Pure string math β€” no model calls. Used by the verify phase to quantify + * where N diverge attempts agree (confident) vs disagree (hallucination risk). + */ + +export interface SectionAgreement { + label: string; // section header or "ΒΆ1", "ΒΆ2", etc. + score: number; // 0–1 average pairwise bigram Jaccard +} + +export interface AgreementResult { + overall: number; // mean of section scores + sections: SectionAgreement[]; // per-section breakdown +} + +// ── Internals ───────────────────────────────────────────────────── + +interface Section { + key: string; // normalized header for matching, or positional index + label: string; // display label + body: string; // section text +} + +const HEADER_RE = /^#{1,4}\s+/m; + +function normalizeKey(header: string): string { + return header.toLowerCase().replace(/[^\w\s]/g, '').trim(); +} + +function extractSections(text: string): Section[] { + const hasHeaders = HEADER_RE.test(text); + + if (hasHeaders) { + const parts = text.split(/^(#{1,4}\s+.+)$/m).filter(Boolean); + const sections: Section[] = []; + for (let i = 0; i < parts.length; i++) { + const match = parts[i].match(/^#{1,4}\s+(.+)$/); + if (match) { + const header = match[1].trim(); + const body = (parts[i + 1] ?? '').trim(); + sections.push({ key: normalizeKey(header), label: header, body }); + i++; // skip body part + } + } + return sections.length ? sections : paragraphSections(text); + } + + return paragraphSections(text); +} + +function paragraphSections(text: string): Section[] { + return text.split(/\n{2,}/) + .map(p => p.trim()) + .filter(Boolean) + .map((body, i) => ({ key: String(i), label: `ΒΆ${i + 1}`, body })); +} + +function wordBigrams(text: string): Set { + const words = text.split(/\s+/).filter(Boolean); + const bigrams = new Set(); + for (let i = 0; i < words.length - 1; i++) { + bigrams.add(`${words[i]} ${words[i + 1]}`); + } + return bigrams; +} + +function jaccard(a: Set, b: Set): number { + if (a.size === 0 && b.size === 0) return 1; + let intersection = 0; + const [smaller, larger] = a.size <= b.size ? [a, b] : [b, a]; + for (const x of smaller) if (larger.has(x)) intersection++; + const union = a.size + b.size - intersection; + return union === 0 ? 1 : intersection / union; +} + +function averagePairwiseJaccard(texts: string[]): number { + if (texts.length < 2) return 1; + const bigramSets = texts.map(wordBigrams); + let sum = 0; + let pairs = 0; + for (let i = 0; i < bigramSets.length; i++) { + for (let j = i + 1; j < bigramSets.length; j++) { + sum += jaccard(bigramSets[i], bigramSets[j]); + pairs++; + } + } + return sum / pairs; +} + +// ── Public API ──────────────────────────────────────────────────── + +export function computeAgreement(outputs: string[]): AgreementResult { + if (outputs.length < 2) return { overall: 1, sections: [] }; + + const allSections = outputs.map(extractSections); + const hasHeaders = allSections.some(ss => ss.length > 0 && ss[0].key !== '0'); + + if (hasHeaders) { + // Collect all unique section keys across attempts + const keySet = new Map(); // key β†’ label (first seen) + for (const ss of allSections) { + for (const s of ss) { + if (!keySet.has(s.key)) keySet.set(s.key, s.label); + } + } + + const sections: SectionAgreement[] = [...keySet.entries()].map(([key, label]) => { + const bodies = allSections + .map(ss => ss.find(s => s.key === key)?.body) + .filter((b): b is string => b != null && b.length > 0); + // Sections present in only one attempt get score 0 + const score = bodies.length < 2 ? 0 : averagePairwiseJaccard(bodies); + return { label, score }; + }); + + const overall = sections.length + ? sections.reduce((s, x) => s + x.score, 0) / sections.length + : 0; + + return { overall, sections }; + } + + // Positional matching for headerless content + const maxSections = Math.max(...allSections.map(ss => ss.length)); + const sections: SectionAgreement[] = []; + + for (let i = 0; i < maxSections; i++) { + const bodies = allSections + .map(ss => ss[i]?.body) + .filter((b): b is string => b != null && b.length > 0); + const score = bodies.length < 2 ? 0 : averagePairwiseJaccard(bodies); + sections.push({ label: `ΒΆ${i + 1}`, score }); + } + + const overall = sections.length + ? sections.reduce((s, x) => s + x.score, 0) / sections.length + : 0; + + return { overall, sections }; +} diff --git a/examples/deep-research/harness.ts b/examples/deep-research/harness.ts new file mode 100644 index 0000000..8b6f687 --- /dev/null +++ b/examples/deep-research/harness.ts @@ -0,0 +1,415 @@ +import * as fs from 'node:fs'; +import * as path from 'node:path'; +import { call, scoped } from 'effection'; +import type { Operation, Channel } from 'effection'; +import { Branch, Session } from '../../dist'; +import type { SessionContext } from '../../dist'; +import { + Ctx, + generate, useAgentPool, runAgents, diverge, withSharedRoot, +} from '../../dist/agents'; +import type { Tool, AgentPoolResult, DivergeResult } from '../../dist/agents'; +import type { WorkflowEvent, OpTiming } from './tui'; +import { computeAgreement } from './agreement'; +import { reportTool } from './tools'; + +/** Load a task prompt file. Convention: system prompt above `---`, user content below. */ +function loadTask(name: string): { system: string; user: string } { + const raw = fs.readFileSync(path.resolve(__dirname, `tasks/${name}.md`), 'utf8').trim(); + const sep = raw.indexOf('\n---\n'); + if (sep === -1) return { system: raw, user: '' }; + return { system: raw.slice(0, sep).trim(), user: raw.slice(sep + 5).trim() }; +} + +const PLAN = loadTask('plan'); +const RESEARCH = loadTask('research'); +const VERIFY = loadTask('verify'); +const EVAL = loadTask('eval'); +const REPORT = loadTask('report'); + +// ── Options ────────────────────────────────────────────────────── + +export interface WorkflowOpts { + session: Session; + toolMap: Map; + toolsJson: string; + agentCount: number; + verifyCount: number; + maxTurns: number; + trace: boolean; + events: Channel; +} + +// ── Agent task builder ─────────────────────────────────────────── + +function agentTasks(questions: string[], toolsJson: string, parent: Branch, seed?: number) { + return questions.map((q, i) => ({ + systemPrompt: RESEARCH.system, + content: q, + tools: toolsJson, + parent, + seed: seed != null ? seed + i : undefined, + })); +} + +const reportOnlyTools = JSON.stringify([reportTool.schema]); + +function* reportPass( + pool: AgentPoolResult, + opts: WorkflowOpts, +): Operation { + const hardCut = pool.agents.filter(a => !a.findings && !a.branch.disposed); + if (hardCut.length === 0) return; + + // Free KV from successful agents before spawning reporters + for (const a of pool.agents) { + if (a.findings && !a.branch.disposed) a.branch.pruneSync(); + } + + const reporters = yield* runAgents({ + tasks: hardCut.map(a => ({ + systemPrompt: REPORT.system, + content: REPORT.user, + tools: reportOnlyTools, + parent: a.branch, + })), + tools: new Map([['report', reportTool]]), + terminalTool: 'report', + trace: opts.trace, + pressure: { softLimit: 200, hardLimit: 64 }, + }); + + hardCut.forEach((a, i) => { + if (reporters.agents[i]?.findings) a.findings = reporters.agents[i].findings; + }); +} + +// ── Operations ─────────────────────────────────────────────────── + +function* plan(query: string, opts: WorkflowOpts): Operation<{ questions: string[]; tokenCount: number; timeMs: number }> { + const ctx: SessionContext = yield* Ctx.expect(); + const t = performance.now(); + + const schema = { + type: 'object', + properties: { + questions: { + type: 'array', + items: { type: 'string' }, + minItems: 2, + maxItems: opts.agentCount, + }, + }, + required: ['questions'], + }; + const grammar: string = yield* call(() => ctx.jsonSchemaToGrammar(JSON.stringify(schema))); + + const userContent = PLAN.user + .replace('{{count}}', String(opts.agentCount)) + .replace('{{query}}', query); + + const messages = [ + { role: 'system', content: PLAN.system }, + { role: 'user', content: userContent }, + ]; + const { prompt }: { prompt: string } = yield* call(() => ctx.formatChat(JSON.stringify(messages))); + + let output: string; + let tokenCount: number; + + const parent = opts.session.trunk ?? undefined; + if (parent) { + const lead: Branch = yield* call(() => parent.fork()); + try { + lead.setGrammar(grammar); + const sep = ctx.getTurnSeparator(); + const delta: number[] = yield* call(() => ctx.tokenize(prompt, false)); + yield* call(() => lead.prefill([...sep, ...delta])); + + ({ output, tokenCount } = yield* call(async () => { + let o = ''; + let tc = 0; + for await (const { text } of lead) { o += text; tc++; } + return { output: o, tokenCount: tc }; + })); + } finally { + yield* call(() => lead.prune()); + } + } else { + const result = yield* generate({ prompt, grammar, params: { temperature: 0.3 } }); + output = result.output; + tokenCount = result.tokenCount; + } + + let questions: string[]; + try { + questions = JSON.parse(output).questions.slice(0, opts.agentCount); + if (!questions.length) throw new Error('empty'); + } catch { + questions = Array.from({ length: opts.agentCount }, (_, i) => `${query} (aspect ${i + 1})`); + } + + const timeMs = performance.now() - t; + yield* opts.events.send({ type: 'plan', questions, tokenCount, timeMs }); + return { questions, tokenCount, timeMs }; +} + +function* research( + questions: string[], + opts: WorkflowOpts, +): Operation<{ pool: AgentPoolResult; sharedPrefixLength: number; timeMs: number }> { + yield* opts.events.send({ type: 'research:start', agentCount: questions.length }); + const t = performance.now(); + + const { result: pool, prefixLen: sharedPrefixLength } = yield* withSharedRoot( + { systemPrompt: RESEARCH.system, tools: opts.toolsJson }, + function*(root, prefixLen) { + const pool = yield* useAgentPool({ + tasks: agentTasks(questions, opts.toolsJson, root), + tools: opts.toolMap, maxTurns: opts.maxTurns, trace: opts.trace, + terminalTool: 'report', + pressure: { softLimit: 2048 }, + }); + + yield* reportPass(pool, opts); + return { result: pool, prefixLen }; + }, + ); + + const timeMs = performance.now() - t; + yield* opts.events.send({ type: 'research:done', pool, timeMs }); + return { pool, sharedPrefixLength, timeMs }; +} + +function* warmResearch( + questions: string[], + opts: WorkflowOpts, +): Operation<{ pool: AgentPoolResult; timeMs: number }> { + yield* opts.events.send({ type: 'research:start', agentCount: questions.length }); + const t = performance.now(); + + const pool = yield* scoped(function*() { + const pool = yield* useAgentPool({ + tasks: agentTasks(questions, opts.toolsJson, opts.session.trunk!, Date.now()), + tools: opts.toolMap, maxTurns: opts.maxTurns, trace: opts.trace, + terminalTool: 'report', + pressure: { softLimit: 1024 }, + }); + + yield* reportPass(pool, opts); + return pool; + }); + + const timeMs = performance.now() - t; + yield* opts.events.send({ type: 'research:done', pool, timeMs }); + return { pool, timeMs }; +} + +function* verify( + pool: AgentPoolResult, + questions: string[], + query: string, + opts: WorkflowOpts, +): Operation<{ result: DivergeResult; timeMs: number }> { + const ctx: SessionContext = yield* Ctx.expect(); + const findingsText = pool.agents + .map((a, i) => `Q: ${questions[i]}\nA: ${(a.findings || '').trim()}`) + .join('\n\n'); + + const userContent = VERIFY.user + .replace('{{findings}}', findingsText) + .replace('{{query}}', query); + + const messages = [ + { role: 'system', content: VERIFY.system }, + { role: 'user', content: userContent }, + ]; + const { prompt }: { prompt: string } = yield* call(() => ctx.formatChat(JSON.stringify(messages))); + + yield* opts.events.send({ type: 'verify:start', count: opts.verifyCount }); + const t = performance.now(); + const result = yield* diverge({ + prompt, + attempts: opts.verifyCount, + params: { temperature: 0.7 }, + }); + const timeMs = performance.now() - t; + const agreement = computeAgreement(result.attempts.map(a => a.output)); + yield* opts.events.send({ type: 'verify:agreement', result: agreement }); + yield* opts.events.send({ type: 'verify:done', result, timeMs }); + return { result, timeMs }; +} + +function* evaluate( + verifyResult: DivergeResult, + opts: WorkflowOpts, +): Operation<{ converged: boolean | null; tokenCount: number; timeMs: number }> { + const ctx: SessionContext = yield* Ctx.expect(); + + const responsesText = verifyResult.attempts + .map((a, i) => `Response ${i + 1}: ${a.output.trim()}`) + .join('\n\n'); + + const userContent = EVAL.user.replace('{{responses}}', responsesText); + + const messages = [ + { role: 'system', content: EVAL.system }, + { role: 'user', content: userContent }, + ]; + + const evalSchema = { + type: 'object', + properties: { converged: { type: 'boolean' } }, + required: ['converged'], + }; + const grammar: string = yield* call(() => ctx.jsonSchemaToGrammar(JSON.stringify(evalSchema))); + const { prompt }: { prompt: string } = yield* call(() => ctx.formatChat(JSON.stringify(messages))); + + const t = performance.now(); + const result = yield* generate({ + prompt, + grammar, + params: { temperature: 0 }, + parse: (output: string) => { + try { return JSON.parse(output).converged as boolean; } + catch { return null; } + }, + }); + const timeMs = performance.now() - t; + yield* opts.events.send({ type: 'eval:done', converged: result.parsed as boolean | null, tokenCount: result.tokenCount, timeMs }); + return { converged: result.parsed as boolean | null, tokenCount: result.tokenCount, timeMs }; +} + +function* answer(verifyResult: DivergeResult, opts: WorkflowOpts): Operation { + yield* opts.events.send({ type: 'answer', text: verifyResult.bestOutput }); +} + +function* promote(verifyResult: DivergeResult, opts: WorkflowOpts): Operation { + yield* call(() => opts.session.promote(verifyResult.best)); +} + +function* respond( + pool: AgentPoolResult, + query: string, + opts: WorkflowOpts, +): Operation<{ tokenCount: number; timeMs: number }> { + const agentFindings = pool.agents + .map((a: { findings: string | null }, i: number) => + a.findings ? `[Agent ${i}] ${a.findings.trim()}` : null) + .filter(Boolean) + .join('\n\n'); + + yield* call(() => opts.session.prefillUser(agentFindings + ? `Research findings:\n${agentFindings}\n\nUser question: ${query}\n\nAnswer based on the research findings above.` + : query)); + + yield* opts.events.send({ type: 'response:start' }); + const t = performance.now(); + let tokenCount = 0; + const trunk = opts.session.trunk!; + for (;;) { + const { token, text, isStop } = trunk.produceSync(); + if (isStop) break; + yield* call(() => trunk.commit(token)); + tokenCount++; + yield* opts.events.send({ type: 'response:text', text }); + } + const timeMs = performance.now() - t; + yield* opts.events.send({ type: 'response:done' }); + return { tokenCount, timeMs }; +} + +function* summarize( + timings: OpTiming[], + opts: WorkflowOpts, + extra?: { kvLine?: string }, +): Operation { + const ctx: SessionContext = yield* Ctx.expect(); + const p = ctx._storeKvPressure(); + const ctxTotal = p.nCtx || 1; + yield* opts.events.send({ + type: 'stats', timings, + kvLine: extra?.kvLine, + ctxPct: Math.round(100 * p.cellsUsed / ctxTotal), + ctxPos: p.cellsUsed, + ctxTotal, + }); +} + +// ── Workflow compositions ──────────────────────────────────────── + +function* coldQuery(query: string, opts: WorkflowOpts): Operation { + const t0 = performance.now(); + + const p = yield* plan(query, opts); + const r = yield* research(p.questions, opts); + const v = yield* verify(r.pool, p.questions, query, opts); + const e = yield* evaluate(v.result, opts); + yield* answer(v.result, opts); + yield* promote(v.result, opts); + + const timings: OpTiming[] = [ + { label: 'Plan', tokens: p.tokenCount, detail: '', timeMs: p.timeMs }, + { + label: 'Research', tokens: r.pool.totalTokens, + detail: `(${r.pool.agents.map(a => a.tokenCount).join(' + ')}) ${r.pool.totalToolCalls} tools`, + timeMs: r.timeMs, + }, + { + label: 'Verify', tokens: v.result.totalTokens, + detail: `(${v.result.attempts.map(a => a.tokenCount).join(' + ')})`, + timeMs: v.timeMs, + }, + { label: 'Eval', tokens: e.tokenCount, detail: `converged: ${e.converged ? 'yes' : 'no'}`, timeMs: e.timeMs }, + ]; + + const kvSaved = r.sharedPrefixLength * (p.questions.length - 1) + + v.result.prefixLength * (v.result.attempts.length - 1); + const kvLine = `KV shared ${r.sharedPrefixLength} \u00d7 ${p.questions.length - 1} + ${v.result.prefixLength} \u00d7 ${v.result.attempts.length - 1} = ${kvSaved.toLocaleString()} tok saved`; + + yield* summarize(timings, opts, { kvLine }); + + yield* opts.events.send({ + type: 'complete', + data: { + planTokens: p.tokenCount, + agentTokens: r.pool.totalTokens, researchSteps: r.pool.steps, + agentPpl: r.pool.agents.map(a => a.ppl), + verifyTokens: v.result.totalTokens, verifySteps: v.result.steps, + evalTokens: e.tokenCount, converged: e.converged, + totalToolCalls: r.pool.totalToolCalls, + prefixTokens: v.result.prefixLength, + sharedPrefixTokens: r.sharedPrefixLength, + agentCount: p.questions.length, attemptCount: v.result.attempts.length, + wallTimeMs: Math.round(performance.now() - t0), + planMs: Math.round(p.timeMs), researchMs: Math.round(r.timeMs), + verifyMs: Math.round(v.timeMs), evalMs: Math.round(e.timeMs), + ...r.pool.counters, + }, + }); +} + +function* warmQuery(query: string, opts: WorkflowOpts): Operation { + const p = yield* plan(query, opts); + const r = yield* warmResearch(p.questions, opts); + const resp = yield* respond(r.pool, query, opts); + + const timings: OpTiming[] = [ + { label: 'Plan', tokens: p.tokenCount, detail: '', timeMs: p.timeMs }, + { + label: 'Research', tokens: r.pool.totalTokens, + detail: `(${r.pool.agents.map(a => a.tokenCount).join(' + ')}) ${r.pool.totalToolCalls} tools`, + timeMs: r.timeMs, + }, + { label: 'Response', tokens: resp.tokenCount, detail: '', timeMs: resp.timeMs }, + ]; + + yield* summarize(timings, opts); +} + +// ── Entry point ────────────────────────────────────────────────── + +export function* handleQuery(query: string, opts: WorkflowOpts): Operation { + yield* opts.events.send({ type: 'query', query, warm: !!opts.session.trunk }); + yield* (opts.session.trunk ? warmQuery : coldQuery)(query, opts); +} diff --git a/examples/deep-research/main.ts b/examples/deep-research/main.ts new file mode 100644 index 0000000..edeaf10 --- /dev/null +++ b/examples/deep-research/main.ts @@ -0,0 +1,219 @@ +#!/usr/bin/env node +/** + * Deep Research β€” CLI entry point + * + * Wiring only: setup, TUI subscriber, REPL. + * Orchestration lives in harness.ts. Presentation lives in tui.ts. + * + * Usage: + * npx tsx examples/deep-research/main.ts [model-path] --corpus [--query ] [options] + */ + +import * as fs from "node:fs"; +import * as path from "node:path"; +import * as readline from "node:readline"; +import { + main, + ensure, + createSignal, + spawn, + each, + call, + action, +} from "effection"; +import { createContext } from "../../dist"; +import type { SessionContext } from "../../dist"; +import { initAgents } from "../../dist/agents"; +import { c, log, setJsonlMode, setVerboseMode, fmtSize, createView } from "./tui"; +import type { WorkflowEvent } from "./tui"; +import { loadResources, chunkResources } from "./resources/files"; +import { createReranker } from "./reranker"; +import { createTools } from "./tools"; +import { handleQuery } from "./harness"; +import type { WorkflowOpts } from "./harness"; + +// ── CLI args ───────────────────────────────────────────────────── + +const DEFAULT_MODEL = path.resolve( + __dirname, + "../../models/Qwen3-4B-Instruct-2507-Q4_K_M.gguf", +); +const DEFAULT_RERANKER = path.resolve( + __dirname, + "../../models/qwen3-reranker-0.6b-q4_k_m.gguf", +); + +const args = process.argv.slice(2); +const jsonlMode = args.includes("--jsonl"); +const verbose = args.includes("--verbose"); +const trace = args.includes("--trace"); + +function argVal(flag: string): string | null { + const i = args.indexOf(flag); + return i !== -1 ? args[i + 1] : null; +} +const flagIndices = new Set( + ["--reranker", "--corpus", "--query"].flatMap((f) => { + const i = args.indexOf(f); + return i !== -1 ? [i, i + 1] : []; + }), +); + +const rerankModelPath = argVal("--reranker") || DEFAULT_RERANKER; +const corpusDir = argVal("--corpus"); +const initialQuery = argVal("--query"); +const modelPath = + args.find((a, i) => !a.startsWith("--") && !flagIndices.has(i)) || + DEFAULT_MODEL; + +if (!corpusDir) { + process.stdout.write( + `Usage: npx tsx examples/deep-research/main.ts [model-path] --corpus [--query ] [--reranker ]\nMissing: --corpus\n`, + ); + process.exit(1); +} + +if (jsonlMode) setJsonlMode(true); +if (verbose) setVerboseMode(true); +if (!verbose && !jsonlMode && !trace) { + try { + fs.closeSync(2); + fs.openSync(process.platform === "win32" ? "\\\\.\\NUL" : "/dev/null", "w"); + } catch { + /* non-fatal */ + } +} + +const AGENT_COUNT = 3; +const VERIFY_COUNT = 3; +const MAX_TOOL_TURNS = 20; + +// ── Main ───────────────────────────────────────────────────────── + +main(function* () { + const resources = loadResources(corpusDir!); + const chunks = chunkResources(resources); + + const modelName = path.basename(modelPath).replace(/-Q\w+\.gguf$/, ""); + const rerankName = path + .basename(rerankModelPath) + .replace(/-q\w+\.gguf$/i, ""); + + log(); + log( + `${c.bold} Deep Research${c.reset} ${c.dim}\u2014 Structured Concurrency Runtime${c.reset}`, + ); + log(); + log( + ` ${c.green}\u25cf${c.reset} Loading ${c.bold}${modelName}${c.reset} ${c.dim}(${fmtSize(fs.statSync(modelPath).size)}, KV: Q4_0)${c.reset}`, + ); + + const nCtx = parseInt(process.env.LLAMA_CTX_SIZE || "16384", 10); + const ctx: SessionContext = yield* call(() => + createContext({ + modelPath, + nCtx, + nSeqMax: Math.max(AGENT_COUNT, VERIFY_COUNT) * 2 + 1, + typeK: "q4_0", + typeV: "q4_0", + }), + ); + + log( + ` ${c.green}\u25cf${c.reset} Loading ${c.bold}${rerankName}${c.reset} ${c.dim}(${fmtSize(fs.statSync(rerankModelPath).size)}, reranker)${c.reset}`, + ); + + const reranker = yield* call(() => + createReranker(rerankModelPath, { nSeqMax: 8, nCtx: 4096 }), + ); + yield* ensure(() => { + reranker.dispose(); + }); + yield* call(() => reranker.tokenizeChunks(chunks)); + + const corpusIsFile = + resources.length === 1 && fs.statSync(corpusDir!).isFile(); + const corpusLabel = corpusIsFile + ? path.basename(corpusDir!) + : `${path.basename(corpusDir!)}/ \u2014 ${resources.length} files`; + log( + ` ${c.dim} Corpus: ${corpusLabel} \u2192 ${chunks.length} chunks${c.reset}`, + ); + + const { toolMap, toolsJson } = createTools({ resources, chunks, reranker }); + const { session, events } = yield* initAgents(ctx); + + // View subscriber β€” all presentation lives here + const view = createView({ + model: path.basename(modelPath), + reranker: path.basename(rerankModelPath), + agentCount: AGENT_COUNT, + verifyCount: VERIFY_COUNT, + chunkCount: chunks.length, + }); + yield* spawn(function* () { + yield* view.subscribe(events); + }); + + const harnessOpts: WorkflowOpts = { + session, + toolMap, + toolsJson, + events, + agentCount: AGENT_COUNT, + verifyCount: VERIFY_COUNT, + maxTurns: MAX_TOOL_TURNS, + trace, + }; + + // Initial query + if (initialQuery) { + yield* handleQuery(initialQuery, harnessOpts); + if (jsonlMode) return; + } + + // REPL β€” Signal bridges readline into Effection scope + log( + ` ${c.dim}${session.trunk ? "Ask a follow-up question" : "Enter your research question"} or /quit to exit${c.reset}`, + ); + log(); + + const inputSignal = createSignal(); + const rl = readline.createInterface({ + input: process.stdin, + output: process.stdout, + }); + rl.setPrompt(` ${c.dim}>${c.reset} `); + + yield* spawn(function* () { + yield* action((resolve) => { + rl.on("line", (line: string) => inputSignal.send(line.trim())); + rl.on("close", () => { + inputSignal.close(); + resolve(); + }); + return () => rl.close(); + }); + }); + + rl.prompt(); + for (const input of yield* each(inputSignal)) { + if (!input || input === "/quit") break; + try { + yield* handleQuery(input, harnessOpts); + } catch (err) { + log(` ${c.red}Error: ${(err as Error).message}${c.reset}`); + } + yield* each.next(); + try { + rl.prompt(); + } catch { + break; + } + } +}).catch((err: unknown) => { + process.stdout.write( + `Error: ${(err as Error).message}\n${(err as Error).stack}\n`, + ); + process.exit(1); +}); diff --git a/examples/deep-research/reranker.ts b/examples/deep-research/reranker.ts new file mode 100644 index 0000000..118e17e --- /dev/null +++ b/examples/deep-research/reranker.ts @@ -0,0 +1,59 @@ +import { Rerank } from "../../dist"; +import type { Chunk } from "./resources/types"; +import type { Reranker, ScoredResult } from "./tools/types"; + +export async function createReranker( + modelPath: string, + opts?: { nSeqMax?: number; nCtx?: number }, +): Promise { + const rerank = await Rerank.create({ modelPath, ...opts }); + + return { + score(query: string, chunks: Chunk[]): AsyncIterable { + const inner = rerank.score( + query, + chunks.map((c) => c.tokens), + 10, + ); + return { + [Symbol.asyncIterator](): AsyncIterator { + const it = inner[Symbol.asyncIterator](); + return { + async next(): Promise> { + const { value, done } = await it.next(); + if (done) + return { + value: undefined as unknown as ScoredResult, + done: true, + }; + return { + value: { + filled: value.filled, + total: value.total, + results: value.results.map((r) => ({ + file: chunks[r.index].resource, + heading: chunks[r.index].heading, + score: r.score, + startLine: chunks[r.index].startLine, + endLine: chunks[r.index].endLine, + })), + }, + done: false, + }; + }, + }; + }, + }; + }, + + async tokenizeChunks(chunks: Chunk[]): Promise { + for (const chunk of chunks) { + chunk.tokens = await rerank.tokenize(chunk.text); + } + }, + + dispose() { + rerank.dispose(); + }, + }; +} diff --git a/examples/deep-research/resources/files.ts b/examples/deep-research/resources/files.ts new file mode 100644 index 0000000..4004374 --- /dev/null +++ b/examples/deep-research/resources/files.ts @@ -0,0 +1,73 @@ +import * as fs from 'node:fs'; +import * as path from 'node:path'; +import { loadBinary } from '../../../dist'; +import type { Resource, Chunk } from './types'; + +interface Section { heading: string; level: number; startLine: number; endLine: number } +const { parseMarkdown } = loadBinary() as unknown as { parseMarkdown(text: string): Section[] }; + +export function loadResources(dir: string): Resource[] { + if (!fs.existsSync(dir)) { + process.stdout.write(`Error: corpus not found: ${dir}\n`); + process.exit(1); + } + const stat = fs.statSync(dir); + if (stat.isFile()) { + return [{ name: path.basename(dir), content: fs.readFileSync(dir, 'utf8') }]; + } + const files = fs.readdirSync(dir).filter((f) => f.endsWith('.md')); + if (!files.length) { + process.stdout.write(`Error: no .md files in: ${dir}\n`); + process.exit(1); + } + return files.map((f) => ({ + name: f, + content: fs.readFileSync(path.join(dir, f), 'utf8'), + })); +} + +/** Split plain text into chunks on blank-line paragraph boundaries */ +function chunkByParagraph(res: Resource): Chunk[] { + const lines = res.content.split('\n'); + const chunks: Chunk[] = []; + let start = 0; + for (let i = 0; i <= lines.length; i++) { + const blank = i === lines.length || !lines[i].trim(); + if (blank && i > start) { + const text = lines.slice(start, i).join('\n').trim(); + if (text) { + chunks.push({ + resource: res.name, + heading: text.slice(0, 60).replace(/\n/g, ' ') + (text.length > 60 ? '…' : ''), + text, tokens: [], + startLine: start + 1, + endLine: i, + }); + } + } + if (blank) start = i + 1; + } + return chunks; +} + +export function chunkResources(resources: Resource[]): Chunk[] { + const out: Chunk[] = []; + for (const res of resources) { + const sections = parseMarkdown(res.content); + // Single section covering the whole file = no headings found β†’ paragraph split + if (sections.length <= 1 && res.content.split('\n').length > 10) { + out.push(...chunkByParagraph(res)); + continue; + } + const lines = res.content.split('\n'); + for (const sec of sections) { + const text = lines.slice(sec.startLine - 1, sec.endLine).join('\n').trim(); + if (!text) continue; + out.push({ + resource: res.name, heading: sec.heading || res.name, text, tokens: [], + startLine: sec.startLine, endLine: sec.endLine, + }); + } + } + return out; +} diff --git a/examples/deep-research/resources/types.ts b/examples/deep-research/resources/types.ts new file mode 100644 index 0000000..17242b1 --- /dev/null +++ b/examples/deep-research/resources/types.ts @@ -0,0 +1,10 @@ +export interface Resource { name: string; content: string } + +export interface Chunk { + resource: string; + heading: string; + text: string; + tokens: number[]; + startLine: number; + endLine: number; +} diff --git a/examples/deep-research/tasks/eval.md b/examples/deep-research/tasks/eval.md new file mode 100644 index 0000000..d555374 --- /dev/null +++ b/examples/deep-research/tasks/eval.md @@ -0,0 +1,5 @@ +You are a consistency checker. Compare the responses and determine if they convey the same core meaning. Output JSON only. +--- +Do these responses agree on the key points? + +{{responses}} diff --git a/examples/deep-research/tasks/plan.md b/examples/deep-research/tasks/plan.md new file mode 100644 index 0000000..05bba9a --- /dev/null +++ b/examples/deep-research/tasks/plan.md @@ -0,0 +1,3 @@ +You break research queries into sub-questions. Output JSON only. +--- +Break this into {{count}} independent sub-questions for parallel research: "{{query}}" diff --git a/examples/deep-research/tasks/report.md b/examples/deep-research/tasks/report.md new file mode 100644 index 0000000..189a41b --- /dev/null +++ b/examples/deep-research/tasks/report.md @@ -0,0 +1,3 @@ +You are a research reporter. Call the report tool with a concise summary (under 200 words) of the key findings from the research above. Focus on the most important discoveries and conclusions. +--- +Report your findings. diff --git a/examples/deep-research/tasks/research.md b/examples/deep-research/tasks/research.md new file mode 100644 index 0000000..60b25c2 --- /dev/null +++ b/examples/deep-research/tasks/research.md @@ -0,0 +1,12 @@ +You are a research assistant analyzing a knowledge base. Your tools: +- **grep**: regex pattern matching β€” use for precise, exhaustive retrieval +- **search**: semantic relevance ranking β€” use to discover related content +- **read_file**: read specific line ranges β€” use to verify and get context +- **report**: submit your final findings with evidence + +Process β€” follow every step in order: +1. Grep with short, simple patterns first. Use single keywords or two-word phrases β€” never combine multiple clauses with `.*`. Run multiple greps if needed. +2. Use search to discover content that grep may miss (different phrasing, synonyms). +3. Read every matching line with read_file to verify in context. Do not rely on grep/search summaries alone. +4. Grep again with a different pattern targeting what you have NOT yet found. This is a completeness check, not confirmation of existing results. +5. Report with line numbers and direct quotes as evidence. State what you found and what you checked. diff --git a/examples/deep-research/tasks/verify.md b/examples/deep-research/tasks/verify.md new file mode 100644 index 0000000..0713358 --- /dev/null +++ b/examples/deep-research/tasks/verify.md @@ -0,0 +1,7 @@ +Synthesize the research findings into a coherent, concise summary. +--- +Research findings: + +{{findings}} + +Synthesize these into a brief summary answering: "{{query}}" diff --git a/examples/deep-research/tools/grep.ts b/examples/deep-research/tools/grep.ts new file mode 100644 index 0000000..bc3ae5f --- /dev/null +++ b/examples/deep-research/tools/grep.ts @@ -0,0 +1,67 @@ +import { Tool } from '../../../dist/agents'; +import type { JsonSchema } from '../../../dist/agents'; +import type { Resource } from '../resources/types'; + +export class GrepTool extends Tool<{ pattern: string; ignoreCase?: boolean }> { + readonly name = 'grep'; + readonly description = 'Search the entire corpus for a regex pattern. Returns every matching line with line numbers and total match count. Complements search() which ranks by relevance β€” grep scans exhaustively.'; + readonly parameters: JsonSchema = { + type: 'object', + properties: { + pattern: { type: 'string', description: 'Regex pattern (e.g. "\\bshor\\b" for whole-word, "hidden_secret" for literal)' }, + ignoreCase: { type: 'boolean', description: 'Case-insensitive matching (default: true)' }, + }, + required: ['pattern'], + }; + + private _resources: Resource[]; + + constructor(resources: Resource[]) { + super(); + this._resources = resources; + } + + async execute(args: { pattern: string; ignoreCase?: boolean }): Promise { + const pattern = args.pattern?.trim(); + if (!pattern) return { error: 'pattern must not be empty' }; + const flags = (args.ignoreCase === false) ? 'g' : 'gi'; + let re: RegExp; + try { re = new RegExp(pattern, flags); } + catch { return { error: `Invalid regex: ${pattern}` }; } + + const matches: { file: string; line: number; text: string }[] = []; + let totalMatches = 0; + + for (const res of this._resources) { + const lines = res.content.split('\n'); + for (let i = 0; i < lines.length; i++) { + const hits = lines[i].match(re); + if (hits) { + totalMatches += hits.length; + const raw = lines[i].trim(); + let text: string; + if (raw.length <= 200) { + text = raw; + } else { + const idx = raw.search(re); + const start = Math.max(0, idx - 40); + const end = Math.min(raw.length, start + 200); + text = (start > 0 ? '\u2026' : '') + raw.slice(start, end) + (end < raw.length ? '\u2026' : ''); + } + matches.push({ file: res.name, line: i + 1, text }); + } + } + } + + if (totalMatches === 0) { + return { + totalMatches: 0, matchingLines: 0, matches: [], + note: 'Zero matches does NOT mean the topic is absent \u2014 only that this exact pattern was not found. Try search() for semantic matching or a broader/simpler regex.', + }; + } + + const limit = 50; + const truncated = matches.length > limit; + return { totalMatches, matchingLines: matches.length, truncated, matches: matches.slice(0, limit) }; + } +} diff --git a/examples/deep-research/tools/index.ts b/examples/deep-research/tools/index.ts new file mode 100644 index 0000000..2145f44 --- /dev/null +++ b/examples/deep-research/tools/index.ts @@ -0,0 +1,23 @@ +import { createToolkit } from '../../../dist/agents'; +import type { Toolkit } from '../../../dist/agents'; +import type { Resource, Chunk } from '../resources/types'; +import type { Reranker } from './types'; +import { SearchTool } from './search'; +import { ReadFileTool } from './read-file'; +import { GrepTool } from './grep'; +import { ReportTool } from './report'; + +export const reportTool = new ReportTool(); + +export function createTools(opts: { + resources: Resource[]; + chunks: Chunk[]; + reranker: Reranker; +}): Toolkit { + return createToolkit([ + new SearchTool(opts.chunks, opts.reranker), + new ReadFileTool(opts.resources), + new GrepTool(opts.resources), + reportTool, + ]); +} diff --git a/examples/deep-research/tools/read-file.ts b/examples/deep-research/tools/read-file.ts new file mode 100644 index 0000000..164a5c5 --- /dev/null +++ b/examples/deep-research/tools/read-file.ts @@ -0,0 +1,41 @@ +import { Tool } from '../../../dist/agents'; +import type { JsonSchema } from '../../../dist/agents'; +import type { Resource } from '../resources/types'; + +export class ReadFileTool extends Tool<{ filename: string; startLine?: number; endLine?: number }> { + readonly name = 'read_file'; + readonly description = 'Read content from a file at specific line ranges. Use startLine/endLine from search results.'; + readonly parameters: JsonSchema; + + private _resources: Resource[]; + + constructor(resources: Resource[]) { + super(); + this._resources = resources; + this.parameters = { + type: 'object', + properties: { + filename: { + type: 'string', + description: 'Filename from search results', + enum: resources.map(r => r.name), + }, + startLine: { type: 'number', description: 'Start line (1-indexed, from search results)' }, + endLine: { type: 'number', description: 'End line (1-indexed, from search results)' }, + }, + required: ['filename'], + }; + } + + async execute(args: { filename: string; startLine?: number; endLine?: number } & Record): Promise { + const filename = args.filename || (args.path as string) || ''; + const file = this._resources.find(r => r.name === filename); + if (!file) { + return { error: `File not found: ${filename}. Available: ${this._resources.map(r => r.name).join(', ')}` }; + } + const lines = file.content.split('\n'); + const s = Math.max(0, (args.startLine ?? 1) - 1); + const e = Math.min(lines.length, args.endLine ?? Math.min(100, lines.length)); + return { file: file.name, content: lines.slice(s, e).join('\n') }; + } +} diff --git a/examples/deep-research/tools/report.ts b/examples/deep-research/tools/report.ts new file mode 100644 index 0000000..97f061a --- /dev/null +++ b/examples/deep-research/tools/report.ts @@ -0,0 +1,14 @@ +import { Tool } from '../../../dist/agents'; +import type { JsonSchema } from '../../../dist/agents'; + +export class ReportTool extends Tool<{ findings: string }> { + readonly name = 'report'; + readonly description = 'Submit your final research findings. Call this when you have gathered enough information to answer the question.'; + readonly parameters: JsonSchema = { + type: 'object', + properties: { findings: { type: 'string', description: 'Your research findings and answer' } }, + required: ['findings'], + }; + + async execute(): Promise { return {}; } +} diff --git a/examples/deep-research/tools/search.ts b/examples/deep-research/tools/search.ts new file mode 100644 index 0000000..034bc55 --- /dev/null +++ b/examples/deep-research/tools/search.ts @@ -0,0 +1,34 @@ +import { Tool } from '../../../dist/agents'; +import type { JsonSchema, ToolContext } from '../../../dist/agents'; +import type { Chunk } from '../resources/types'; +import type { Reranker } from './types'; + +export class SearchTool extends Tool<{ query: string }> { + readonly name = 'search'; + readonly description = 'Search the knowledge base. Returns sections ranked by relevance with line ranges for read_file.'; + readonly parameters: JsonSchema = { + type: 'object', + properties: { query: { type: 'string', description: 'Search query' } }, + required: ['query'], + }; + + private _chunks: Chunk[]; + private _reranker: Reranker; + + constructor(chunks: Chunk[], reranker: Reranker) { + super(); + this._chunks = chunks; + this._reranker = reranker; + } + + async execute(args: { query: string }, context?: ToolContext): Promise { + const query = args.query?.trim(); + if (!query) return { error: 'query must not be empty' }; + let last; + for await (const { results, filled, total } of this._reranker.score(query, this._chunks)) { + if (context?.onProgress) context.onProgress({ filled, total }); + last = results; + } + return last; + } +} diff --git a/examples/deep-research/tools/types.ts b/examples/deep-research/tools/types.ts new file mode 100644 index 0000000..3f0012a --- /dev/null +++ b/examples/deep-research/tools/types.ts @@ -0,0 +1,21 @@ +import type { Chunk } from '../resources/types'; + +export interface ScoredChunk { + file: string; + heading: string; + score: number; + startLine: number; + endLine: number; +} + +export interface ScoredResult { + results: ScoredChunk[]; + filled: number; + total: number; +} + +export interface Reranker { + score(query: string, chunks: Chunk[]): AsyncIterable; + tokenizeChunks(chunks: Chunk[]): Promise; + dispose(): void; +} diff --git a/examples/deep-research/tui.ts b/examples/deep-research/tui.ts new file mode 100644 index 0000000..f720095 --- /dev/null +++ b/examples/deep-research/tui.ts @@ -0,0 +1,500 @@ +import * as fs from 'node:fs'; +import { each } from 'effection'; +import type { Channel, Operation } from 'effection'; +import type { AgentEvent, AgentPoolResult, DivergeResult } from '../../dist/agents'; +import type { AgreementResult } from './agreement'; + +// ── Event types ────────────────────────────────────────────────── + +export interface OpTiming { + label: string; + tokens: number; + detail: string; + timeMs: number; +} + +export type StepEvent = + | { type: 'query'; query: string; warm: boolean } + | { type: 'plan'; questions: string[]; tokenCount: number; timeMs: number } + | { type: 'research:start'; agentCount: number } + | { type: 'research:done'; pool: AgentPoolResult; timeMs: number } + | { type: 'verify:start'; count: number } + | { type: 'verify:done'; result: DivergeResult; timeMs: number } + | { type: 'verify:agreement'; result: AgreementResult } + | { type: 'eval:done'; converged: boolean | null; tokenCount: number; timeMs: number } + | { type: 'answer'; text: string } + | { type: 'response:start' } + | { type: 'response:text'; text: string } + | { type: 'response:done' } + | { type: 'stats'; timings: OpTiming[]; kvLine?: string; ctxPct: number; ctxPos: number; ctxTotal: number } + | { type: 'complete'; data: Record }; + +export type WorkflowEvent = AgentEvent | StepEvent; + +// ── Mode + color ───────────────────────────────────────────────── + +let _jsonlMode = false; +let _verboseMode = false; + +export function setJsonlMode(on: boolean): void { _jsonlMode = on; } +export function setVerboseMode(on: boolean): void { _verboseMode = on; } + +const isTTY = process.stdout.isTTY; + +export const c = isTTY ? { + bold: '\x1b[1m', dim: '\x1b[2m', reset: '\x1b[0m', + green: '\x1b[32m', cyan: '\x1b[36m', yellow: '\x1b[33m', red: '\x1b[31m', +} : { bold: '', dim: '', reset: '', green: '', cyan: '', yellow: '', red: '' }; + +// ── Primitives ─────────────────────────────────────────────────── + +let _statusText = ''; + +function status(text: string): void { + if (_jsonlMode || !isTTY) return; + _statusText = text; + process.stdout.write('\r\x1b[K' + text); +} + +function statusClear(): void { + if (!_statusText) return; + _statusText = ''; + process.stdout.write('\r\x1b[K'); +} + +export const log = (...a: unknown[]): void => { + if (_jsonlMode) return; + statusClear(); + console.log(...a); +}; + +function emit(event: string, data: Record): void { + if (_jsonlMode) console.log(JSON.stringify({ event, ...data })); +} + +export const fmtSize = (bytes: number): string => bytes > 1e9 + ? (bytes / 1e9).toFixed(1) + ' GB' + : (bytes / 1e6).toFixed(0) + ' MB'; + +const pad = (s: unknown, n: number): string => String(s).padStart(n); + +// ── View state + handler type ──────────────────────────────────── + +interface ViewState { + agentLabel: Map; + nextLabel: number; + agentText: Map; + agentStatus: Map; + agentParent: Map; // childId β†’ parentId (sub-agent tracking) + traceQuery: string; +} + +type ViewHandler = (ev: WorkflowEvent) => void; + +function isSubAgent(state: ViewState, agentId: number): boolean { + return state.agentParent.has(agentId); +} + +function parentLabel(state: ViewState, agentId: number): string { + return label(state, state.agentParent.get(agentId)!); +} + +function label(state: ViewState, agentId: number): string { + let l = state.agentLabel.get(agentId); + if (!l) { l = `A${state.nextLabel++}`; state.agentLabel.set(agentId, l); } + return l; +} + +function resetLabels(state: ViewState): void { + state.nextLabel = 0; + state.agentLabel.clear(); + state.agentStatus.clear(); + state.agentText.clear(); + state.agentParent.clear(); +} + +function renderStatus(state: ViewState): void { + const active = [...state.agentStatus.entries()] + .filter(([id, s]) => s.state !== 'done' && !isSubAgent(state, id)); + if (active.length === 0) return; + + const generating = active.filter(([, s]) => s.state === 'gen'); + if (generating.length === 1 && active.length === 1) { + const [id] = generating[0]; + const raw = (state.agentText.get(id) ?? '').replace(/\n/g, ' ').trimStart(); + const cols = process.stdout.columns || 80; + const maxLen = cols - 12; + const text = raw.length > maxLen ? raw.slice(raw.length - maxLen) : raw; + status(` ${c.dim}\u25c6${c.reset} ${c.yellow}${label(state, id)}${c.reset} ${text}`); + return; + } + + const parts = active.map(([id, s]) => { + const lbl = `${c.yellow}${label(state, id)}${c.reset}`; + if (s.state === 'gen') return `${lbl}: ${s.tokenCount} tok`; + const detail = s.detail ? ` ${s.detail}` : ''; + return `${lbl}: ${c.cyan}${s.state}${c.reset}${detail}`; + }); + status(` ${c.dim}\u25c6${c.reset} ${parts.join(' ')}`); +} + +// ── View handlers ──────────────────────────────────────────────── + +function queryHandler(state: ViewState, opts: ViewOpts): ViewHandler { + return (ev) => { + if (ev.type !== 'query') return; + state.traceQuery = ev.query; + if (!ev.warm) { + emit('start', { + model: opts.model, reranker: opts.reranker, query: ev.query, + agentCount: opts.agentCount, verifyCount: opts.verifyCount, chunks: opts.chunkCount, + }); + log(); + log(` ${c.dim}Query${c.reset}`); + log(` ${c.bold}${ev.query}${c.reset}`); + } + }; +} + +function planHandler(): ViewHandler { + return (ev) => { + if (ev.type !== 'plan') return; + emit('plan', { questions: ev.questions, planTokens: ev.tokenCount }); + log(`\n ${c.green}\u25cf${c.reset} ${c.bold}Plan${c.reset} ${c.dim}${ev.tokenCount} tok \u00b7 ${(ev.timeMs / 1000).toFixed(1)}s${c.reset}`); + ev.questions.forEach((q: string, i: number) => log(` ${c.dim}${i + 1}.${c.reset} ${q}`)); + }; +} + +function agentHandler(state: ViewState): ViewHandler { + return (ev) => { + switch (ev.type) { + case 'agent:spawn': { + // If parent is a known labeled agent, this is a sub-agent + if (state.agentLabel.has(ev.parentAgentId)) { + state.agentParent.set(ev.agentId, ev.parentAgentId); + } + break; + } + case 'agent:produce': { + const sub = isSubAgent(state, ev.agentId); + state.agentText.set(ev.agentId, (state.agentText.get(ev.agentId) ?? '') + ev.text); + state.agentStatus.set(ev.agentId, { state: 'gen', tokenCount: ev.tokenCount, detail: '' }); + if (sub) break; // sub-agents: skip verbose/status output + if (_verboseMode) { + const lbl = label(state, ev.agentId); + if (ev.tokenCount === 1) { + statusClear(); + process.stdout.write(`\n ${c.dim}───${c.reset} ${c.yellow}${lbl}${c.reset} ${c.dim}tokens${c.reset} ${c.dim}───${c.reset}\n `); + } + process.stdout.write(ev.text); + } else { + renderStatus(state); + } + break; + } + case 'agent:tool_call': { + const sub = isSubAgent(state, ev.agentId); + if (_verboseMode && !sub) process.stdout.write('\n'); + state.agentText.delete(ev.agentId); + state.agentStatus.set(ev.agentId, { state: ev.tool, tokenCount: 0, detail: '' }); + emit('tool_call', { agentId: ev.agentId, toolName: ev.tool, arguments: ev.args }); + let toolArgs: Record; + try { toolArgs = JSON.parse(ev.args); } catch { toolArgs = {}; } + const argSummary = ev.tool === 'search' + ? `"${toolArgs.query || ''}"` + : ev.tool === 'grep' + ? `/${toolArgs.pattern || ''}/` + : ev.tool === 'report' ? '' + : `${toolArgs.filename}` + (toolArgs.startLine ? ` L${toolArgs.startLine}-${toolArgs.endLine}` : ''); + if (sub) { + const plbl = `${c.yellow}${parentLabel(state, ev.agentId)}${c.reset}`; + log(` ${c.dim}\u2502${c.reset} ${c.dim}\u2514${c.reset} ${plbl} ${c.cyan}${ev.tool}${c.reset}${argSummary ? `(${argSummary})` : ''}`); + } else { + log(` ${c.dim}\u251c${c.reset} ${c.yellow}${label(state, ev.agentId)}${c.reset} ${c.cyan}${ev.tool}${c.reset}${argSummary ? `(${argSummary})` : ''}`); + } + break; + } + case 'agent:tool_result': { + emit('tool_result', { + agentId: ev.agentId, toolName: ev.tool, + result: ev.result.length > 200 ? ev.result.slice(0, 200) + '...' : ev.result, + }); + let preview = ''; + if (ev.tool === 'read_file') { + try { + const firstLine = (JSON.parse(ev.result) as { content: string }).content.split('\n').find((l: string) => l.trim()); + if (firstLine) preview = ` \u00b7 ${firstLine.trim().slice(0, 60)}${firstLine.trim().length > 60 ? '\u2026' : ''}`; + } catch { /* non-fatal */ } + } else if (ev.tool === 'search') { + try { + const top = (JSON.parse(ev.result) as { heading: string }[])[0]; + if (top?.heading) preview = ` \u00b7 ${top.heading}`; + } catch { /* non-fatal */ } + } else if (ev.tool === 'grep') { + try { + const r = JSON.parse(ev.result) as { totalMatches: number; matchingLines: number }; + preview = ` \u00b7 ${r.totalMatches} matches in ${r.matchingLines} lines`; + } catch { /* non-fatal */ } + } + if (isSubAgent(state, ev.agentId)) { + const plbl = `${c.yellow}${parentLabel(state, ev.agentId)}${c.reset}`; + log(` ${c.dim}\u2502${c.reset} ${c.dim}\u2514${c.reset} ${plbl} ${c.dim}\u2190 ${ev.tool} ${ev.result.length}b${preview}${c.reset}`); + } else { + log(` ${c.dim}\u251c${c.reset} ${c.yellow}${label(state, ev.agentId)}${c.reset} ${c.dim}\u2190 ${ev.tool} ${ev.result.length}b${preview}${c.reset}`); + } + break; + } + case 'agent:tool_progress': { + state.agentStatus.set(ev.agentId, { state: ev.tool, tokenCount: 0, detail: `${ev.filled}/${ev.total}` }); + renderStatus(state); + break; + } + case 'agent:report': { + state.agentStatus.set(ev.agentId, { state: 'done', tokenCount: 0, detail: '' }); + const sub = isSubAgent(state, ev.agentId); + const cols = process.stdout.columns || 80; + const displayLabel = sub ? parentLabel(state, ev.agentId) : label(state, ev.agentId); + const lbl = `${c.yellow}${displayLabel}${c.reset}`; + const indent = sub ? ` ${c.dim}\u2502${c.reset} ` : ' '; + const prefix = `${indent}${c.dim}\u2502${c.reset} `; + const wrap = cols - (sub ? 11 : 8); + + log(`${indent}${c.dim}\u2502${c.reset}`); + log(`${indent}${c.dim}\u251c\u2500\u2500${c.reset} ${lbl} ${c.bold}findings${c.reset}`); + + for (const para of ev.findings.split('\n')) { + if (!para.trim()) { log(prefix); continue; } + const words = para.split(/\s+/); + let line = ''; + for (const word of words) { + if (line && line.length + 1 + word.length > wrap) { + log(`${prefix}${c.dim}${line}${c.reset}`); + line = word; + } else { + line = line ? `${line} ${word}` : word; + } + } + if (line) log(`${prefix}${c.dim}${line}${c.reset}`); + } + log(`${indent}${c.dim}\u2502${c.reset}`); + break; + } + case 'agent:done': + if (_verboseMode && !isSubAgent(state, ev.agentId)) process.stdout.write('\n'); + break; + } + }; +} + +function researchSummaryHandler(state: ViewState): ViewHandler { + function flushTrace(pool: AgentPoolResult): void { + if (!pool.agents.some(a => a.trace?.length)) return; + const filename = `trace-${Date.now()}.json`; + fs.writeFileSync(filename, JSON.stringify({ + query: state.traceQuery, + timestamp: new Date().toISOString(), + agents: pool.agents.map(a => ({ + agentId: a.agentId, label: label(state, a.agentId), + ppl: a.ppl, samplingPpl: a.samplingPpl, + tokenCount: a.tokenCount, toolCallCount: a.toolCallCount, + findings: a.findings, trace: a.trace ?? [], + })), + }, null, 2)); + log(` ${c.dim}Trace written to ${filename}${c.reset}`); + } + + return (ev) => { + switch (ev.type) { + case 'research:start': { + log(`\n ${c.green}\u25cf${c.reset} ${c.bold}Research${c.reset} ${c.dim}${ev.agentCount} agents${c.reset}`); + resetLabels(state); + break; + } + case 'research:done': { + statusClear(); + ev.pool.agents.forEach((a, i) => { + const tree = i === ev.pool.agents.length - 1 ? '\u2514' : '\u251c'; + emit('agent_done', { + index: i, findings: (a.findings || '').slice(0, 500), + toolCalls: a.toolCallCount, tokenCount: a.tokenCount, + ppl: a.ppl, samplingPpl: a.samplingPpl, + }); + const raw = (state.agentText.get(a.agentId) ?? '').replace(/\n/g, ' ').trim(); + if (raw) log(` ${c.dim}\u251c${c.reset} ${c.yellow}${label(state, a.agentId)}${c.reset} ${c.dim}\u25b8 ${raw.slice(0, 120)}${raw.length > 120 ? '\u2026' : ''}${c.reset}`); + const pplStr = Number.isFinite(a.ppl) ? ` \u00b7 ppl ${a.ppl.toFixed(2)}` : ''; + log(` ${c.dim}${tree}${c.reset} ${c.yellow}${label(state, a.agentId)}${c.reset} ${c.green}done${c.reset} ${c.dim}${a.tokenCount} tok \u00b7 ${a.toolCallCount} tools${pplStr}${c.reset}`); + }); + log(` ${c.dim}${ev.pool.totalTokens} tok \u00b7 ${ev.pool.totalToolCalls} tools \u00b7 ${(ev.timeMs / 1000).toFixed(1)}s${c.reset}`); + flushTrace(ev.pool); + break; + } + } + }; +} + +function verifyHandler(): ViewHandler { + let pendingAgreement: AgreementResult | null = null; + + return (ev) => { + switch (ev.type) { + case 'verify:start': { + log(`\n ${c.green}\u25cf${c.reset} ${c.bold}Verify${c.reset} ${c.dim}${ev.count} attempts${c.reset}`); + pendingAgreement = null; + break; + } + case 'verify:agreement': { + pendingAgreement = ev.result; + emit('verify_agreement', { + overall: ev.result.overall, + sections: ev.result.sections.map(s => ({ label: s.label, score: s.score })), + }); + break; + } + case 'verify:done': { + ev.result.attempts.forEach((a, i) => { + const tree = i === ev.result.attempts.length - 1 + ? (pendingAgreement ? '\u251c' : '\u2514') + : '\u251c'; + emit('attempt_done', { index: i, output: a.output.trim().slice(0, 500), tokenCount: a.tokenCount, ppl: a.ppl }); + log(` ${c.dim}${tree} ${a.tokenCount} tok \u00b7 ppl ${a.ppl.toFixed(2)}${c.reset}`); + }); + if (pendingAgreement && pendingAgreement.sections.length > 0) { + const pct = Math.round(pendingAgreement.overall * 100); + log(` ${c.dim}\u251c${c.reset} Agreement: ${c.bold}${pct}%${c.reset}`); + const sorted = [...pendingAgreement.sections].sort((a, b) => b.score - a.score); + const show = sorted.slice(0, 5); + const maxLabelLen = Math.max(...show.map(s => s.label.length)); + show.forEach((s, i) => { + const tree = i === show.length - 1 && sorted.length <= 5 ? '\u2514' : '\u251c'; + const filled = Math.round(s.score * 10); + const bar = '\u2588'.repeat(filled) + '\u2591'.repeat(10 - filled); + const sPct = pad(Math.round(s.score * 100), 3); + const label = `"${s.label}"`.padEnd(maxLabelLen + 2); + log(` ${c.dim}${tree}${c.reset} ${c.dim}${label}${c.reset} ${sPct}% ${bar}`); + }); + if (sorted.length > 5) { + log(` ${c.dim}\u2514 \u2026 ${sorted.length - 5} more${c.reset}`); + } + } + log(` ${c.dim}${ev.result.totalTokens} tok \u00b7 ${(ev.timeMs / 1000).toFixed(1)}s${c.reset}`); + pendingAgreement = null; + break; + } + } + }; +} + +function evalHandler(): ViewHandler { + return (ev) => { + if (ev.type !== 'eval:done') return; + emit('convergence', { converged: ev.converged, evalTokens: ev.tokenCount }); + const verdict = ev.converged === true ? `${c.green}yes${c.reset}` + : ev.converged === false ? `${c.red}no${c.reset}` + : `${c.yellow}unknown${c.reset}`; + log(`\n ${c.green}\u25cf${c.reset} ${c.bold}Eval${c.reset} ${c.dim}${ev.tokenCount} tok \u00b7 ${(ev.timeMs / 1000).toFixed(1)}s${c.reset}`); + log(` Converged: ${verdict}`); + }; +} + +function answerHandler(): ViewHandler { + return (ev) => { + if (ev.type !== 'answer') return; + log(`\n ${c.dim}${'\u2500'.repeat(58)}${c.reset}\n`); + const prose = ev.text.trim() + .replace(/\*\*(.+?)\*\*/g, `${c.bold}$1${c.reset}`) + .split('\n').map((l: string) => ` ${l}`).join('\n'); + log(prose); + }; +} + +function responseHandler(): ViewHandler { + return (ev) => { + switch (ev.type) { + case 'response:start': + process.stdout.write(` ${c.dim}<${c.reset} `); + break; + case 'response:text': + process.stdout.write(ev.text); + break; + case 'response:done': + console.log('\n'); + break; + } + }; +} + +function statsHandler(): ViewHandler { + return (ev) => { + if (ev.type !== 'stats') return; + const { timings, kvLine, ctxPct, ctxPos, ctxTotal } = ev; + const totalTokens = timings.reduce((s, p) => s + p.tokens, 0); + const totalMs = timings.reduce((s, p) => s + p.timeMs, 0); + + log(`\n ${c.dim}${'\u2501'.repeat(58)}${c.reset}`); + for (const p of timings) { + const left = `${p.label.padEnd(10)} ${pad(p.tokens, 5)} tok`; + const detail = p.detail ? ` ${p.detail}` : ''; + const right = p.timeMs > 0 ? `${pad((p.timeMs / 1000).toFixed(1), 6)}s` : ''; + log(` ${c.dim}${left}${detail}${' '.repeat(Math.max(1, 58 - left.length - detail.length - right.length))}${right}${c.reset}`); + } + log(` ${c.dim}${'\u2501'.repeat(58)}${c.reset}`); + log(` ${c.bold}Total${c.reset} ${c.bold}${pad(totalTokens, 5)}${c.reset} tok ${c.bold}${pad((totalMs / 1000).toFixed(1), 6)}s${c.reset}`); + if (kvLine) log(` ${c.dim}${kvLine}${c.reset}`); + if (ctxPct != null && ctxPos != null && ctxTotal != null) { + const ctxStr = `ctx: ${ctxPct}% (${ctxPos.toLocaleString()}/${ctxTotal.toLocaleString()})`; + log(` ${c.dim}${'\u2501'.repeat(58)}${c.reset}`); + log(` ${c.dim}${' '.repeat(58 - ctxStr.length)}${ctxStr}${c.reset}`); + } + log(); + }; +} + +function completeHandler(): ViewHandler { + return (ev) => { + if (ev.type !== 'complete') return; + emit('complete', ev.data); + }; +} + +// ── createView β€” composable view factory ───────────────────────── + +export interface ViewOpts { + model: string; + reranker: string; + agentCount: number; + verifyCount: number; + chunkCount: number; +} + +export function createView(opts: ViewOpts) { + const state: ViewState = { + agentLabel: new Map(), + nextLabel: 0, + agentText: new Map(), + agentStatus: new Map(), + agentParent: new Map(), + traceQuery: '', + }; + + const handlers: ViewHandler[] = [ + queryHandler(state, opts), + planHandler(), + agentHandler(state), + researchSummaryHandler(state), + verifyHandler(), + evalHandler(), + answerHandler(), + responseHandler(), + statsHandler(), + completeHandler(), + ]; + + return { + *subscribe(events: Channel): Operation { + for (const ev of yield* each(events)) { + for (const h of handlers) h(ev); + yield* each.next(); + } + }, + }; +} diff --git a/examples/embed/embed.mjs b/examples/embed/embed.ts similarity index 84% rename from examples/embed/embed.mjs rename to examples/embed/embed.ts index ce79da0..bf4ac09 100644 --- a/examples/embed/embed.mjs +++ b/examples/embed/embed.ts @@ -3,9 +3,9 @@ * Embedding extraction example using lloyal.node * * Usage: - * node embed.mjs /path/to/embedding-model.gguf # Human-readable output - * node embed.mjs /path/to/embedding-model.gguf --jsonl # JSONL output for testing - * node embed.mjs # uses default nomic-embed model path + * npx tsx embed.ts /path/to/embedding-model.gguf # Human-readable output + * npx tsx embed.ts /path/to/embedding-model.gguf --jsonl # JSONL output for testing + * npx tsx embed.ts # uses default nomic-embed model path * * This example demonstrates: * - Creating an embedding context with pooling enabled @@ -14,10 +14,8 @@ */ import * as path from 'node:path'; -import { fileURLToPath } from 'node:url'; -import { createContext } from '../../lib/index.js'; - -const __dirname = path.dirname(fileURLToPath(import.meta.url)); +import { createContext, PoolingType } from '../../dist/index.js'; +import type { SessionContext } from '../../dist/index.js'; // Default to nomic-embed-text model in fixtures const DEFAULT_MODEL = path.resolve( @@ -31,24 +29,16 @@ const jsonlMode = args.includes('--jsonl'); const modelPath = args.find(a => !a.startsWith('--')) || DEFAULT_MODEL; /** Emit output - JSONL or human-readable */ -function emit(event, data) { +function emit(event: string, data: Record): void { if (jsonlMode) { console.log(JSON.stringify({ event, ...data })); } } -// Pooling types (matches llama.cpp LLAMA_POOLING_TYPE_*) -const PoolingType = { - NONE: 0, - MEAN: 1, - CLS: 2, - LAST: 3, -}; - /** * Compute cosine similarity between two vectors */ -function cosineSimilarity(a, b) { +function cosineSimilarity(a: Float32Array, b: Float32Array): number { if (a.length !== b.length) { throw new Error('Vectors must have same dimension'); } @@ -73,7 +63,7 @@ function cosineSimilarity(a, b) { /** * Get embedding for a text */ -async function getEmbedding(ctx, text) { +async function getEmbedding(ctx: SessionContext, text: string): Promise { // Tokenize the text const tokens = await ctx.tokenize(text); @@ -89,7 +79,7 @@ async function getEmbedding(ctx, text) { return embedding; } -async function main() { +async function main(): Promise { if (!jsonlMode) { console.log('='.repeat(60)); console.log('lloyal.node Embedding Example'); @@ -134,7 +124,7 @@ async function main() { } // Get embeddings for all texts - const embeddings = []; + const embeddings: { text: string; embedding: Float32Array }[] = []; for (const text of texts) { const start = performance.now(); const embedding = await getEmbedding(ctx, text); @@ -167,7 +157,7 @@ async function main() { emit('similarity', { i, j, similarity: sim }); if (!jsonlMode) { - const bar = 'β–ˆ'.repeat(Math.round(sim * 20)); + const bar = '\u2588'.repeat(Math.round(sim * 20)); console.log(` [${i}] vs [${j}]: ${sim.toFixed(4)} ${bar}`); console.log(` "${texts[i].substring(0, 30)}..."`); console.log(` "${texts[j].substring(0, 30)}..."`); @@ -204,7 +194,7 @@ async function main() { if (!jsonlMode) { console.log('Results (ranked by similarity):\n'); ranked.forEach((result, i) => { - const bar = 'β–ˆ'.repeat(Math.round(result.similarity * 20)); + const bar = '\u2588'.repeat(Math.round(result.similarity * 20)); console.log(` ${i + 1}. ${result.similarity.toFixed(4)} ${bar}`); console.log(` "${result.text}"`); console.log(); @@ -222,7 +212,7 @@ async function main() { } main().catch((err) => { - console.error('Error:', err.message); - console.error(err.stack); + console.error('Error:', (err as Error).message); + console.error((err as Error).stack); process.exit(1); }); diff --git a/examples/entropy/entropy.mjs b/examples/entropy/entropy.ts similarity index 84% rename from examples/entropy/entropy.mjs rename to examples/entropy/entropy.ts index 7618567..cdfd5fd 100644 --- a/examples/entropy/entropy.mjs +++ b/examples/entropy/entropy.ts @@ -14,15 +14,14 @@ * * * Usage: - * node entropy.mjs [model-path] # Human-readable output - * node entropy.mjs [model-path] --jsonl # JSONL output for testing + * npx tsx entropy.ts [model-path] # Human-readable output + * npx tsx entropy.ts [model-path] --jsonl # JSONL output for testing */ import * as path from 'node:path'; -import { fileURLToPath } from 'node:url'; -import { createContext, Branch } from '../../lib/index.js'; +import { createContext, Branch } from '../../dist/index.js'; +import type { SessionContext } from '../../dist/index.js'; -const __dirname = path.dirname(fileURLToPath(import.meta.url)); const DEFAULT_MODEL = path.resolve( __dirname, '../../models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf' @@ -34,7 +33,7 @@ const jsonlMode = args.includes('--jsonl'); const modelPath = args.find(a => !a.startsWith('--')) || DEFAULT_MODEL; /** Emit output - JSONL or human-readable */ -function emit(event, data) { +function emit(event: string, data: Record): void { if (jsonlMode) { console.log(JSON.stringify({ event, ...data })); } @@ -48,7 +47,7 @@ const THETA = 1.5; // Scale factor /** * Calculate EDT temperature from entropy */ -function edtTemperature(entropy) { +function edtTemperature(entropy: number): number { const safeEntropy = Math.max(entropy, 0.1); return T0 * Math.pow(N, THETA / safeEntropy); } @@ -59,7 +58,7 @@ function edtTemperature(entropy) { * Uses Branch API with per-token setSamplerParams() for EDT adaptation. * Each token gets a temperature computed from the current logit entropy. */ -async function generate(ctx, prompt, strategy, strategyName, maxTokens = 50) { +async function generate(ctx: SessionContext, prompt: string, strategy: number | 'edt', strategyName: string, maxTokens: number = 50): Promise<{text: string; avgEntropy: number; avgTemp: number; tokenCount: number; temps: number[]; entropies: number[]}> { const messages = [{ role: 'user', content: prompt }]; const { prompt: formatted } = await ctx.formatChat(JSON.stringify(messages)); const tokens = await ctx.tokenize(formatted); @@ -68,9 +67,9 @@ async function generate(ctx, prompt, strategy, strategyName, maxTokens = 50) { const branch = Branch.create(ctx, 0, { temperature: baseTemp, topP: 0.9 }); await branch.prefill(tokens); - const output = []; - const temps = []; - const entropies = []; + const output: number[] = []; + const temps: number[] = []; + const entropies: number[] = []; for (let i = 0; i < maxTokens; i++) { const entropy = branch.modelEntropy('nats'); @@ -100,10 +99,12 @@ async function generate(ctx, prompt, strategy, strategyName, maxTokens = 50) { return { text, avgEntropy, avgTemp, tokenCount: output.length, temps, entropies }; } +type GenerateResult = Awaited>; + /** * Run comparison for a single prompt */ -async function compareStrategies(ctx, prompt, label) { +async function compareStrategies(ctx: SessionContext, prompt: string, label: string): Promise<{fixed: GenerateResult; edt: GenerateResult}> { if (!jsonlMode) { console.log(`\n${'='.repeat(70)}`); console.log(`${label}: "${prompt}"`); @@ -152,7 +153,7 @@ async function compareStrategies(ctx, prompt, label) { return { fixed, edt }; } -async function main() { +async function main(): Promise { if (!jsonlMode) { console.log('EDT vs Fixed Temperature Comparison'); console.log('Based on Zhang et al. 2024: https://arxiv.org/abs/2403.14541\n'); @@ -209,7 +210,7 @@ don't add randomness - let it output what it knows. } main().catch((err) => { - console.error('Error:', err.message); - console.error(err.stack); + console.error('Error:', (err as Error).message); + console.error((err as Error).stack); process.exit(1); }); diff --git a/examples/grammar/README.md b/examples/grammar/README.md deleted file mode 100644 index 57ac23a..0000000 --- a/examples/grammar/README.md +++ /dev/null @@ -1,77 +0,0 @@ -# Grammar-Constrained Generation with Branch Forking - -Demonstrates grammar-constrained generation using the Branch API with automatic grammar cloning on fork. - -## Run It - -```bash -node grammar.mjs -``` - -## What You'll See - -``` -Generating until "city" field... - { - "name": "John Doe", - "age": 30, - "city": - -Forking into 3 branches at branch point... - - [NYC branch]: { "name": "John Doe", "age": 30, "city": "Seattle" } - [LA branch]: { "name": "John Doe", "age": 30, "city": "Chicago" } - [Chicago branch]: { "name": "John Doe", "age": 30, "city": "LA" } -``` - -## The Branch Fork Pattern - -Grammar state is integrated into the branch and cloned automatically on fork: - -```javascript -// Create root branch with grammar constraint -const grammar = await ctx.jsonSchemaToGrammar(JSON.stringify(schema)); -const root = Branch.create(ctx, 0, params, undefined, grammar); -await root.prefill(promptTokens); - -// Generate until branch point -for (let i = 0; i < 100; i++) { - const { token, text, isStop } = await root.produce(); - if (isStop) break; - await root.commit(token); - if (accumulated.includes('"city"')) break; -} - -// Fork β€” grammar state cloned automatically -for (const city of cities) { - const child = await root.fork(); - child.reseedSampler(seed++); - - for await (const { text } of child) { - // Each branch generates independently with its own grammar state - } - await child.prune(); -} -await root.prune(); -``` - -## Why Branch Fork Here? - -For grammar-constrained branching, fork handles everything atomically: -- **KV cache**: Shared prefix, divergent-only storage per branch -- **Grammar state**: Parser position cloned automatically -- **Sampler chain**: Penalties and PRNG cloned and reseeded - -No manual KV save/load or grammar cloning needed β€” `fork()` is a single operation. - -## Key APIs - -| Method | Description | -|--------|-------------| -| `Branch.create(ctx, pos, params, nBatch, grammar)` | Create branch with grammar constraint | -| `branch.fork()` | Clone branch: KV prefix + grammar + sampler | -| `branch.reseedSampler(seed)` | Diversify forked branch's PRNG | -| `branch.produce()` | Sample grammar-valid token | -| `branch.commit(token)` | Advance grammar + KV state | -| `branch.prune()` | Clean up branch resources | -| `ctx.jsonSchemaToGrammar(json)` | Convert JSON schema to GBNF grammar | diff --git a/examples/grammar/grammar.mjs b/examples/grammar/grammar.mjs deleted file mode 100644 index 6f96f2c..0000000 --- a/examples/grammar/grammar.mjs +++ /dev/null @@ -1,177 +0,0 @@ -#!/usr/bin/env node -/** - * Grammar-constrained generation with forkable state - * - * Uses Branch API for grammar-constrained generation with tree branching. - * Grammar state is automatically cloned on fork(), so each branch can - * diverge independently while maintaining valid JSON output. - * - * Usage: - * node grammar.mjs [model-path] # Human-readable output - * node grammar.mjs [model-path] --jsonl # JSONL output for testing - */ - -import * as path from 'node:path'; -import { fileURLToPath } from 'node:url'; -import { createContext, Branch } from '../../lib/index.js'; - -const __dirname = path.dirname(fileURLToPath(import.meta.url)); -const DEFAULT_MODEL = path.resolve( - __dirname, - '../../models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf' -); - -// Parse args -const args = process.argv.slice(2); -const jsonlMode = args.includes('--jsonl'); -const modelPath = args.find(a => !a.startsWith('--')) || DEFAULT_MODEL; - -/** Emit output - JSONL or human-readable */ -function emit(event, data) { - if (jsonlMode) { - console.log(JSON.stringify({ event, ...data })); - } -} - -async function main() { - if (!jsonlMode) { - console.log(`Loading model: ${path.basename(modelPath)}`); - } - - emit('start', { model: path.basename(modelPath) }); - - const nCtx = parseInt(process.env.LLAMA_CTX_SIZE || '2048', 10); - const ctx = await createContext({ - modelPath, - nCtx, - nSeqMax: 4, - }); - - // JSON schema with enum for branching demo - const schema = { - type: 'object', - properties: { - name: { type: 'string' }, - age: { type: 'number' }, - city: { enum: ['NYC', 'LA', 'Chicago', 'Seattle'] }, - }, - required: ['name', 'age', 'city'], - }; - - if (!jsonlMode) { - console.log('\nJSON Schema:'); - console.log(JSON.stringify(schema, null, 2)); - } - - const grammar = await ctx.jsonSchemaToGrammar(JSON.stringify(schema)); - if (!jsonlMode) { - console.log('\nGBNF Grammar (first 200 chars):'); - console.log(grammar.slice(0, 200) + '...\n'); - } - - const prompt = 'Generate a person as JSON:\n'; - if (!jsonlMode) { - console.log(`Prompt: "${prompt}"`); - } - - const tokens = await ctx.tokenize(prompt); - - // Root branch with grammar constraint β€” grammar state cloned automatically on fork() - const root = Branch.create(ctx, 0, { temperature: 0.7, topP: 0.9 }, undefined, grammar); - await root.prefill(tokens); - - // ===== PHASE 1: Generate until we see "city" key ===== - if (!jsonlMode) { - console.log('\nGenerating until "city" field...'); - process.stdout.write(' '); - } - - let accumulated = ''; - - for (let i = 0; i < 100; i++) { - const { token, text, isStop } = await root.produce(); - if (isStop) break; - - accumulated += text; - if (!jsonlMode) { - process.stdout.write(text); - } - emit('token', { phase: 'prefix', token, text }); - - await root.commit(token); - - // Stop when we see "city": - we want to branch here - if (accumulated.includes('"city"')) { - break; - } - } - if (!jsonlMode) { - console.log('\n'); - } - - // ===== PHASE 2: Fork and complete with different branches ===== - const cities = ['NYC', 'LA', 'Chicago']; - if (!jsonlMode) { - console.log(`Forking into ${cities.length} branches at branch point...\n`); - } - - emit('branch_point', { prefix: accumulated, position: root.position }); - - const results = []; - for (const city of cities) { - const child = await root.fork(); - child.reseedSampler(results.length + 42); - - let branchText = ''; - for (let i = 0; i < 30; i++) { - const { token, text, isStop } = await child.produce(); - if (isStop) break; - - branchText += text; - emit('token', { phase: 'branch', city, token, text }); - - await child.commit(token); - } - - const fullOutput = accumulated + branchText; - results.push({ city, output: fullOutput }); - - if (!jsonlMode) { - console.log(` [${city} branch]: ${fullOutput}`); - } - emit('branch_complete', { city, output: fullOutput }); - - await child.prune(); - } - - await root.prune(); - - // Validate JSON outputs - let validJsonCount = 0; - for (const b of results) { - try { - JSON.parse(b.output); - validJsonCount++; - } catch { - // Invalid JSON - } - } - - emit('complete', { - branchCount: results.length, - validJsonCount, - branches: results.map(b => ({ city: b.city, output: b.output })), - }); - - ctx.dispose(); - - if (!jsonlMode) { - console.log('\nDone.'); - } -} - -main().catch((err) => { - console.error('Error:', err.message); - console.error(err.stack); - process.exit(1); -}); diff --git a/examples/speculative/README.md b/examples/speculative/README.md deleted file mode 100644 index 7433c77..0000000 --- a/examples/speculative/README.md +++ /dev/null @@ -1,117 +0,0 @@ -# Speculative Decoding with Branch API - -Demonstrates speculative decoding using the Branch primitive: fork a draft, verify, accept/reject, sample bonus token. - -## Run It - -```bash -node speculative.mjs -``` - -## What You'll See - -``` -Prompt: "The quick brown fox" - -Generating 30 tokens with speculative decoding... - -The quick brown fox jumps over the lazy dog. The dog... - -================================================== -Statistics -================================================== - Iterations: 13 - Tokens drafted: 48 - Tokens accepted: 6 - Accept rate: 12.5% - Output tokens: 30 -``` - -## How It Works - -| Phase | What Happens | -|-------|--------------| -| **1. MAIN** | Create main branch tracking committed state | -| **2. FORK** | Fork draft branch (shares KV prefix with main) | -| **3. DRAFT** | produce/commit N tokens on draft branch | -| **4. VERIFY** | Check draft confidence (entropy threshold) | -| **5. PRUNE** | Remove draft branch (cleans up divergent KV) | -| **6. ACCEPT** | Commit accepted tokens to main branch | -| **7. BONUS** | Sample one token from main at rejection point | - -## Key Pattern: Fork/Draft/Verify with Branch API - -```javascript -// Main branch tracks committed state -const main = Branch.create(ctx, 0, { temperature: 0.7 }); -await main.prefill(promptTokens); - -while (output.length < maxTokens) { - // Fork draft from main β€” shares KV prefix - const draft = await main.fork(); - draft.reseedSampler(iteration); - - // Draft N tokens - const drafts = []; - for (let i = 0; i < N; i++) { - const entropy = ctx.modelEntropy('nats', draft.getLogits()); - const { token, text, isStop } = draft.produceSync(); - if (isStop) break; - drafts.push({ token, text, entropy }); - await draft.commit(token); - } - - // Verify and prune draft - const acceptedCount = verify(drafts); - await draft.prune(); - - // Commit accepted tokens to main - for (const d of drafts.slice(0, acceptedCount)) { - await main.commit(d.token); - } - - // Bonus token from main at rejection point - if (acceptedCount < drafts.length) { - const { token } = main.produceSync(); - await main.commit(token); - } -} -await main.prune(); -``` - -## Why Branch API? - -The produce/commit separation is what makes speculative decoding natural: - -- **produce()** samples without writing to KV β€” inspect before deciding -- **commit()** accepts + decodes β€” advance state only for accepted tokens -- **fork()** shares KV prefix β€” draft branch doesn't duplicate the prompt -- **prune()** removes divergent KV β€” clean rejection without manual bookkeeping - -## Key APIs - -| Method | Description | -|--------|-------------| -| `Branch.create(ctx, pos, params)` | Create branch at position | -| `branch.fork()` | Fork: shared KV prefix + cloned sampler | -| `branch.produce()` | Sample without KV write | -| `branch.commit(token)` | Accept + decode into KV | -| `branch.prune()` | Remove divergent KV entries | -| `branch.reseedSampler(seed)` | Diversify forked branch | -| `ctx.modelEntropy('nats', logits)` | Check draft confidence | - -## Accept Rate - -The accept rate determines speedup: - -| Accept Rate | Meaning | -|-------------|---------| -| High (>70%) | Draft model matches target well - good speedup | -| Low (<30%) | Draft model diverges - minimal speedup | - -This example uses entropy-based simulation (not a real draft model), so accept rates are low. With a properly trained draft model, rates of 60-80% are achievable. - -## References - -- [Leviathan et al. 2023](https://arxiv.org/abs/2211.17192) - "Fast Inference from Transformers via Speculative Decoding" -- [Chen et al. 2023](https://arxiv.org/abs/2302.01318) - "Accelerating LLM Decoding with Speculative Sampling" diff --git a/examples/speculative/speculative.mjs b/examples/speculative/speculative.mjs deleted file mode 100644 index 93bc111..0000000 --- a/examples/speculative/speculative.mjs +++ /dev/null @@ -1,271 +0,0 @@ -#!/usr/bin/env node -/** - * Speculative Decoding with Branch API - * - * This example demonstrates speculative decoding using the Branch primitive: - * - Main branch tracks committed state - * - Fork a draft branch for speculative generation - * - Prune draft on rejection, commit accepted tokens to main - * - Sample bonus token from main at rejection point - * - * Real speculative decoding uses a small "draft" model and large "target" model. - * This example uses the same model for both (demonstrating the mechanics, not speedup). - * - * Branch API Benefits: - * - Atomic fork: KV + logits + sampler + perplexity cloned together - * - produce/commit separation: sample without KV write, then commit - * - Shared prefix: forked branches share KV for common prefix - * - Clean cleanup: prune() removes divergent KV entries - * - * References: - * - Leviathan et al. 2023 "Fast Inference from Transformers via Speculative Decoding" - * - Chen et al. 2023 "Accelerating Large Language Model Decoding with Speculative Sampling" - * - * Usage: - * node speculative.mjs [model-path] # Human-readable output - * node speculative.mjs [model-path] --jsonl # JSONL output for testing - */ - -import * as path from 'node:path'; -import { fileURLToPath } from 'node:url'; -import { createContext, Branch } from '../../lib/index.js'; - -const __dirname = path.dirname(fileURLToPath(import.meta.url)); -const DEFAULT_MODEL = path.resolve( - __dirname, - '../../models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf' -); - -// Parse args -const args = process.argv.slice(2); -const jsonlMode = args.includes('--jsonl'); -const modelPath = args.find((a) => !a.startsWith('--')) || DEFAULT_MODEL; - -/** Emit output - JSONL or human-readable */ -function emit(event, data) { - if (jsonlMode) { - console.log(JSON.stringify({ event, ...data })); - } -} - -/** - * Simulate speculative decoding verification - * - * In real speculative decoding: - * - Draft model generates N tokens quickly (small model or n-gram) - * - Target model scores all N tokens in a single batch - * - Compare: if target agrees with draft, accept; else reject and use target's token - * - * Here we simulate by accepting tokens with probability based on draft confidence. - */ -function simulateVerification(drafts) { - // In production: compare draft probabilities to target probabilities - // Here: accept high-confidence drafts (low entropy), reject uncertain ones - let accepted = 0; - - for (const draft of drafts) { - // Simulate: accept if draft was "confident" (entropy < threshold) - // Real implementation would compare P_target(token) vs P_draft(token) - if (draft.entropy < 2.0) { - accepted++; - } else { - break; // First rejection stops the chain - } - } - - return accepted; -} - -async function main() { - const DRAFT_COUNT = 4; - const GENERATION_LENGTH = 30; - - if (!jsonlMode) { - console.log('Speculative Decoding Demo (Branch API)'); - console.log('======================================\n'); - console.log(`Loading model: ${path.basename(modelPath)}`); - } - - emit('start', { - model: path.basename(modelPath), - draftCount: DRAFT_COUNT, - generationLength: GENERATION_LENGTH, - }); - - const nCtx = parseInt(process.env.LLAMA_CTX_SIZE || '2048', 10); - const ctx = await createContext({ - modelPath, - nCtx, - nSeqMax: 4, // Enable multi-sequence for fork/verify pattern - }); - - const prompt = 'The quick brown fox'; - if (!jsonlMode) { - console.log(`\nPrompt: "${prompt}"`); - } - - // Prefill prompt via main branch - const promptTokens = await ctx.tokenize(prompt); - - const main = Branch.create(ctx, 0, { - temperature: 0.7, // For bonus token sampling - }); - await main.prefill(promptTokens); - - const output = []; - let totalDrafted = 0; - let totalAccepted = 0; - let iterations = 0; - - if (!jsonlMode) { - console.log( - `\nGenerating ${GENERATION_LENGTH} tokens with speculative decoding...\n` - ); - process.stdout.write(prompt); - } - - while (output.length < GENERATION_LENGTH) { - iterations++; - - // === DRAFT PHASE === - // Fork main branch for speculative drafting - // Draft branch shares KV prefix with main, diverges as it generates - const draft = await main.fork(); - draft.reseedSampler(iterations); // Different seed each iteration for diversity - - const drafts = []; - - for (let i = 0; i < DRAFT_COUNT && output.length + drafts.length < GENERATION_LENGTH; i++) { - // Get entropy BEFORE sampling (from draft branch's logits snapshot) - const entropy = draft.modelEntropy('nats'); - - // produce() samples from captured logits (no KV write yet) - const { token, text, isStop } = draft.produceSync(); - - if (isStop) break; - - drafts.push({ token, text, entropy }); - - // commit() accepts token + decodes + captures new logits - await draft.commit(token); - } - - if (drafts.length === 0) { - await draft.prune(); - break; - } - totalDrafted += drafts.length; - - // === VERIFY PHASE === - // Simulate verification - in production this compares draft vs target distributions - const acceptedCount = simulateVerification(drafts); - totalAccepted += acceptedCount; - - // === CLEANUP DRAFT === - // Prune draft branch - removes its divergent KV entries - // Main branch is unchanged (still at pre-draft position) - await draft.prune(); - - // === ACCEPT PHASE === - // Commit accepted tokens to main branch - const accepted = drafts.slice(0, acceptedCount); - for (const d of accepted) { - await main.commit(d.token); - if (!jsonlMode) { - process.stdout.write(d.text); - } - emit('token', { - token: d.token, - text: d.text, - entropy: d.entropy, - accepted: true, - }); - output.push(d.token); - } - - // === BONUS TOKEN === - // If we rejected some drafts, sample a bonus token from main - // Main is now at the accepted position with fresh logits - const rejected = drafts.slice(acceptedCount); - if (rejected.length > 0) { - // produce() samples from main's current logits (at rejection point) - const { token: bonusToken, text: bonusText, isStop } = main.produceSync(); - - if (!isStop) { - await main.commit(bonusToken); - if (!jsonlMode) { - process.stdout.write(bonusText); - } - emit('token', { token: bonusToken, text: bonusText, bonus: true }); - output.push(bonusToken); - } - } - - emit('iteration', { - iteration: iterations, - drafted: drafts.length, - accepted: acceptedCount, - rejected: rejected.length, - hasBonus: rejected.length > 0, - }); - - // Check for natural stopping - if (output.length > 0 && ctx.isStopToken(output[output.length - 1])) { - break; - } - } - - // Cleanup main branch - await main.prune(); - - // Statistics - const acceptRate = totalDrafted > 0 ? totalAccepted / totalDrafted : 0; - - emit('complete', { - iterations, - totalDrafted, - totalAccepted, - acceptRate, - outputTokens: output.length, - }); - - if (!jsonlMode) { - console.log('\n'); - console.log('='.repeat(50)); - console.log('Statistics'); - console.log('='.repeat(50)); - console.log(` Iterations: ${iterations}`); - console.log(` Tokens drafted: ${totalDrafted}`); - console.log(` Tokens accepted: ${totalAccepted}`); - console.log(` Accept rate: ${(acceptRate * 100).toFixed(1)}%`); - console.log(` Output tokens: ${output.length}`); - - console.log('\n' + '='.repeat(50)); - console.log('How Speculative Decoding Works (Branch API)'); - console.log('='.repeat(50)); - console.log(` - 1. MAIN: Create main branch tracking committed state - 2. FORK: Fork draft branch (shares KV prefix with main) - 3. DRAFT: produce/commit N tokens on draft branch - 4. VERIFY: Check draft confidence (entropy threshold) - 5. PRUNE: Remove draft branch (cleans up divergent KV) - 6. COMMIT: Apply accepted tokens to main branch - 7. BONUS: Sample one token from main at rejection point - 8. REPEAT: Continue from main's new position - - Branch API Advantages: - - Atomic fork: KV + logits + sampler copied together - - Shared prefix: Only divergent KV uses extra memory - - Clean separation: produce() samples, commit() writes - - Easy cleanup: prune() handles KV removal -`); - } - - ctx.dispose(); -} - -main().catch((err) => { - console.error('Error:', err.message); - console.error(err.stack); - process.exit(1); -}); diff --git a/examples/streaming/README.md b/examples/streaming/README.md deleted file mode 100644 index 2352f60..0000000 --- a/examples/streaming/README.md +++ /dev/null @@ -1,217 +0,0 @@ -# Streaming Examples - -Advanced streaming patterns for long-form generation with quality preservation. - -## Examples Overview - -| Example | Purpose | Key Pattern | -|---------|---------|-------------| -| `streaming.mjs` | Infinite context generation | BlinkKV reseeding | -| `streaming-tsampler.mjs` | TypeScript sampling with N-gram tracking | TTA (Test-Time Alignment) | -| `streaming-summary.mjs` | Dynamic summary sinks | BlinkKV + summary sidecar | - ---- - -## streaming.mjs - BlinkKV Infinite Context - -Demonstrates generating beyond the context window limit using the BlinkKV reseeding pattern. - -### Usage - -```bash -node streaming.mjs /path/to/model.gguf -``` - -### Parameters (from BlinkKV paper) - -| Parameter | Value | Description | -|-----------|-------|-------------| -| Context size | 2048 | Model's context window | -| Sink tokens | prompt | Structural anchor (entire prompt) | -| Tail size | 256 | Most recent tokens to retain | - -### BlinkKV Pattern - -When the KV cache fills: -1. **Clear** the entire KV cache -2. **Re-decode sinks** (prompt tokens) at positions [0..N] -3. **Re-decode tail** (256 most recent) at positions [N+1..N+256] -4. **Continue** from position N+257 - -This maintains cache-local position contiguity, which is necessary and sufficient for streaming quality. - -### Key APIs - -| Method | Description | -|--------|-------------| -| `clearAndReseed(sinks, tail)` | Clear cache, re-decode at local positions | -| `modelSurprisal(token)` | Measure prediction error | -| `createPerplexityTracker()` | Track quality across stream | - ---- - -## streaming-tsampler.mjs - TypeScript Sampling with N-gram Tracking - -Demonstrates using tsampler (TypeScript sampling library) with N-gram sequence tracking for repetition detection. - -### Usage - -```bash -node streaming-tsampler.mjs /path/to/model.gguf -``` - -### Architecture - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ Native Context (llama.cpp) β”‚ -β”‚ - KV cache management β”‚ -β”‚ - Logits computation via decode() β”‚ -β”‚ - BlinkKV reseeding β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ - β”‚ ctx.getLogits() - β–Ό -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ tsampler (TypeScript) β”‚ -β”‚ - sampleWithStrategy() for token selection β”‚ -β”‚ - Temperature, top-p, top-k filtering β”‚ -β”‚ - Xoroshiro128Plus PRNG for reproducibility β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ - β”‚ sampled token - β–Ό -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ NgramTracker (App-level) β”‚ -β”‚ - Tracks N-gram sequences (configurable N) β”‚ -β”‚ - Threshold-based blocking (block after K repeats) β”‚ -β”‚ - Logit steering: blocked token β†’ -Infinity β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -### Key Insight: Token vs Sequence Penalties - -llama.cpp's built-in repetition penalties operate at the **token level**, penalizing individual words regardless of context. This degrades prose quality over long generations as common words ("the", "is", "a") accumulate penalties. - -Instead, tsampler + N-gram tracking operates at the **sequence level**: -- Only blocks when an exact N-token sequence repeats -- Threshold-based: only blocks after K occurrences (not first occurrence) -- Preserves natural word reuse while preventing actual loops - -### tsampler Integration - -```javascript -import { - sampleWithStrategy, - Xoroshiro128Plus, - SamplerWorkspace, -} from 'tsampler'; - -const prng = new Xoroshiro128Plus(42); // Deterministic seed -const workspace = new SamplerWorkspace(256); - -// Get logits from native layer -const logits = new Float32Array(ctx.getLogits()); - -// Apply N-gram blocking before sampling -const blockedToken = ngramTracker.getBlockedToken(); -if (blockedToken !== null) { - logits[blockedToken] = -Infinity; -} - -// Sample with tsampler -const token = sampleWithStrategy(logits, { - params: { temperature: 0.8, topP: 0.9 }, - workspace, - prng, -}); -``` - -### Configuration - -| Parameter | Default | Description | -|-----------|---------|-------------| -| `NGRAM_SIZE` | 6 | N-gram length for sequence tracking | -| `BLOCK_THRESHOLD` | 2 | Block after K occurrences of same pattern | - ---- - -## streaming-summary.mjs - Dynamic Summary Sinks - -Extends BlinkKV with a slim-summary sidecar that generates cumulative summaries of evicted content. Summaries become sink tokens on reseed, giving the model compressed semantic memory of what it generated beyond the visible tail. - -### Usage - -```bash -node streaming-summary.mjs /path/to/model.gguf -node streaming-summary.mjs /path/to/model.gguf --jsonl -``` - -### Architecture - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ Main Context (llama.cpp) β”‚ -β”‚ - KV cache management + BlinkKV reseeding β”‚ -β”‚ - Token generation loop β”‚ -β”‚ - clearAndReseed(sinks, tail) with dynamic sinks β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ - β”‚ evicted text β”‚ reseed - β–Ό β–² sink tokens -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ Summary Sidecar (slim-summary) β”‚β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -β”‚ - slim-summarize.gguf (1.7GB) β”‚ -β”‚ - Prompt: / β”‚ -β”‚ - Output: Python-style list β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ - -After reseed, KV cache layout: -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ anchor β”‚ summary β”‚ tail β”‚ -β”‚ (prompt) β”‚ (evictedβ†’) β”‚ (256 recent) β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -### Sidecar Prompt Format - -The slim-summarize model uses a specific prompt format: - -``` -: {text} - key points (5) -: -``` - -Output is a Python-style list: `['point1', 'point2', 'point3']` - -When budget is tight, uses `brief description (1)` for a single cohesive summary. - -### Budget Management - -| Concept | Formula | -|---------|---------| -| Max sink tokens | `nCtx * sinkBudgetRatio` (default 0.4 = 819 tokens) | -| Summary budget | `maxSinkTokens - anchorTokens.length` | -| Over budget? | Re-summarize with `brief description (1)`, maxTokens=100 | - -### Configuration - -| Parameter | Default | Description | -|-----------|---------|-------------| -| `TAIL_SIZE` | 256 | Most recent tokens to retain | -| `TARGET_TOKENS` | 5000 | Total tokens to generate | -| `sinkBudgetRatio` | 0.4 | Fraction of context allocated to sinks | -| `summaryMaxTokens` | 200 | Max tokens for summary generation | - -### Key APIs - -| Method | Description | -|--------|-------------| -| `clearAndReseed(sinks, tail)` | Clear cache, re-decode sinks + tail | -| `tokenize(text)` | Tokenize summary text for sink injection | -| `kvCacheClear()` | Clear sidecar KV before each summary | -| `formatChat(messages)` | Format anchor message with chat template | - ---- - -## References - -1. Han et al. 2024 - "LM-Infinite: Zero-Shot Extreme Length Generalization" (BlinkKV) diff --git a/examples/streaming/streaming-summary.mjs b/examples/streaming/streaming-summary.mjs deleted file mode 100644 index 27221ca..0000000 --- a/examples/streaming/streaming-summary.mjs +++ /dev/null @@ -1,552 +0,0 @@ -#!/usr/bin/env node -/** - * Infinite context generation with dynamic summary sinks - * - * Usage: - * node streaming-summary.mjs [model-path] # Self-summary (default) - * node streaming-summary.mjs [model-path] --sidecar # Use slim-summarize sidecar - * node streaming-summary.mjs [model-path] --jsonl # JSONL output for testing - * - * This example demonstrates: - * - BlinkKV reseeding with ghostwritten progress sinks - * - Self-summary: main model summarizes its own evicted content (default) - * - Sidecar mode: optional slim-summarize model for summarization (--sidecar) - * - Outline detection with structural progress tracking - * - Pattern matching (not instruction following) to guide continuation - * - Branch API for generation (produce/commit loop) - * - * After reseed, KV cache contains: [progress][tail] - * - progress = minimal anchor + checklist of done/current sections + summary - * - tail = recent 256 tokens for continuity - * - * The progress sink uses "done" / "continue from here" markers that the - * model pattern-matches against, rather than relying on instruction following. - */ - -import * as fs from 'node:fs'; -import * as path from 'node:path'; -import { fileURLToPath } from 'node:url'; -import { createContext, Branch } from '../../lib/index.js'; - -const __dirname = path.dirname(fileURLToPath(import.meta.url)); -const DEFAULT_MODEL = path.resolve( - __dirname, - '../../models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf' -); -const SUMMARY_MODEL = path.resolve( - __dirname, - '../../models/slim-summarize.gguf' -); - -// Parse args -const args = process.argv.slice(2); -const jsonlMode = args.includes('--jsonl'); -const useSidecar = args.includes('--sidecar'); -const modelPath = args.find(a => !a.startsWith('--')) || DEFAULT_MODEL; - -// Parse --max-tokens for CI (default 5000) -const maxTokensArg = args.find(a => a.startsWith('--max-tokens=')); -const TARGET_TOKENS = maxTokensArg ? parseInt(maxTokensArg.split('=')[1], 10) : 5000; - -/** Emit output - JSONL or human-readable */ -function emit(event, data) { - if (jsonlMode) { - console.log(JSON.stringify({ event, ...data })); - } -} - -/** - * Parse slim-summarize output (Python-style list) into readable text - */ -function parseSummaryOutput(raw) { - // Output is Python-style list: ['point1', 'point2', ...] - // Items may contain apostrophes (e.g., "It's"), so we can't match between quotes. - // Instead, strip outer brackets + quotes, then split on the item boundary: ', ' - let inner = raw.trim(); - if (inner.startsWith('[')) inner = inner.slice(1); - if (inner.endsWith(']')) inner = inner.slice(0, -1); - inner = inner.trim(); - if (inner.startsWith("'") || inner.startsWith('"')) inner = inner.slice(1); - if (inner.endsWith("'") || inner.endsWith('"')) inner = inner.slice(0, -1); - - if (!inner) return raw.trim(); - - // Split on quote-comma-quote boundaries (handles apostrophes within items) - const items = inner.split(/['"]\s*,\s*['"]/) - .map(s => s.trim()) - .filter(Boolean); - - if (items.length > 0) return items.join('\n'); - return inner; -} - -/** - * Generate a summary using sidecar context - * @param {object} summaryCtx - Context to use for summarization - * @param {string} text - Text to summarize - * @param {object} options - Options: maxTokens, brief, format ('self' | 'slim-summarize') - */ -async function generateSummary(summaryCtx, text, options = {}) { - const maxTokens = options.maxTokens || 200; - const format = options.format || 'self'; - - let tokens; - - if (format === 'self') { - // Self-summary: use model's chat template via formatChat() - const { prompt } = await summaryCtx.formatChat( - JSON.stringify([ - { - role: 'system', - content: 'Summarize the following text concisely. List the key points.', - }, - { role: 'user', content: text.slice(-10000) }, - ]) - ); - tokens = await summaryCtx.tokenize(prompt); - } else { - // slim-summarize prompt format - const paramStr = options.brief ? 'brief description (1)' : 'key points (5)'; - const prompt = ` ${text.slice(-10000)}\n ${paramStr}\n:`; - tokens = await summaryCtx.tokenize(prompt); - } - - await summaryCtx.kvCacheClear(); - const branch = Branch.create(summaryCtx, 0, { temperature: 0.3 }); - await branch.prefill(tokens); - - let response = ''; - for (let i = 0; i < maxTokens; i++) { - const { token, text: t, isStop } = await branch.produce(); - if (isStop) break; - response += t; - await branch.commit(token); - } - await branch.prune(); - - // Only parse slim-summarize Python-style list format - return format === 'slim-summarize' - ? parseSummaryOutput(response.trim()) - : response.trim(); -} - -/** - * Parse numbered outline items from prompt text. - */ -function parseOutline(text) { - const items = []; - const regex = /^\s*(\d+)\.\s+(.+?)(?:\s*[-–—:]\s*.*)?$/gm; - let match; - while ((match = regex.exec(text)) !== null) { - items.push({ - number: parseInt(match[1]), - title: match[2].trim(), - }); - } - return items; -} - -/** - * Extract instruction part of prompt, before any numbered outline. - */ -function extractMinimalAnchor(text) { - const listMatch = text.match(/^\s*1\.\s/m); - if (listMatch && listMatch.index > 0) { - return text.slice(0, listMatch.index).trim(); - } - return text.slice(0, 200).trim(); -} - -/** - * Build ghostwritten progress sink. - * Completed items show "- done", current shows "- continue from here". - * Model pattern-matches to continue from the right section. - */ -function buildProgressSink(anchor, outline, allGeneratedText, summaryChain) { - const lower = allGeneratedText.toLowerCase(); - - let lastCoveredIdx = -1; - for (let i = outline.length - 1; i >= 0; i--) { - if (lower.includes(outline[i].title.toLowerCase())) { - lastCoveredIdx = i; - break; - } - } - - let text = `${anchor}\n\n`; - - for (let i = 0; i < outline.length; i++) { - const item = outline[i]; - if (i < lastCoveredIdx) { - text += `${item.number}. ${item.title} - done\n`; - } else if (i === lastCoveredIdx) { - text += `${item.number}. ${item.title} - continue from here\n`; - } else { - text += `${item.number}. ${item.title}\n`; - } - } - - if (summaryChain) { - text += `\nKey points so far:\n${summaryChain}\n`; - } - - return text; -} - -async function main() { - // Constants - const nCtx = parseInt(process.env.LLAMA_CTX_SIZE || '2048', 10); - const TAIL_SIZE = 256; - const MAX_SINK_RATIO = 0.4; - const MAX_SINK_TOKENS = Math.floor(nCtx * MAX_SINK_RATIO); - const SUMMARY_MAX_TOKENS = 200; - - // Determine summary mode before emitting start event - const summaryFormat = useSidecar ? 'slim-summarize' : 'self'; - - if (!jsonlMode) { - console.log(`Loading model: ${modelPath}`); - console.log(`Summary mode: ${summaryFormat}`); - } - - emit('start', { - model: path.basename(modelPath), - nCtx, - tailSize: TAIL_SIZE, - maxSinkTokens: MAX_SINK_TOKENS, - targetTokens: TARGET_TOKENS, - summaryMode: summaryFormat, - }); - - const ctx = await createContext({ - modelPath, - nCtx, - }); - - // Summary sidecar β€” preload in background (overlaps with prompt decode + generation) - // Default: "self" mode - second context from same model (weights shared via model_registry) - // --sidecar flag: use slim-summarize.gguf instead - let summaryCtx = null; - let summaryCtxPromise = null; - let actualSummaryFormat = summaryFormat; - - if (useSidecar) { - // Sidecar mode: use slim-summarize.gguf - const summaryModelAvailable = fs.existsSync(SUMMARY_MODEL); - if (summaryModelAvailable) { - summaryCtxPromise = createContext({ modelPath: SUMMARY_MODEL, nCtx: 4096 }); - } else { - if (!jsonlMode) { - console.log('Sidecar model not found - falling back to self-summary'); - } - emit('sidecar_missing', { message: 'slim-summarize.gguf not found, using self-summary' }); - // Fall back to self mode - summaryCtxPromise = createContext({ modelPath, nCtx: 4096 }); - actualSummaryFormat = 'self'; - } - } else { - // Self mode (default): second context from same model - // Weights are shared via model_registry β€” only KV cache is duplicated - summaryCtxPromise = createContext({ modelPath, nCtx: 4096 }); - } - - const prompt = `Write a comprehensive guide to machine learning, covering the following topics in extreme detail with examples, code snippets, and mathematical formulas: - -1. Linear Regression - derivation, implementation, regularization -2. Logistic Regression - binary and multiclass -3. Neural Networks - backpropagation, activation functions -4. Convolutional Neural Networks - architectures, pooling, stride -5. Recurrent Neural Networks - LSTM, GRU, attention -6. Transformers - self-attention, positional encoding -7. Optimization - SGD, Adam, learning rate schedules -8. Regularization - dropout, batch normalization, weight decay - -Begin: - -# Comprehensive Machine Learning Guide - -## Chapter 1: Linear Regression - -`; - - // Parse outline for ghostwritten progress sinks - const outline = parseOutline(prompt); - const minimalAnchor = outline.length > 0 - ? extractMinimalAnchor(prompt) - : null; - - if (!jsonlMode) { - console.log(`\nPrompt: "${prompt.slice(0, 100)}..."`); - if (outline.length > 0) { - console.log(`Outline detected: ${outline.length} sections`); - console.log(`Minimal anchor: "${minimalAnchor}"`); - } - } - - const promptTokens = await ctx.tokenize(prompt); - - // Fallback anchor for prompts without outlines - let anchorTokens = null; - if (outline.length === 0) { - anchorTokens = [...promptTokens]; - } - - const summaryBudget = outline.length > 0 - ? MAX_SINK_TOKENS - : MAX_SINK_TOKENS - (anchorTokens?.length || 0); - - const samplingParams = { temperature: 0.8, topP: 0.9 }; - let branch = Branch.create(ctx, 0, samplingParams); - await branch.prefill(promptTokens); - - if (!jsonlMode) { - console.log(`\nContext size: ${nCtx}`); - console.log(`Target tokens: ${TARGET_TOKENS}`); - console.log(`Sink budget: ${MAX_SINK_TOKENS} tokens`); - console.log(`Tail size: ${TAIL_SIZE}`); - console.log(`\nGenerating...\n`); - process.stdout.write(prompt); - } - - const allTokens = [...promptTokens]; - // Manual PPL tracking (persists across branch reseeds) - let nllSum = 0, nllCount = 0; - let reseedCount = 0; - let currentSegmentText = ''; - let allGeneratedText = ''; - const summaries = []; - let pendingSummaryTokens = []; - - for (let t = 0; t < TARGET_TOKENS; t++) { - const { token, isStop } = await branch.produce(); - - if (isStop) { - if (!jsonlMode) { - console.log('\n[EOS token reached]'); - } - emit('eos', { tokenIndex: t }); - break; - } - - const surprisal = branch.modelSurprisal(token, 'nats'); - nllSum += Math.max(0, surprisal); - nllCount++; - - const text = ctx.tokenToText(token); - if (!jsonlMode) { - process.stdout.write(text); - } - emit('token', { source: 'main', index: t, token, text, surprisal }); - - currentSegmentText += text; - allGeneratedText += text; - allTokens.push(token); - await branch.commit(token); - - // Cache full? Reseed with dynamic sinks - if (branch.position >= nCtx) { - // Estimate evicted portion of current segment only - const tailCharsEstimate = TAIL_SIZE * 4; - const evictedFromSegment = currentSegmentText.length > tailCharsEstimate - ? currentSegmentText.slice(0, -tailCharsEstimate) - : ''; - - let sinks; - - // Resolve preloaded summary context (should already be loaded by now) - if (summaryCtxPromise && !summaryCtx) { - summaryCtx = await summaryCtxPromise; - const summaryModelName = actualSummaryFormat === 'self' ? path.basename(modelPath) : 'slim-summarize.gguf'; - if (!jsonlMode) { - console.log(`\n [Summary context loaded: ${summaryModelName} (${actualSummaryFormat} mode)]`); - } - emit('summary_loaded', { model: summaryModelName, mode: actualSummaryFormat }); - } - - // Run summary sidecar if available - let chainText = null; - if (summaryCtx && evictedFromSegment.length > 0) { - emit('summary_start', { reseedCount: reseedCount + 1 }); - const summaryStartTime = Date.now(); - - if (!jsonlMode) { - process.stdout.write(`\n [Summarizing ${evictedFromSegment.length} evicted chars (page ${summaries.length + 1})...`); - } - - const newPage = await generateSummary(summaryCtx, evictedFromSegment, { - maxTokens: SUMMARY_MAX_TOKENS, - format: actualSummaryFormat, - }); - summaries.push(newPage); - chainText = summaries.join('\n'); - - // Fold oldest pages if chain is getting large - let testTokens = await ctx.tokenize(chainText); - if (testTokens.length > summaryBudget * 0.6) { - if (!jsonlMode) { - process.stdout.write(' (folding oldest pages)'); - } - - const foldCount = Math.max(1, Math.ceil(summaries.length / 2)); - const toFold = summaries.splice(0, foldCount); - const folded = await generateSummary(summaryCtx, toFold.join('\n'), { - brief: true, - maxTokens: 100, - format: actualSummaryFormat, - }); - summaries.unshift(folded); - chainText = summaries.join('\n'); - } - - const compressionRatio = evictedFromSegment.length > 0 - ? (evictedFromSegment.length / newPage.length).toFixed(1) - : '0'; - const durationMs = Date.now() - summaryStartTime; - - emit('summary_complete', { - reseedCount: reseedCount + 1, - summary: newPage, - summaryTokens: (await ctx.tokenize(chainText)).length, - compressionRatio: parseFloat(compressionRatio), - durationMs, - pages: summaries.length, - }); - - if (!jsonlMode) { - process.stdout.write(` ${compressionRatio}x, ${summaries.length} pages]`); - } - } - - // Build sinks β€” progress mode (outline detected) or fallback - if (outline.length > 0) { - const progressText = buildProgressSink( - minimalAnchor, outline, allGeneratedText, chainText - ); - let progressTokens = await ctx.tokenize(progressText); - - if (progressTokens.length > MAX_SINK_TOKENS) { - // Drop summary details to fit budget - const trimmedText = buildProgressSink( - minimalAnchor, outline, allGeneratedText, null - ); - progressTokens = await ctx.tokenize(trimmedText); - } - - sinks = progressTokens; - pendingSummaryTokens = progressTokens; - - if (!jsonlMode) { - console.log(`\n [Progress sink: ${progressTokens.length} tok]`); - // Show progress state - const lower = allGeneratedText.toLowerCase(); - let lastIdx = -1; - for (let i = outline.length - 1; i >= 0; i--) { - if (lower.includes(outline[i].title.toLowerCase())) { - lastIdx = i; break; - } - } - if (lastIdx >= 0) { - console.log(` [Sections done: ${lastIdx}, continuing: ${outline[lastIdx].title}]`); - } - } - - emit('sink_update', { - anchorTokens: 0, - summaryTokens: progressTokens.length, - totalSinkTokens: progressTokens.length, - budgetUsed: ((progressTokens.length / MAX_SINK_TOKENS) * 100).toFixed(1), - budgetMax: MAX_SINK_TOKENS, - pages: summaries.length, - mode: 'progress', - }); - } else if (chainText) { - const wrapped = `Previously:\n${chainText}\n`; - const summaryTokens = await ctx.tokenize(wrapped); - sinks = [...anchorTokens, ...summaryTokens]; - pendingSummaryTokens = summaryTokens; - - if (!jsonlMode) { - process.stdout.write(` ${summaryTokens.length} summary tok]`); - } - - emit('sink_update', { - anchorTokens: anchorTokens.length, - summaryTokens: summaryTokens.length, - totalSinkTokens: sinks.length, - budgetUsed: ((sinks.length / MAX_SINK_TOKENS) * 100).toFixed(1), - budgetMax: MAX_SINK_TOKENS, - pages: summaries.length, - mode: 'anchor', - }); - } else { - sinks = [...(anchorTokens || [])]; - } - - const tail = allTokens.slice(-TAIL_SIZE); - - // Destroy current branch, clear KV, create fresh branch with re-prefill - await branch.prune(); - await ctx.kvCacheClear(); - branch = Branch.create(ctx, 0, samplingParams); - await branch.prefill([...sinks, ...tail]); - - reseedCount++; - - const ppl = nllCount > 0 ? Math.exp(nllSum / nllCount) : 1; - emit('reseed', { - count: reseedCount, - tokenIndex: t + 1, - ppl, - sinkTokens: sinks.length, - tailTokens: TAIL_SIZE, - summaryPages: summaries.length, - summaryPreview: summaries[summaries.length - 1]?.slice(0, 100) || '', - }); - - if (!jsonlMode) { - console.log(` [Reseed ${reseedCount} at token ${t + 1}/${TARGET_TOKENS} | PPL: ${ppl.toFixed(2)} | Sinks: ${sinks.length} tok | Pages: ${summaries.length}]`); - } - - currentSegmentText = ''; - } - - // Progress indicator every 1000 tokens - if ((t + 1) % 1000 === 0 && reseedCount === 0 && !jsonlMode) { - console.log(`\n [${t + 1}/${TARGET_TOKENS} tokens]`); - } - } - - const finalPpl = nllCount > 0 ? Math.exp(nllSum / nllCount) : 1; - await branch.prune(); - - const generatedTokens = allTokens.length - promptTokens.length; - const finalChain = summaries.join('\n'); - emit('complete', { - generatedTokens, - reseeds: reseedCount, - finalPpl, - finalSummary: finalChain.slice(0, 300), - finalSummaryTokens: pendingSummaryTokens.length, - summaryPages: summaries.length, - }); - - if (!jsonlMode) { - console.log('\n\n' + '='.repeat(50)); - console.log(`Generated: ${generatedTokens} tokens`); - console.log(`Reseeds: ${reseedCount}`); - console.log(`Final perplexity: ${finalPpl.toFixed(2)}`); - if (summaries.length > 0) { - console.log(`Summary pages: ${summaries.length}`); - console.log(`Final chain (${pendingSummaryTokens.length} tok): ${finalChain.slice(0, 200)}`); - } - console.log('='.repeat(50)); - } - - ctx.dispose(); - if (summaryCtx) summaryCtx.dispose(); -} - -main().catch((err) => { - console.error('Error:', err.message); - process.exit(1); -}); diff --git a/examples/streaming/streaming-tsampler.mjs b/examples/streaming/streaming-tsampler.mjs deleted file mode 100644 index ec41ec5..0000000 --- a/examples/streaming/streaming-tsampler.mjs +++ /dev/null @@ -1,326 +0,0 @@ -#!/usr/bin/env node -/** - * Infinite context generation with BlinkKV + tsampler N-gram deduplication - * - * This example demonstrates: - * - TypeScript sampling via tsampler (TTA pattern) - * - N-gram tracking to detect sequence repetition - * - Logit steering to prevent repeated sequences - * - Branch API for KV management (prefill/commit) - * - KV cache clear + re-prefill for infinite context - * - * The key insight: llama.cpp's token-level penalties degrade prose quality. - * Instead, we track N-grams at the app level and steer away from repeats. - * - * Usage: - * node streaming-tsampler.mjs [model-path] # Human-readable output - * node streaming-tsampler.mjs [model-path] --jsonl # JSONL output for testing - */ - -import * as path from 'node:path'; -import { fileURLToPath } from 'node:url'; -import { createContext, Branch } from '../../lib/index.js'; - -// Import tsampler from npm package -import { - sampleWithStrategy, - // TokenHistoryTracker, // Disabled - matching baseline - Xoroshiro128Plus, - SamplerWorkspace, -} from '@lloyal-labs/tsampler'; - -const __dirname = path.dirname(fileURLToPath(import.meta.url)); -const DEFAULT_MODEL = path.resolve( - __dirname, - '../../models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf' -); - -// Parse args -const args = process.argv.slice(2); -const jsonlMode = args.includes('--jsonl'); -const modelPath = args.find(a => !a.startsWith('--')) || DEFAULT_MODEL; - -// Parse --max-tokens for CI (default 5000) -const maxTokensArg = args.find(a => a.startsWith('--max-tokens=')); -const TARGET_TOKENS = maxTokensArg ? parseInt(maxTokensArg.split('=')[1], 10) : 5000; - -/** Emit output - JSONL or human-readable */ -function emit(event, data) { - if (jsonlMode) { - console.log(JSON.stringify({ event, ...data })); - } -} - -/** - * N-gram tracker for sequence-level repetition detection (threshold-based) - * - * Tracks N-grams and their followers. Only blocks when the SAME N-gram β†’ follower - * pattern is seen K times (threshold), indicating true looping behavior rather - * than coincidental reuse. - */ -class NgramTracker { - constructor(n = 4, threshold = 2) { - this.n = n; - this.threshold = threshold; // Block after seeing same pattern K times - this.ngrams = new Map(); // ngram key -> Map - this.recentTokens = []; - } - - /** - * Record a token and update N-gram history - */ - accept(token) { - this.recentTokens.push(token); - - // Once we have enough tokens, record the N-gram and what followed - if (this.recentTokens.length > this.n) { - const ngramTokens = this.recentTokens.slice(-this.n - 1, -1); - const ngramKey = ngramTokens.join(','); - - // Get or create follower counts for this N-gram - if (!this.ngrams.has(ngramKey)) { - this.ngrams.set(ngramKey, new Map()); - } - const followers = this.ngrams.get(ngramKey); - - // Increment count for this follower - const count = followers.get(token) || 0; - followers.set(token, count + 1); - } - } - - /** - * Check if current context would repeat an N-gram above threshold - * @returns {number|null} Token to block, or null if below threshold - */ - getBlockedToken() { - if (this.recentTokens.length < this.n) { - return null; - } - - const currentNgram = this.recentTokens.slice(-this.n); - const ngramKey = currentNgram.join(','); - - const followers = this.ngrams.get(ngramKey); - if (!followers) { - return null; - } - - // Find follower that has hit threshold (true loop) - for (const [follower, count] of followers) { - if (count >= this.threshold) { - return follower; - } - } - - return null; - } - - /** - * Get stats for logging - */ - stats() { - let totalPatterns = 0; - for (const followers of this.ngrams.values()) { - totalPatterns += followers.size; - } - return { - uniqueNgrams: this.ngrams.size, - totalPatterns, - totalTokens: this.recentTokens.length, - }; - } -} - -async function main() { - // BlinkKV parameters - const nCtx = parseInt(process.env.LLAMA_CTX_SIZE || '2048', 10); - const TAIL_SIZE = 256; - const NGRAM_SIZE = 6; // Track 6-grams for sequence detection - const BLOCK_THRESHOLD = 2; // Only block after seeing same pattern K times - - if (!jsonlMode) { - console.log(`Loading model: ${modelPath}`); - } - - emit('start', { model: path.basename(modelPath), nCtx, tailSize: TAIL_SIZE, targetTokens: TARGET_TOKENS, ngramSize: NGRAM_SIZE, blockThreshold: BLOCK_THRESHOLD }); - - const ctx = await createContext({ - modelPath, - nCtx, - }); - - const prompt = `Write a comprehensive guide to machine learning, covering the following topics in extreme detail with examples, code snippets, and mathematical formulas: - -1. Linear Regression - derivation, implementation, regularization -2. Logistic Regression - binary and multiclass -3. Neural Networks - backpropagation, activation functions -4. Convolutional Neural Networks - architectures, pooling, stride -5. Recurrent Neural Networks - LSTM, GRU, attention -6. Transformers - self-attention, positional encoding -7. Optimization - SGD, Adam, learning rate schedules -8. Regularization - dropout, batch normalization, weight decay - -Begin: - -# Comprehensive Machine Learning Guide - -## Chapter 1: Linear Regression - -`; - if (!jsonlMode) { - console.log(`\nPrompt: "${prompt.slice(0, 100)}..."`); - } - - const promptTokens = await ctx.tokenize(prompt); - - // Track all generated tokens - const allTokens = [...promptTokens]; - const sinks = [...promptTokens]; // Sink the entire prompt - - // tsampler setup - const prng = new Xoroshiro128Plus(42); // Fixed seed for reproducibility - // const tokenHistory = new TokenHistoryTracker(32); // Disabled - matching baseline - const workspace = new SamplerWorkspace(256); - - // N-gram tracker for sequence-level deduplication - const ngramTracker = new NgramTracker(NGRAM_SIZE, BLOCK_THRESHOLD); - - // Seed N-gram tracker with prompt tokens - for (const token of promptTokens) { - ngramTracker.accept(token); - } - - if (!jsonlMode) { - console.log(`\nContext size: ${nCtx}`); - console.log(`Target tokens: ${TARGET_TOKENS}`); - console.log(`Sink tokens (prompt): ${sinks.length}`); - console.log(`Tail size: ${TAIL_SIZE}`); - console.log(`N-gram size: ${NGRAM_SIZE}, block threshold: ${BLOCK_THRESHOLD}`); - console.log(`\nGenerating with tsampler + N-gram deduplication (threshold-based)...\n`); - process.stdout.write(prompt); - } - - // Branch used purely for KV management β€” sampling done externally via tsampler - let branch = Branch.create(ctx, 0, { temperature: 0 }); - await branch.prefill(promptTokens); - - // Manual PPL tracking (persists across branch reseeds) - let nllSum = 0, nllCount = 0; - let reseedCount = 0; - let blockedCount = 0; - - for (let t = 0; t < TARGET_TOKENS; t++) { - // Get logits from branch snapshot - const originalLogits = branch.getLogits(); - const logits = new Float32Array(originalLogits); - - // N-gram deduplication: Check if we're about to repeat a sequence - const blockedToken = ngramTracker.getBlockedToken(); - const wasBlocked = blockedToken !== null && blockedToken < logits.length; - if (wasBlocked) { - // Steer away from the repeat by setting logit to -Infinity - logits[blockedToken] = -Infinity; - blockedCount++; - } - - // Sample with tsampler (TTA pattern) - // Match baseline params exactly: temp 0.8, topP 0.9, no topK, no penalties - const token = sampleWithStrategy(logits, { - params: { - temperature: 0.8, - topP: 0.9, - }, - workspace, - prng, - }); - - // Check for EOS - if (ctx.isStopToken(token)) { - if (!jsonlMode) { - console.log('\n[EOS token reached]'); - } - emit('eos', { tokenIndex: t }); - break; - } - - // Accept token into trackers - // tokenHistory.accept(token); // Disabled - matching baseline - ngramTracker.accept(token); - - // Track surprisal from branch's logits snapshot (before N-gram steering) - const surprisal = branch.modelSurprisal(token, 'nats'); - nllSum += Math.max(0, surprisal); - nllCount++; - - // Output token - const text = ctx.tokenToText(token); - if (!jsonlMode) { - process.stdout.write(text); - } - emit('token', { index: t, token, text, surprisal, blocked: wasBlocked }); - - // Store and advance KV (no sampler accept β€” we're using tsampler externally) - allTokens.push(token); - await branch.commit(token); - - // Cache full? Reseed at boundary - if (branch.position >= nCtx) { - const tail = allTokens.slice(-TAIL_SIZE); - - // Destroy current branch, clear KV, create fresh branch with re-prefill - await branch.prune(); - await ctx.kvCacheClear(); - branch = Branch.create(ctx, 0, { temperature: 0 }); - await branch.prefill([...sinks, ...tail]); - - reseedCount++; - - const ppl = nllCount > 0 ? Math.exp(nllSum / nllCount) : 1; - const stats = ngramTracker.stats(); - - emit('reseed', { count: reseedCount, tokenIndex: t + 1, ppl, blockedCount, uniqueNgrams: stats.uniqueNgrams }); - - if (!jsonlMode) { - console.log(`\n [Reseed ${reseedCount} at token ${t + 1}/${TARGET_TOKENS} | PPL: ${ppl.toFixed(2)} | Blocked: ${blockedCount} | Unique ${NGRAM_SIZE}-grams: ${stats.uniqueNgrams}]`); - } - } - - // Progress every 1000 tokens - if ((t + 1) % 1000 === 0 && branch.position < nCtx && !jsonlMode) { - const stats = ngramTracker.stats(); - console.log(`\n [${t + 1}/${TARGET_TOKENS} | Blocked repeats: ${blockedCount} | Unique ${NGRAM_SIZE}-grams: ${stats.uniqueNgrams}]`); - } - } - - const finalPpl = nllCount > 0 ? Math.exp(nllSum / nllCount) : 1; - const finalStats = ngramTracker.stats(); - await branch.prune(); - - const generatedTokens = allTokens.length - promptTokens.length; - emit('complete', { - generatedTokens, - reseeds: reseedCount, - finalPpl, - blockedCount, - uniqueNgrams: finalStats.uniqueNgrams, - }); - - if (!jsonlMode) { - console.log('\n\n' + '='.repeat(60)); - console.log(`Generated: ${generatedTokens} tokens`); - console.log(`Reseeds: ${reseedCount}`); - console.log(`Final perplexity: ${finalPpl.toFixed(2)}`); - console.log(`Sequence repeats blocked: ${blockedCount}`); - console.log(`Unique ${NGRAM_SIZE}-grams tracked: ${finalStats.uniqueNgrams}`); - console.log('='.repeat(60)); - } - - ctx.dispose(); -} - -main().catch((err) => { - console.error('Error:', err.message); - console.error(err.stack); - process.exit(1); -}); diff --git a/examples/streaming/streaming.mjs b/examples/streaming/streaming.mjs deleted file mode 100644 index e877e64..0000000 --- a/examples/streaming/streaming.mjs +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env node -/** - * Infinite context generation with BlinkKV - * - * Usage: - * node streaming.mjs [model-path] # Human-readable output - * node streaming.mjs [model-path] --jsonl # JSONL output for testing - * - * This example demonstrates: - * - Generating tokens beyond context window limit - * - KV cache clear + re-prefill for cache-local position reindexing - * - Per-token perplexity measurement across reseeds - * - Branch API for generation (produce/commit loop) - * - * Parameters from BlinkKV paper: 2048 context, 4 sinks, 256 tail - */ - -import * as path from 'node:path'; -import { fileURLToPath } from 'node:url'; -import { createContext, Branch } from '../../lib/index.js'; - -const __dirname = path.dirname(fileURLToPath(import.meta.url)); -const DEFAULT_MODEL = path.resolve( - __dirname, - '../../models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf' -); - -// Parse args -const args = process.argv.slice(2); -const jsonlMode = args.includes('--jsonl'); -const modelPath = args.find(a => !a.startsWith('--')) || DEFAULT_MODEL; - -// Parse --max-tokens for CI (default 5000) -const maxTokensArg = args.find(a => a.startsWith('--max-tokens=')); -const TARGET_TOKENS = maxTokensArg ? parseInt(maxTokensArg.split('=')[1], 10) : 5000; - -/** Emit output - JSONL or human-readable */ -function emit(event, data) { - if (jsonlMode) { - console.log(JSON.stringify({ event, ...data })); - } -} - -async function main() { - // BlinkKV paper parameters: 2048 context, 4 sinks, 256 tail - const nCtx = parseInt(process.env.LLAMA_CTX_SIZE || '2048', 10); - const SINK_COUNT = 4; - const TAIL_SIZE = 256; - - if (!jsonlMode) { - console.log(`Loading model: ${modelPath}`); - } - - emit('start', { model: path.basename(modelPath), nCtx, sinkCount: SINK_COUNT, tailSize: TAIL_SIZE, targetTokens: TARGET_TOKENS }); - - const ctx = await createContext({ - modelPath, - nCtx, - }); - - const prompt = `Write a comprehensive guide to machine learning, covering the following topics in extreme detail with examples, code snippets, and mathematical formulas: - -1. Linear Regression - derivation, implementation, regularization -2. Logistic Regression - binary and multiclass -3. Neural Networks - backpropagation, activation functions -4. Convolutional Neural Networks - architectures, pooling, stride -5. Recurrent Neural Networks - LSTM, GRU, attention -6. Transformers - self-attention, positional encoding -7. Optimization - SGD, Adam, learning rate schedules -8. Regularization - dropout, batch normalization, weight decay - -Begin: - -# Comprehensive Machine Learning Guide - -## Chapter 1: Linear Regression - -`; - if (!jsonlMode) { - console.log(`\nPrompt: "${prompt.slice(0, 100)}..."`); - } - - const promptTokens = await ctx.tokenize(prompt); - - // Track all generated tokens (needed for reseeding) - const allTokens = [...promptTokens]; - // Sink the entire prompt - it's the structural anchor - const sinks = [...promptTokens]; - - if (!jsonlMode) { - console.log(`\nContext size: ${nCtx}`); - console.log(`Target tokens: ${TARGET_TOKENS}`); - console.log(`Sink tokens (prompt): ${sinks.length}`); - console.log(`Tail size: ${TAIL_SIZE}`); - console.log(`Cache size after reseed: ${sinks.length + TAIL_SIZE}`); - console.log(`\nGenerating...\n`); - process.stdout.write(prompt); - } - - const samplingParams = { temperature: 0.8, topP: 0.9 }; - let branch = Branch.create(ctx, 0, samplingParams); - await branch.prefill(promptTokens); - - // Manual PPL tracking (persists across branch reseeds) - let nllSum = 0, nllCount = 0; - let reseedCount = 0; - - for (let t = 0; t < TARGET_TOKENS; t++) { - // NOTE: Token-level repeat penalties are NOT used for long-form generation. - // llama.cpp's penalty system penalizes individual tokens (not sequences), - // which degrades prose quality over 100+ tokens as common words accumulate - // in the penalty buffer. For sequence-level deduplication, use N-gram - // tracking with logit steering (TTA pattern) instead. - const { token, isStop } = await branch.produce(); - if (isStop) { - if (!jsonlMode) { - console.log('\n[EOS token reached]'); - } - emit('eos', { tokenIndex: t }); - break; - } - - // Track surprisal from the logits used by produce() - const surprisal = branch.modelSurprisal(token, 'nats'); - nllSum += Math.max(0, surprisal); - nllCount++; - - // Output token - const text = ctx.tokenToText(token); - if (!jsonlMode) { - process.stdout.write(text); - } - emit('token', { index: t, token, text, surprisal }); - - // Store token and commit (decode + capture new logits) - allTokens.push(token); - await branch.commit(token); - - // Cache full? Reseed at boundary - if (branch.position >= nCtx) { - const tail = allTokens.slice(-TAIL_SIZE); - - // Destroy current branch, clear KV, create fresh branch with re-prefill - await branch.prune(); - await ctx.kvCacheClear(); - branch = Branch.create(ctx, 0, samplingParams); - await branch.prefill([...sinks, ...tail]); - - reseedCount++; - - const ppl = nllCount > 0 ? Math.exp(nllSum / nllCount) : 1; - emit('reseed', { count: reseedCount, tokenIndex: t + 1, ppl }); - - if (!jsonlMode) { - console.log(`\n [Reseed ${reseedCount} at token ${t + 1}/${TARGET_TOKENS} | PPL: ${ppl.toFixed(2)}]`); - } - } - - // Progress indicator every 1000 tokens - if ((t + 1) % 1000 === 0 && reseedCount === 0 && !jsonlMode) { - console.log(`\n [${t + 1}/${TARGET_TOKENS} tokens]`); - } - } - - const finalPpl = nllCount > 0 ? Math.exp(nllSum / nllCount) : 1; - await branch.prune(); - - const generatedTokens = allTokens.length - promptTokens.length; - emit('complete', { generatedTokens, reseeds: reseedCount, finalPpl }); - - if (!jsonlMode) { - console.log('\n\n' + '='.repeat(50)); - console.log(`Generated: ${generatedTokens} tokens`); - console.log(`Reseeds: ${reseedCount}`); - console.log(`Final perplexity: ${finalPpl.toFixed(2)}`); - console.log('='.repeat(50)); - } - - ctx.dispose(); -} - -main().catch((err) => { - console.error('Error:', err.message); - process.exit(1); -}); diff --git a/lib/Branch.js b/lib/Branch.js deleted file mode 100644 index b7ee396..0000000 --- a/lib/Branch.js +++ /dev/null @@ -1,471 +0,0 @@ -/** - * Branch - Forkable inference handle for covalent generation - * - * A Branch owns everything needed for independent generation: a KV cache - * sequence, sampler chain, logits snapshot, and perplexity tracker. - * - * Forking is cheap β€” the KV prefix is shared in memory (metadata-only operation under unified KV β€” - * no KV tensor buffers are copied), so sibling branches read from the same physical KV entries. - * Only tokens decoded after the fork point are exclusive to each branch. - * This is the covalent property: branches share a bond (common prefix) - * while diverging independently. - * - * Branches form trees, not just flat lists. Fork from root for best-of-N, - * fork from children for tree search/beam search, fork from a draft for speculative - * decoding. - * - * The produce/commit protocol separates sampling from state advancement: - * produce() samples without writing to KV, letting you inspect the result - * before deciding to commit(). This two-phase split is what makes speculative - * verification and tree search natural. - * - * @example Best-of-N with perplexity selection - * ```js - * const root = Branch.create(ctx, tokens.length, { temperature: 0.8 }); - * await root.prefill(tokens); - * - * const results = []; - * for (let i = 0; i < 5; i++) { - * const branch = await root.fork(); - * branch.reseedSampler(1000 + i); - * const tokens = []; - * for await (const { token } of branch) tokens.push(token); - * results.push({ branch, tokens, ppl: branch.perplexity }); - * } - * - * const best = results.reduce((a, b) => a.ppl < b.ppl ? a : b); - * for (const r of results) { if (r !== best) await r.branch.prune(); } - * ``` - */ - -class Branch { - /** - * @param {SessionContext} ctx - * @param {number} handle - */ - constructor(ctx, handle) { - this._ctx = ctx; - this._handle = handle; - this._disposed = false; - } - - /** - * Create a root branch at the given position - * - * The branch takes ownership of the sequence and creates its own sampler - * chain from the provided params. Call prefill() to decode prompt tokens - * and capture the logit distribution before forking. - * - * @param {SessionContext} ctx - SessionContext to create branch on - * @param {number} position - Starting position (typically prompt token count) - * @param {SamplingParams} [params] - Sampling parameters (temperature, topP, etc.) - * @param {number} [nBatch] - Per-branch batch size override (defaults to context nBatch). - * Controls chunk size for prefill(). Has no effect on - * single-token commit() which uses a zero-allocation fast path. Useful for tuning - * memory/throughput tradeoff on bulk token decode β€” e.g. smaller nBatch for cheap - * exploration branches, larger for the trunk. - * @param {string} [grammar] - GBNF grammar string for constrained generation. - * When provided, sample() returns only grammar-valid tokens. The grammar state - * is cloned on fork(), so sibling branches can diverge independently. - * @returns {Branch} New Branch instance - */ - static create(ctx, position, params, nBatch, grammar) { - const handle = ctx._branchCreate(position, params, nBatch, grammar); - return new Branch(ctx, handle); - } - - /** - * Fork this branch to a new sequence - * - * The child shares the parent's KV prefix in memory (metadata-only under unified KV, no KV buffer copy). - * Logits, sampler state, and perplexity tracker are cloned so the child - * can diverge independently. Fork from any branch β€” root or intermediate β€” - * to build arbitrarily deep trees. - * - * Call reseedSampler() on each child for stochastic diversity. - * - * @returns {Promise} New forked Branch - */ - async fork() { - this._ensureNotDisposed(); - const newHandle = this._ctx._branchFork(this._handle); - return new Branch(this._ctx, newHandle); - } - - /** - * Get a copy of this branch's captured logits snapshot - * - * Returns n_vocab floats β€” the raw logit distribution from the last - * prefill() or commit() call. Use for distributional analysis - * (KL divergence, entropy, top-k overlap) without crossing the - * sampling chain. - * - * @returns {Float32Array} Copy of the logits snapshot (n_vocab elements) - * @throws {Error} If no logits have been captured yet - */ - getLogits() { - this._ensureNotDisposed(); - return this._ctx._branchGetLogits(this._handle); - } - - /** - * Bulk-decode tokens into the branch's KV cache and capture logits - * - * Feeds an array of tokens through the model. tokens.length is the total - * count to process; the branch's nBatch (set at Branch.create) controls - * how many are sent per llama_decode call. For example, 500 tokens with - * nBatch=64 makes 8 llama_decode calls (7x64 + 1x52). With nBatch=512 - * it makes 1. - * - * Advances position by tokens.length and stores the final logits into - * the branch's internal snapshot. The next produce()/sample() call reads - * from that snapshot β€” logits never cross the JS boundary. - * - * Does NOT accept tokens into the sampler's repeat-penalty window β€” use - * this for external tokens (user input between turns), not model-generated - * tokens. For model output, use commit() which does accept + decode. - * - * The primary way to feed tokens into a branch's KV cache. - * - * @param {number[]} tokens - Token IDs to decode - * @returns {Promise} - */ - async prefill(tokens) { - this._ensureNotDisposed(); - await this._ctx._branchPrefill(this._handle, tokens); - } - - /** - * Sample next token from branch's logits snapshot - * - * Applies the branch's full sampler chain (top-k, top-p, temperature, - * repeat/presence penalties) to the captured logits. - * - * @returns {number} Sampled token ID - */ - sample() { - this._ensureNotDisposed(); - return this._ctx._branchSample(this._handle); - } - - /** - * Record token in the sampler's repeat/presence penalty window - * - * @param {number} token - Token to accept - */ - accept(token) { - this._ensureNotDisposed(); - this._ctx._branchAccept(this._handle, token); - } - - /** - * Discard this branch entirely β€” remove its KV entries and free the handle - * - * Use for losers: branches whose generation you want to erase completely. - * Only removes KV entries divergent from the shared prefix; sibling - * branches are unaffected. - * - * @returns {Promise} - */ - async prune() { - if (this._disposed) return; - this._ctx._branchPrune(this._handle); - this._disposed = true; - } - - /** - * Discard this branch and all its descendants β€” CASCADE delete - * - * Iterative post-order traversal: prunes children first, then this branch. - * Use when you want to tear down an entire subtree (e.g. abandoned search path). - * - * @returns {Promise} - */ - async pruneSubtree() { - if (this._disposed) return; - this._ctx._branchPruneSubtree(this._handle); - this._disposed = true; - } - - /** - * Reseed the sampler's PRNG for diversity after fork() - * - * CRITICAL for parallel generation: Without reseeding, all forked branches - * produce identical outputs because they share the same PRNG state. - * - * Only affects stochastic samplers (temperature > 0). Greedy samplers are unchanged. - * - * @param {number} seed - New seed for the PRNG - * - * @example - * ```js - * const root = Branch.create(ctx, pos, { temperature: 0.9 }); - * await root.prefill(promptTokens); - * - * // Fork and reseed for diversity - * const branches = []; - * for (let i = 0; i < 5; i++) { - * const branch = await root.fork(); - * branch.reseedSampler(1000 + i); // Each branch gets unique seed - * branches.push(branch); - * } - * ``` - */ - reseedSampler(seed) { - this._ensureNotDisposed(); - this._ctx._branchSamplerChainReseed(this._handle, seed); - } - - /** - * Apply dynamic logit adjustments for this branch only - * - * Unlike logit_bias (which is cloned on fork), steer biases are NOT inherited - * by child branches. Each branch manages its own steer state independently. - * - * Use cases: - * - tsampler: Block tokens that would create repeated N-grams (per-path history) - * - Tree search: Block already-explored actions at this node (not inherited by children) - * - * Applied during sample() in order: Grammar -> Logit Bias -> Steer -> Sampler Chain - * - * @param {Array<{token: number, bias: number}>} biases - Token adjustments. - * Use -Infinity to block a token, positive values to boost. - * - * @example Block tokens for N-gram deduplication - * ```js - * // Client computes blocked tokens based on generated text - * const blocked = computeNgramBlocks(generatedText); - * branch.steer(blocked.map(t => ({ token: t, bias: -Infinity }))); - * - * const { token } = await branch.produce(); // Blocked tokens won't be sampled - * await branch.commit(token); - * - * branch.clearSteer(); // Reset for next iteration - * ``` - */ - steer(biases) { - this._ensureNotDisposed(); - this._ctx._branchSteer(this._handle, biases); - } - - /** - * Clear all steer biases from this branch - * - * Removes any dynamic logit adjustments set by steer(). - */ - clearSteer() { - this._ensureNotDisposed(); - this._ctx._branchClearSteer(this._handle); - } - - /** - * Replace the sampler chain with new parameters (memoized) - * - * If the new params match the current chain's params, this is a no-op. - * Otherwise the old chain is freed and a new one is created. - * - * @param {SamplingParams} params - New sampling parameters - */ - setSamplerParams(params) { - this._ensureNotDisposed(); - this._ctx._branchSetSamplerParams(this._handle, params); - } - - /** - * Replace or remove the grammar constraint - * - * Pass a GBNF grammar string to constrain generation, or empty string / null - * to remove the constraint. The grammar state is cloned on fork(). - * - * @param {string} [grammarStr] - GBNF grammar string, or empty/null to remove - */ - setGrammar(grammarStr) { - this._ensureNotDisposed(); - this._ctx._branchSetGrammar(this._handle, grammarStr || ''); - } - - /** - * Sample the next token without advancing state (async) - * - * No KV write, no position update. Inspect the result before deciding - * to commit() β€” this separation is what enables speculative verification - * and conditional branching. - * - * Async contract: local branches resolve immediately; cloud branches - * may perform an HTTP round-trip. Use produceSync() when you know the - * branch is local and want zero-overhead sampling. - * - * @returns {Promise<{ token: number, text: string, isStop: boolean }>} - */ - async produce() { - return this.produceSync(); - } - - /** - * Sample the next token without advancing state (sync) - * - * Same as produce() but synchronous. Use when you know the branch is - * local and want to avoid the microtick overhead of a promise. - * - * @returns {{ token: number, text: string, isStop: boolean }} - */ - produceSync() { - this._ensureNotDisposed(); - const token = this.sample(); - return { - token, - text: this._ctx.tokenToText(token), - isStop: this._ctx.isStopToken(token), - }; - } - - /** - * Accept and decode β€” update branch state, then write token to KV - * - * Accepts the token into the sampler penalty window (for correct PPL - * measurement), then decodes (writing to KV cache) and captures the - * resulting logits for the next produce() call. Accept-first ordering - * with rollback: if decode throws, sampler/grammar/metrics are restored - * from clones taken before the accept. - * - * @param {number} token - Token to commit (from produce()) - * @returns {Promise} - */ - async commit(token) { - this._ensureNotDisposed(); - await this._ctx._storeCommit([this._handle], [token]); - } - - // ===== METRICS ===== - - /** - * Compute entropy of the branch's logits distribution - * - * @param {'nats'|'bits'} [base='nats'] - * @returns {number} - */ - modelEntropy(base = 'nats') { - this._ensureNotDisposed(); - return this._ctx._branchModelEntropy(this._handle, base); - } - - /** - * Compute surprisal for a specific token from the branch's logits - * - * @param {number} token - * @param {'nats'|'bits'} [base='nats'] - * @returns {number} - */ - modelSurprisal(token, base = 'nats') { - this._ensureNotDisposed(); - return this._ctx._branchModelSurprisal(this._handle, token, base); - } - - /** - * Sampling-level perplexity (from filtered distribution) - * - * @returns {number} - */ - get samplingPerplexity() { - this._ensureNotDisposed(); - return this._ctx._branchGetSamplingPerplexity(this._handle); - } - - /** - * Set static logit biases on this branch (cloned on fork) - * - * @param {Array<{token: number, bias: number}>} biases - */ - setLogitBias(biases) { - this._ensureNotDisposed(); - this._ctx._branchSetLogitBias(this._handle, biases); - } - - /** - * Clear all static logit biases from this branch - */ - clearLogitBias() { - this._ensureNotDisposed(); - this._ctx._branchClearLogitBias(this._handle); - } - - // ===== ACCESSORS ===== - - /** @returns {number} Branch's current position (number of tokens decoded) */ - get position() { - this._ensureNotDisposed(); - return this._ctx._branchGetPosition(this._handle); - } - - /** @returns {number} Branch's perplexity (exp of mean surprisal) */ - get perplexity() { - this._ensureNotDisposed(); - return this._ctx._branchGetPerplexity(this._handle); - } - - /** @returns {number} Internal handle (for debugging) */ - get handle() { - return this._handle; - } - - /** @returns {boolean} Whether this branch has been disposed */ - get disposed() { - return this._disposed; - } - - /** @returns {number|null} Parent branch handle, or null if root */ - get parent() { - this._ensureNotDisposed(); - const h = this._ctx._branchParent(this._handle); - return h === 0 ? null : h; - } - - /** @returns {number[]} Child branch handles */ - get children() { - this._ensureNotDisposed(); - return this._ctx._branchChildren(this._handle); - } - - /** @returns {boolean} True if this branch has no children */ - get isLeaf() { - this._ensureNotDisposed(); - return this._ctx._branchIsLeaf(this._handle); - } - - /** @returns {boolean} True if this branch holds a KV lease */ - get isActive() { - this._ensureNotDisposed(); - return this._ctx._branchIsActive(this._handle); - } - - // ===== ASYNC ITERATION ===== - - /** - * Async iterator β€” generate tokens until EOG - * - * Commit-before-yield: every yielded token is already written to KV and - * accepted into the sampler. Breaking out of the loop is clean β€” no - * orphaned uncommitted tokens, perplexity reflects all yielded tokens. - * - * For inspect-before-commit (speculative decoding, tree search), use - * the produce()/commit() protocol directly. - */ - async *[Symbol.asyncIterator]() { - while (!this._disposed) { - const { token, text, isStop } = await this.produce(); - if (isStop) return; - await this.commit(token); - yield { token, text }; - } - } - - // ===== INTERNAL ===== - - _ensureNotDisposed() { - if (this._disposed) { - throw new Error('Branch has been disposed'); - } - } -} - -module.exports = { Branch }; diff --git a/lib/BranchStore.js b/lib/BranchStore.js deleted file mode 100644 index 8b14030..0000000 --- a/lib/BranchStore.js +++ /dev/null @@ -1,43 +0,0 @@ -/** - * BranchStore - Batched multi-branch decode operations - * - * See index.d.ts for full API documentation. - */ -class BranchStore { - constructor(ctx) { - this._ctx = ctx; - } - - // entries: [branch, token][] β€” binding is structural, not positional - async commit(entries) { - const handles = [], tokens = []; - for (const [branch, token] of entries) { - if (branch.disposed) throw new Error('BranchStore.commit: branch is disposed'); - handles.push(branch.handle); - tokens.push(token); - } - await this._ctx._storeCommit(handles, tokens); - } - - // entries: [branch, tokens[]][] β€” binding is structural, not positional - async prefill(entries) { - const handles = [], tokenArrays = []; - for (const [branch, tokens] of entries) { - if (branch.disposed) throw new Error('BranchStore.prefill: branch is disposed'); - handles.push(branch.handle); - tokenArrays.push(tokens); - } - await this._ctx._storePrefill(handles, tokenArrays); - } - - async retainOnly(winner) { - if (winner.disposed) throw new Error('BranchStore.retainOnly: winner is disposed'); - this._ctx._storeRetainOnly(winner.handle); - } - - get available() { - return this._ctx._storeAvailable(); - } -} - -module.exports = { BranchStore }; diff --git a/lib/index.js b/lib/index.js deleted file mode 100644 index e7719bc..0000000 --- a/lib/index.js +++ /dev/null @@ -1,259 +0,0 @@ -/** - * liblloyal-node - Thin N-API wrapper over liblloyal - * - * Exposes raw llama.cpp inference primitives for Node.js. - * Primary use case: Integration testing for tsampler. - * - * @example - * ```js - * const { createContext } = require('@lloyal-labs/lloyal.node'); - * - * const ctx = await createContext({ - * modelPath: './model.gguf', - * nCtx: 2048, - * nThreads: 4 - * }); - * - * // Tokenize - * const tokens = await ctx.tokenize("Hello world"); - * - * // Generate via Branch API - * const branch = Branch.create(ctx, 0, { temperature: 0.7 }); - * await branch.prefill(tokens); - * for await (const { text } of branch) { - * process.stdout.write(text); - * } - * await branch.prune(); - * - * // Cleanup - * ctx.dispose(); - * ``` - * - * @example GPU variant selection - * ```js - * // Option 1: Environment variable (affects all contexts) - * // Set LLOYAL_GPU=cuda before running - * - * // Option 2: Per-context selection (recommended) - * const ctx = await createContext( - * { modelPath: './model.gguf', nCtx: 4096 }, - * { gpuVariant: 'cuda' } // Falls back to CPU if CUDA unavailable - * ); - * ``` - */ - -/** - * Platform package naming: @lloyal-labs/lloyal.node-{platform}-{arch}[-{gpu}] - * @param {string} [variant] - GPU variant: 'cuda', 'vulkan', or undefined for CPU - * @returns {string} Platform package name - */ -const getPlatformPackageName = (variant) => { - const platform = process.platform; - const arch = process.arch; - // cpu/metal/default = no suffix, cuda/vulkan = suffix - const noSuffix = !variant || variant === 'default' || variant === 'cpu' || variant === 'metal'; - const suffix = noSuffix ? '' : `-${variant}`; - return `@lloyal-labs/lloyal.node-${platform}-${arch}${suffix}`; -}; - -/** - * Try to load a platform package, return null on failure. - * Failures include: package not installed, missing GPU runtime libs (dlopen fails), - * or module doesn't export expected interface. - * @param {string} packageName - Package name to load - * @param {boolean} [verbose=false] - Log failure reasons - * @returns {object|null} The native binary module or null - */ -const tryLoadPackage = (packageName, verbose = false) => { - try { - const mod = require(packageName); - // Validate it's actually a native module with expected exports - if (mod && typeof mod.createContext === 'function') { - return mod; - } - if (verbose) { - console.warn(`[lloyal.node] ${packageName} loaded but missing createContext export`); - } - return null; - } catch (e) { - if (verbose) { - console.warn(`[lloyal.node] Failed to load ${packageName}: ${e.message}`); - } - return null; - } -}; - -/** - * Load the native binary with automatic fallback. - * - * **Loading Priority:** - * - * When `LLOYAL_LOCAL=1`: - * - Uses local build exclusively (`build/Release/lloyal.node`) - * - Throws error if not found (no fallback) - * - * Otherwise: - * 1. Requested GPU variant package (if `variant` param or `LLOYAL_GPU` env var specified) - * 2. Local build (`build/Release/lloyal.node`) β€” always fresher during development - * 3. Default platform package (`@lloyal-labs/lloyal.node-{platform}-{arch}`) - * - * **Environment Variables:** - * - `LLOYAL_LOCAL=1` β€” Use local build exclusively (`build/Release/lloyal.node`). - * Throws an error if local build not found. Use during development to test - * local changes without uninstalling npm packages. - * - `LLOYAL_GPU` β€” GPU variant to load: `'cuda'` or `'vulkan'`. Equivalent to - * passing the `variant` parameter. - * - `LLOYAL_NO_FALLBACK=1` β€” Disable fallback when GPU variant fails. Throws an - * error instead of silently falling back to CPU. Use in CI to ensure the - * specific GPU package loads correctly and catch missing runtime libraries. - * - * @param {string} [variant] - GPU variant: `'cuda'`, `'vulkan'`, or `undefined` for CPU. - * Overrides `LLOYAL_GPU` env var if specified. - * @returns {object} The native binary module with `createContext` and `SessionContext` - * @throws {Error} If no binary can be loaded for the current platform - * - * @example Development testing with local build - * ```bash - * # Build locally, then test without uninstalling npm packages - * npm run build - * LLOYAL_LOCAL=1 node my-script.js - * ``` - * - * @example GPU variant selection - * ```bash - * # Via environment variable - * LLOYAL_GPU=cuda node my-script.js - * - * # Or programmatically - * const binary = loadBinary('cuda'); - * ``` - * - * @example CI: Ensure GPU package loads (no silent fallback) - * ```bash - * LLOYAL_GPU=cuda LLOYAL_NO_FALLBACK=1 npm test - * ``` - */ -const loadBinary = (variant) => { - // Use env var if no variant specified - variant = variant ?? process.env.LLOYAL_GPU; - // LLOYAL_NO_FALLBACK=1 disables fallback (for CI testing specific packages) - const noFallback = process.env.LLOYAL_NO_FALLBACK === '1'; - // LLOYAL_LOCAL=1 forces local build first (development) - const useLocal = process.env.LLOYAL_LOCAL === '1'; - - // 0. Use local build if explicitly requested (no fallback) - if (useLocal) { - try { - return require('../build/Release/lloyal.node'); - } catch (e) { - throw new Error( - '[lloyal.node] LLOYAL_LOCAL=1 but local build not found. ' + - 'Run `npm run build` first.' - ); - } - } - - // 1. Try requested variant (if specified) - if (variant && variant !== 'default') { - const pkgName = getPlatformPackageName(variant); - const binary = tryLoadPackage(pkgName, true); // verbose=true to see errors - if (binary) return binary; - - if (noFallback) { - throw new Error( - `[lloyal.node] GPU variant "${variant}" failed to load. ` + - `Package: ${pkgName}. Check that runtime libraries are available.` - ); - } - console.warn(`[lloyal.node] GPU variant "${variant}" unavailable, falling back to CPU`); - } - - // 2. Try local build (always fresher than installed packages during development) - try { - return require('../build/Release/lloyal.node'); - } catch (e) { - // ignore β€” no local build - } - - // 3. Try default platform package (CPU) - const defaultPkg = getPlatformPackageName(); - const binary = tryLoadPackage(defaultPkg, true); // verbose=true - if (binary) return binary; - - throw new Error( - `No lloyal.node binary found for ${process.platform}-${process.arch}. ` + - `Tried: ${variant ? getPlatformPackageName(variant) + ', ' : ''}${defaultPkg}` - ); -}; - -// Default binary (loaded lazily on first use) -let _binary = null; -const getBinary = () => { - if (!_binary) { - _binary = loadBinary(process.env.LLOYAL_GPU); - } - return _binary; -}; - -const { Branch } = require('./Branch'); -const { BranchStore } = require('./BranchStore'); - -module.exports = { - /** - * Branch class for parallel generation - * @see Branch.create() - */ - Branch, - /** - * BranchStore class for batched multi-branch decode - * @see BranchStore - */ - BranchStore, - /** - * Create a new inference context - * - * @param {ContextOptions} options - Context configuration - * @param {LoadOptions} [loadOptions] - Binary loading options - * @returns {Promise} The inference context - * - * @example - * ```js - * // Basic usage - * const ctx = await createContext({ - * modelPath: './model.gguf', - * nCtx: 2048, - * nThreads: 4 - * }); - * - * // With GPU variant - * const ctx = await createContext( - * { modelPath: './model.gguf' }, - * { gpuVariant: 'cuda' } - * ); - * ``` - */ - createContext: async (options, loadOptions) => { - const variant = loadOptions?.gpuVariant || process.env.LLOYAL_GPU; - const binary = variant ? loadBinary(variant) : getBinary(); - return binary.createContext(options); - }, - - /** - * Load binary for a specific GPU variant. - * Useful for checking variant availability before creating context. - * - * @param {string} [variant] - 'cuda', 'vulkan', or undefined for CPU - * @returns {object} Native binary module - * @throws {Error} If no binary available for platform - * - * @example - * ```js - * // Load default (CPU) binary - * const binary = loadBinary(); - * - * // Load CUDA binary (falls back to CPU if unavailable) - * const binary = loadBinary('cuda'); - * ``` - */ - loadBinary, -}; diff --git a/liblloyal b/liblloyal index 388e255..4082e2e 160000 --- a/liblloyal +++ b/liblloyal @@ -1 +1 @@ -Subproject commit 388e255adad2eda1a4e18c8e25345404fee39573 +Subproject commit 4082e2eab6618b800753d462e8ad3773541fb5f3 diff --git a/package-lock.json b/package-lock.json index 09cdb1a..e618748 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10,13 +10,17 @@ "license": "Apache-2.0", "dependencies": { "@lloyal-labs/tsampler": "^0.2.0", + "effection": "^4.0.2", "node-addon-api": "^8.5.0" }, "devDependencies": { + "@types/node": "^25.3.0", "cmake-js": "^8.0.0", "glob": "^11.0.0", + "tsx": "^4.21.0", "typedoc": "^0.28.16", - "typedoc-rhineai-theme": "^1.2.0" + "typedoc-rhineai-theme": "^1.2.0", + "typescript": "^5.9.3" }, "engines": { "node": ">=22.0.0" @@ -37,6 +41,448 @@ "@lloyal-labs/lloyal.node-win32-x64-vulkan": "1.6.0" } }, + "node_modules/@esbuild/aix-ppc64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.27.3.tgz", + "integrity": "sha512-9fJMTNFTWZMh5qwrBItuziu834eOCUcEqymSH7pY+zoMVEZg3gcPuBNxH1EvfVYe9h0x/Ptw8KBzv7qxb7l8dg==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "aix" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.27.3.tgz", + "integrity": "sha512-i5D1hPY7GIQmXlXhs2w8AWHhenb00+GxjxRncS2ZM7YNVGNfaMxgzSGuO8o8SJzRc/oZwU2bcScvVERk03QhzA==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.27.3.tgz", + "integrity": "sha512-YdghPYUmj/FX2SYKJ0OZxf+iaKgMsKHVPF1MAq/P8WirnSpCStzKJFjOjzsW0QQ7oIAiccHdcqjbHmJxRb/dmg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-x64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.27.3.tgz", + "integrity": "sha512-IN/0BNTkHtk8lkOM8JWAYFg4ORxBkZQf9zXiEOfERX/CzxW3Vg1ewAhU7QSWQpVIzTW+b8Xy+lGzdYXV6UZObQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-arm64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.27.3.tgz", + "integrity": "sha512-Re491k7ByTVRy0t3EKWajdLIr0gz2kKKfzafkth4Q8A5n1xTHrkqZgLLjFEHVD+AXdUGgQMq+Godfq45mGpCKg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-x64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.27.3.tgz", + "integrity": "sha512-vHk/hA7/1AckjGzRqi6wbo+jaShzRowYip6rt6q7VYEDX4LEy1pZfDpdxCBnGtl+A5zq8iXDcyuxwtv3hNtHFg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-arm64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.27.3.tgz", + "integrity": "sha512-ipTYM2fjt3kQAYOvo6vcxJx3nBYAzPjgTCk7QEgZG8AUO3ydUhvelmhrbOheMnGOlaSFUoHXB6un+A7q4ygY9w==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-x64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.27.3.tgz", + "integrity": "sha512-dDk0X87T7mI6U3K9VjWtHOXqwAMJBNN2r7bejDsc+j03SEjtD9HrOl8gVFByeM0aJksoUuUVU9TBaZa2rgj0oA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.27.3.tgz", + "integrity": "sha512-s6nPv2QkSupJwLYyfS+gwdirm0ukyTFNl3KTgZEAiJDd+iHZcbTPPcWCcRYH+WlNbwChgH2QkE9NSlNrMT8Gfw==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.27.3.tgz", + "integrity": "sha512-sZOuFz/xWnZ4KH3YfFrKCf1WyPZHakVzTiqji3WDc0BCl2kBwiJLCXpzLzUBLgmp4veFZdvN5ChW4Eq/8Fc2Fg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ia32": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.27.3.tgz", + "integrity": "sha512-yGlQYjdxtLdh0a3jHjuwOrxQjOZYD/C9PfdbgJJF3TIZWnm/tMd/RcNiLngiu4iwcBAOezdnSLAwQDPqTmtTYg==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-loong64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.27.3.tgz", + "integrity": "sha512-WO60Sn8ly3gtzhyjATDgieJNet/KqsDlX5nRC5Y3oTFcS1l0KWba+SEa9Ja1GfDqSF1z6hif/SkpQJbL63cgOA==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-mips64el": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.27.3.tgz", + "integrity": "sha512-APsymYA6sGcZ4pD6k+UxbDjOFSvPWyZhjaiPyl/f79xKxwTnrn5QUnXR5prvetuaSMsb4jgeHewIDCIWljrSxw==", + "cpu": [ + "mips64el" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ppc64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.27.3.tgz", + "integrity": "sha512-eizBnTeBefojtDb9nSh4vvVQ3V9Qf9Df01PfawPcRzJH4gFSgrObw+LveUyDoKU3kxi5+9RJTCWlj4FjYXVPEA==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-riscv64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.27.3.tgz", + "integrity": "sha512-3Emwh0r5wmfm3ssTWRQSyVhbOHvqegUDRd0WhmXKX2mkHJe1SFCMJhagUleMq+Uci34wLSipf8Lagt4LlpRFWQ==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-s390x": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.27.3.tgz", + "integrity": "sha512-pBHUx9LzXWBc7MFIEEL0yD/ZVtNgLytvx60gES28GcWMqil8ElCYR4kvbV2BDqsHOvVDRrOxGySBM9Fcv744hw==", + "cpu": [ + "s390x" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-x64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.27.3.tgz", + "integrity": "sha512-Czi8yzXUWIQYAtL/2y6vogER8pvcsOsk5cpwL4Gk5nJqH5UZiVByIY8Eorm5R13gq+DQKYg0+JyQoytLQas4dA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-arm64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.27.3.tgz", + "integrity": "sha512-sDpk0RgmTCR/5HguIZa9n9u+HVKf40fbEUt+iTzSnCaGvY9kFP0YKBWZtJaraonFnqef5SlJ8/TiPAxzyS+UoA==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-x64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.27.3.tgz", + "integrity": "sha512-P14lFKJl/DdaE00LItAukUdZO5iqNH7+PjoBm+fLQjtxfcfFE20Xf5CrLsmZdq5LFFZzb5JMZ9grUwvtVYzjiA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-arm64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.27.3.tgz", + "integrity": "sha512-AIcMP77AvirGbRl/UZFTq5hjXK+2wC7qFRGoHSDrZ5v5b8DK/GYpXW3CPRL53NkvDqb9D+alBiC/dV0Fb7eJcw==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-x64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.27.3.tgz", + "integrity": "sha512-DnW2sRrBzA+YnE70LKqnM3P+z8vehfJWHXECbwBmH/CU51z6FiqTQTHFenPlHmo3a8UgpLyH3PT+87OViOh1AQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openharmony-arm64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/openharmony-arm64/-/openharmony-arm64-0.27.3.tgz", + "integrity": "sha512-NinAEgr/etERPTsZJ7aEZQvvg/A6IsZG/LgZy+81wON2huV7SrK3e63dU0XhyZP4RKGyTm7aOgmQk0bGp0fy2g==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/sunos-x64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.27.3.tgz", + "integrity": "sha512-PanZ+nEz+eWoBJ8/f8HKxTTD172SKwdXebZ0ndd953gt1HRBbhMsaNqjTyYLGLPdoWHy4zLU7bDVJztF5f3BHA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "sunos" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-arm64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.27.3.tgz", + "integrity": "sha512-B2t59lWWYrbRDw/tjiWOuzSsFh1Y/E95ofKz7rIVYSQkUYBjfSgf6oeYPNWHToFRr2zx52JKApIcAS/D5TUBnA==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-ia32": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.27.3.tgz", + "integrity": "sha512-QLKSFeXNS8+tHW7tZpMtjlNb7HKau0QDpwm49u0vUp9y1WOF+PEzkU84y9GqYaAVW8aH8f3GcBck26jh54cX4Q==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-x64": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.27.3.tgz", + "integrity": "sha512-4uJGhsxuptu3OcpVAzli+/gWusVGwZZHTlS63hh++ehExkVT8SgiEf7/uC/PclrPPkLhZqGgCTjd0VWLo6xMqA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, "node_modules/@gerrit0/mini-shiki": { "version": "3.22.0", "resolved": "https://registry.npmjs.org/@gerrit0/mini-shiki/-/mini-shiki-3.22.0.tgz", @@ -83,43 +529,173 @@ } }, "node_modules/@lloyal-labs/lloyal.node-darwin-arm64": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-darwin-arm64/-/lloyal.node-darwin-arm64-1.6.0.tgz", + "integrity": "sha512-T8Xt2ZSyY7yLQQgVLQZhR4Wb61LuEEnZdSF7+C0wu9BbB/DMyum2Ix6lDsufGf/oXOLiSrwVbNUvIplfE6u7YQ==", + "cpu": [ + "arm64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "darwin" + ] }, "node_modules/@lloyal-labs/lloyal.node-darwin-x64": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-darwin-x64/-/lloyal.node-darwin-x64-1.6.0.tgz", + "integrity": "sha512-AlHhmFFoU8J1BNsqGc0leok0R+Ot4jzm3d1O/atPAi8EMFmFoSI6/af9iJcKy1//+goiPGlA1sl3BH+d7o/syw==", + "cpu": [ + "x64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "darwin" + ] }, "node_modules/@lloyal-labs/lloyal.node-linux-arm64": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-linux-arm64/-/lloyal.node-linux-arm64-1.6.0.tgz", + "integrity": "sha512-cQJpiy061atIRRYErbnP6UjFo6owxa2dGEIGQ4u/DcwDaGX/cQMnT/PQCxHMcjGNVb8M6eDwL0Qc07SwCHikMg==", + "cpu": [ + "arm64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ] }, "node_modules/@lloyal-labs/lloyal.node-linux-arm64-cuda": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-linux-arm64-cuda/-/lloyal.node-linux-arm64-cuda-1.6.0.tgz", + "integrity": "sha512-K0qIdYBOWctBOwrXmTmcgoSRhDLWdBsCZTMLiYTIp8wVDF6SEBmVBNG4DNTqCaM9482RiKoJcKLnNwleCCkiLw==", + "cpu": [ + "arm64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ] }, "node_modules/@lloyal-labs/lloyal.node-linux-arm64-vulkan": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-linux-arm64-vulkan/-/lloyal.node-linux-arm64-vulkan-1.6.0.tgz", + "integrity": "sha512-7+3f8gUa3e8j/DatoZqrllcUkJgTHcjWzIOlP1t5443ed9pKlZWonMx/1hPyF51rIMvozodR7QVCtY12hYSOfg==", + "cpu": [ + "arm64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ] }, "node_modules/@lloyal-labs/lloyal.node-linux-x64": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-linux-x64/-/lloyal.node-linux-x64-1.6.0.tgz", + "integrity": "sha512-fr7h/rcpDCehveoJMskohq3mGT4moU5NqFKlaXOZEADC1PXNjjmTktcznU4SRAFGYjBhlAF3uDr4JM7CQBNY8Q==", + "cpu": [ + "x64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ] }, "node_modules/@lloyal-labs/lloyal.node-linux-x64-cuda": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-linux-x64-cuda/-/lloyal.node-linux-x64-cuda-1.6.0.tgz", + "integrity": "sha512-ddD4NtOGUzSFSYfBZLs42Y6mEplsMrQr6UXYbz1HBHUTQ867lJodF+nxsos03lcIAuAWxRCLFAUz2Nk8KgRrQA==", + "cpu": [ + "x64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ] }, "node_modules/@lloyal-labs/lloyal.node-linux-x64-vulkan": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-linux-x64-vulkan/-/lloyal.node-linux-x64-vulkan-1.6.0.tgz", + "integrity": "sha512-F0t83fJnNJl9LZB+5kwXhvqLV1ZtXrFLNp8c/4JySA1lDvnZuGt8AyGbfzqMO938X5y86yiHafNsbGQgdT7OpA==", + "cpu": [ + "x64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ] }, "node_modules/@lloyal-labs/lloyal.node-win32-arm64": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-win32-arm64/-/lloyal.node-win32-arm64-1.6.0.tgz", + "integrity": "sha512-BHrSLnMlYnJ1YloRziLL2VqiMswXPtLFaEKmsq4EyVRBYTcs8rGxxobptM3Kf5FrDnjL5PYSxPnwpSy90tQk0Q==", + "cpu": [ + "arm64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "win32" + ] }, "node_modules/@lloyal-labs/lloyal.node-win32-arm64-vulkan": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-win32-arm64-vulkan/-/lloyal.node-win32-arm64-vulkan-1.6.0.tgz", + "integrity": "sha512-o4gBVTXCYLF/gPWfsoTXwSL/uyu/b54sk+nMDyAue5TwBPos1vfYPjm02wNIEv9K17NbuG8Cz05cDU+xSn34OQ==", + "cpu": [ + "arm64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "win32" + ] }, "node_modules/@lloyal-labs/lloyal.node-win32-x64": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-win32-x64/-/lloyal.node-win32-x64-1.6.0.tgz", + "integrity": "sha512-PifF8Iy1IOfJIPa32ppI2ODaKk0x0Cmvkgg2nExLuG+DhxYc7KeTAJQ5es605ll9hTUOXrQVIPSZORBHYzUEbg==", + "cpu": [ + "x64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "win32" + ] }, "node_modules/@lloyal-labs/lloyal.node-win32-x64-cuda": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-win32-x64-cuda/-/lloyal.node-win32-x64-cuda-1.6.0.tgz", + "integrity": "sha512-OUIc4G1tkxJp3N5VUhSanmZsPAcffg0JbUyhlNVV9v4EilQ7wxy3s0fN0gQfDl0XEaUXTc2zujuibdj3htLG9Q==", + "cpu": [ + "x64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "win32" + ] }, "node_modules/@lloyal-labs/lloyal.node-win32-x64-vulkan": { - "optional": true + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@lloyal-labs/lloyal.node-win32-x64-vulkan/-/lloyal.node-win32-x64-vulkan-1.6.0.tgz", + "integrity": "sha512-q39Kh+SYGna/8DDQ8ncEHAMRym57IerLrLQ876+gO79HjdgdGfayw1H1yclmIqWczYzQuY/qPFxWTPoF3Pmdvw==", + "cpu": [ + "x64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "win32" + ] }, "node_modules/@lloyal-labs/tsampler": { "version": "0.2.0", @@ -186,6 +762,16 @@ "@types/unist": "*" } }, + "node_modules/@types/node": { + "version": "25.3.0", + "resolved": "https://registry.npmjs.org/@types/node/-/node-25.3.0.tgz", + "integrity": "sha512-4K3bqJpXpqfg2XKGK9bpDTc6xO/xoUP/RBWS7AtRMug6zZFaRekiLzjVtAoZMquxoAbzBvy5nxQ7veS5eYzf8A==", + "dev": true, + "license": "MIT", + "dependencies": { + "undici-types": "~7.18.0" + } + }, "node_modules/@types/unist": { "version": "3.0.3", "resolved": "https://registry.npmjs.org/@types/unist/-/unist-3.0.3.tgz", @@ -464,6 +1050,15 @@ "dev": true, "license": "MIT" }, + "node_modules/effection": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/effection/-/effection-4.0.2.tgz", + "integrity": "sha512-O8WMGP10nPuJDwbNGILcaCNWS+CvDYjcdsUSD79nWZ+WtUQ8h1MEV7JJwCSZCSeKx8+TdEaZ/8r6qPTR2o/o8w==", + "license": "MIT", + "engines": { + "node": ">= 16" + } + }, "node_modules/emoji-regex": { "version": "9.2.2", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz", @@ -484,6 +1079,48 @@ "url": "https://github.com/fb55/entities?sponsor=1" } }, + "node_modules/esbuild": { + "version": "0.27.3", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.27.3.tgz", + "integrity": "sha512-8VwMnyGCONIs6cWue2IdpHxHnAjzxnw2Zr7MkVxB2vjmQ2ivqGFb4LEG3SMnv0Gb2F/G/2yA8zUaiL1gywDCCg==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "@esbuild/aix-ppc64": "0.27.3", + "@esbuild/android-arm": "0.27.3", + "@esbuild/android-arm64": "0.27.3", + "@esbuild/android-x64": "0.27.3", + "@esbuild/darwin-arm64": "0.27.3", + "@esbuild/darwin-x64": "0.27.3", + "@esbuild/freebsd-arm64": "0.27.3", + "@esbuild/freebsd-x64": "0.27.3", + "@esbuild/linux-arm": "0.27.3", + "@esbuild/linux-arm64": "0.27.3", + "@esbuild/linux-ia32": "0.27.3", + "@esbuild/linux-loong64": "0.27.3", + "@esbuild/linux-mips64el": "0.27.3", + "@esbuild/linux-ppc64": "0.27.3", + "@esbuild/linux-riscv64": "0.27.3", + "@esbuild/linux-s390x": "0.27.3", + "@esbuild/linux-x64": "0.27.3", + "@esbuild/netbsd-arm64": "0.27.3", + "@esbuild/netbsd-x64": "0.27.3", + "@esbuild/openbsd-arm64": "0.27.3", + "@esbuild/openbsd-x64": "0.27.3", + "@esbuild/openharmony-arm64": "0.27.3", + "@esbuild/sunos-x64": "0.27.3", + "@esbuild/win32-arm64": "0.27.3", + "@esbuild/win32-ia32": "0.27.3", + "@esbuild/win32-x64": "0.27.3" + } + }, "node_modules/escalade": { "version": "3.2.0", "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", @@ -526,6 +1163,21 @@ "node": ">=14.14" } }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, "node_modules/get-caller-file": { "version": "2.0.5", "resolved": "https://registry.npmjs.org/get-caller-file/-/get-caller-file-2.0.5.tgz", @@ -536,6 +1188,19 @@ "node": "6.* || 8.* || >= 10.*" } }, + "node_modules/get-tsconfig": { + "version": "4.13.6", + "resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.13.6.tgz", + "integrity": "sha512-shZT/QMiSHc/YBLxxOkMtgSid5HFoauqCE3/exfsEcwg1WkeqjG+V40yBbBrsD+jW2HDXcs28xOfcbm2jI8Ddw==", + "dev": true, + "license": "MIT", + "dependencies": { + "resolve-pkg-maps": "^1.0.0" + }, + "funding": { + "url": "https://github.com/privatenumber/get-tsconfig?sponsor=1" + } + }, "node_modules/glob": { "version": "11.1.0", "resolved": "https://registry.npmjs.org/glob/-/glob-11.1.0.tgz", @@ -840,6 +1505,16 @@ "node": ">=0.10.0" } }, + "node_modules/resolve-pkg-maps": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz", + "integrity": "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1" + } + }, "node_modules/semver": { "version": "7.7.3", "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", @@ -1020,6 +1695,26 @@ "node": ">=18" } }, + "node_modules/tsx": { + "version": "4.21.0", + "resolved": "https://registry.npmjs.org/tsx/-/tsx-4.21.0.tgz", + "integrity": "sha512-5C1sg4USs1lfG0GFb2RLXsdpXqBSEhAaA/0kPL01wxzpMqLILNxIxIOKiILz+cdg/pLnOUxFYOR5yhHU666wbw==", + "dev": true, + "license": "MIT", + "dependencies": { + "esbuild": "~0.27.0", + "get-tsconfig": "^4.7.5" + }, + "bin": { + "tsx": "dist/cli.mjs" + }, + "engines": { + "node": ">=18.0.0" + }, + "optionalDependencies": { + "fsevents": "~2.3.3" + } + }, "node_modules/typedoc": { "version": "0.28.17", "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.28.17.tgz", @@ -1074,12 +1769,11 @@ } }, "node_modules/typescript": { - "version": "5.8.3", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.8.3.tgz", - "integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==", + "version": "5.9.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.3.tgz", + "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "dev": true, "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -1095,6 +1789,13 @@ "dev": true, "license": "MIT" }, + "node_modules/undici-types": { + "version": "7.18.2", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.18.2.tgz", + "integrity": "sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==", + "dev": true, + "license": "MIT" + }, "node_modules/universalify": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/universalify/-/universalify-2.0.1.tgz", diff --git a/package.json b/package.json index de6e841..8cd4138 100644 --- a/package.json +++ b/package.json @@ -2,25 +2,29 @@ "name": "@lloyal-labs/lloyal.node", "version": "1.6.0", "description": "Node.js client for liblloyal+llama.cpp", - "main": "lib/index.js", - "types": "lib/index.d.ts", + "main": "dist/index.js", + "types": "dist/index.d.ts", "gypfile": false, "publishConfig": { "access": "public" }, "scripts": { "download-models": "bash scripts/download-test-models.sh", - "build": "node scripts/build.js", + "build:native": "node scripts/build.js", + "build:ts": "tsc", + "build:test": "tsc -p tsconfig.test.json", + "build": "npm run build:ts && npm run build:native", "build:debug": "cmake-js compile --debug", "rebuild": "cmake-js rebuild", - "clean": "cmake-js clean && rm -rf build_test/", + "clean": "cmake-js clean && rm -rf build_test/ dist/", "version": "node scripts/sync-versions.js && git add -A", "docs": "npx typedoc", "test": "npm run test:integration", - "test:integration": "node test/integration.js", - "test:examples": "node test/examples.js", + "test:integration": "npx tsx test/integration.ts", + "test:agents": "npx tsx test/agents.ts", + "test:examples": "npx tsx test/examples.ts", "sync:llama-cpp": "node scripts/sync-llama-cpp.js", - "example": "node examples/chat/chat.mjs" + "example": "npx tsx examples/chat/chat.ts" }, "repository": { "type": "git", @@ -44,13 +48,17 @@ "homepage": "https://github.com/lloyal-ai/lloyal.node#readme", "dependencies": { "@lloyal-labs/tsampler": "^0.2.0", + "effection": "^4.0.2", "node-addon-api": "^8.5.0" }, "devDependencies": { + "@types/node": "^25.3.0", "cmake-js": "^8.0.0", "glob": "^11.0.0", + "tsx": "^4.21.0", "typedoc": "^0.28.16", - "typedoc-rhineai-theme": "^1.2.0" + "typedoc-rhineai-theme": "^1.2.0", + "typescript": "^5.9.3" }, "optionalDependencies": { "@lloyal-labs/lloyal.node-darwin-arm64": "1.6.0", @@ -71,7 +79,7 @@ "node": ">=22.0.0" }, "files": [ - "lib/", + "dist/", "scripts/" ] } diff --git a/src/Branch.ts b/src/Branch.ts new file mode 100644 index 0000000..e44300e --- /dev/null +++ b/src/Branch.ts @@ -0,0 +1,650 @@ +import type { SessionContext, SamplingParams, Produced, GrammarTrigger } from './types'; +import { GrammarTriggerType } from './types'; + +/** + * Forkable inference handle for covalent generation + * + * A Branch owns everything needed for independent generation: a KV cache + * sequence, sampler chain, logits snapshot, and perplexity tracker. + * + * Forking is cheap β€” the KV prefix is shared in memory (metadata-only operation under unified KV β€” + * no KV tensor buffers are copied), so sibling branches read from the same physical KV entries. + * Only tokens decoded after the fork point are exclusive to each branch. + * + * Branches form trees, not just flat lists. Fork from root for best-of-N, + * fork from children for tree search/beam search, fork from a draft for speculative + * decoding. + * + * The produce/commit protocol separates sampling from state advancement: + * produce() samples without writing to KV, letting you inspect the result + * before deciding to commit(). + * + * @example Best-of-N with perplexity selection + * ```typescript + * const root = Branch.create(ctx, tokens.length, { temperature: 0.8 }); + * await root.prefill(tokens); + * + * const results = []; + * for (let i = 0; i < 5; i++) { + * const branch = await root.fork(); + * branch.reseedSampler(1000 + i); + * const tokens = []; + * for await (const { token } of branch) tokens.push(token); + * results.push({ branch, tokens, ppl: branch.perplexity }); + * } + * + * const best = results.reduce((a, b) => a.ppl < b.ppl ? a : b); + * for (const r of results) { if (r !== best) await r.branch.prune(); } + * ``` + * + * @category Branching + */ +export class Branch { + private _ctx: SessionContext; + private _handle: number; + private _disposed: boolean; + + constructor(ctx: SessionContext, handle: number) { + this._ctx = ctx; + this._handle = handle; + this._disposed = false; + } + + /** + * Create a root branch at the given position + * + * The branch takes ownership of the sequence and creates its own sampler + * chain from the provided params. Call prefill() to decode prompt tokens + * and capture the logit distribution before forking. + * + * @param ctx - SessionContext to create branch on + * @param position - Starting position (typically prompt token count) + * @param params - Sampling parameters (temperature, topP, etc.) + * @param nBatch - Per-branch batch size override (defaults to context nBatch). + * Controls chunk size for prefill(). Has no effect on + * single-token commit() which uses a zero-allocation fast path. + * @param grammar - GBNF grammar string for constrained generation. + * When provided, sample() returns only grammar-valid tokens. The grammar state + * is cloned on fork(), so sibling branches can diverge independently. + * @returns New Branch instance + */ + static create( + ctx: SessionContext, + position: number, + params?: SamplingParams, + nBatch?: number, + grammar?: string + ): Branch { + const handle = ctx._branchCreate(position, params, nBatch, grammar); + return new Branch(ctx, handle); + } + + /** + * Fork this branch to a new sequence (async) + * + * Async contract: local branches resolve immediately; cloud branches + * may perform an HTTP round-trip. Use {@link forkSync} when you know + * the branch is local and want zero-overhead forking. + * + * @returns New forked Branch + */ + async fork(): Promise { + return this.forkSync(); + } + + /** + * Fork this branch to a new sequence (sync) + * + * The child shares the parent's KV prefix in memory (metadata-only under unified KV, no KV buffer copy). + * Logits, sampler state, and perplexity tracker are cloned so the child + * can diverge independently. Fork from any branch β€” root or intermediate β€” + * to build arbitrarily deep trees. + * + * Call reseedSampler() on each child for stochastic diversity. + * + * @returns New forked Branch + */ + forkSync(): Branch { + this._ensureNotDisposed(); + const newHandle = this._ctx._branchFork(this._handle); + return new Branch(this._ctx, newHandle); + } + + /** + * Get a copy of this branch's captured logits snapshot. + * + * Returns n_vocab floats β€” the raw logit distribution from the last + * prefill() or commit() call. + * + * Returns an independent copy of the branch's internal snapshot. + * The returned Float32Array is safe to hold across async boundaries + * and is not affected by subsequent decode operations. + * + * @returns Independent copy of the logits snapshot (n_vocab elements) + * @throws If no logits have been captured yet + */ + getLogits(): Float32Array { + this._ensureNotDisposed(); + return this._ctx._branchGetLogits(this._handle); + } + + /** + * Bulk-decode tokens into the branch's KV cache and capture logits. + * + * `tokens.length` is the total count to process; the branch's `nBatch` + * (set at `Branch.create`) controls how many are sent per `llama_decode` + * call. E.g. 500 tokens with `nBatch=64` β†’ 8 calls (7Γ—64 + 1Γ—52). + * + * Advances `position` by `tokens.length`. Stores final logits into the + * branch's internal snapshot β€” the next `produce()`/`sample()` reads + * from it. + * + * Does NOT accept tokens into the repeat-penalty window β€” for external + * tokens (user input between turns), not model-generated tokens. + * For model output, use `commit()` which does accept + decode. + * + * The primary way to feed tokens into a branch's KV cache. + * + * @param tokens - Token IDs to decode + */ + async prefill(tokens: number[]): Promise { + this._ensureNotDisposed(); + await this._ctx._branchPrefill(this._handle, tokens); + } + + /** + * Sample next token from branch's logits snapshot + * + * Applies the branch's full sampler chain (top-k, top-p, temperature, + * repeat/presence penalties) to the captured logits. + * + * @returns Sampled token ID + */ + sample(): number { + this._ensureNotDisposed(); + return this._ctx._branchSample(this._handle); + } + + /** + * Record token in the sampler's repeat/presence penalty window + * + * @param token - Token to accept + */ + accept(token: number): void { + this._ensureNotDisposed(); + this._ctx._branchAccept(this._handle, token); + } + + /** + * Discard this branch (async) + * + * Async contract: local branches resolve immediately; cloud branches + * may perform an HTTP round-trip. Use {@link pruneSync} when you know + * the branch is local. + * + * RESTRICT mode: throws if children exist. Use {@link pruneSubtree} to + * cascade-delete an entire subtree. + */ + async prune(): Promise { + this.pruneSync(); + } + + /** + * Discard this branch β€” remove its divergent KV entries and free the handle (sync) + * + * Only removes KV entries divergent from the shared prefix; sibling branches + * are unaffected. The disposed flag is set synchronously β€” any call to + * produce(), commit(), etc. after prune() will throw immediately. + * + * RESTRICT mode: throws if children exist. Use {@link pruneSubtreeSync} to + * cascade-delete an entire subtree. + */ + pruneSync(): void { + if (this._disposed) return; + const kids = this.children; + if (kids.length > 0) { + throw new Error( + `Branch.prune(): branch ${this._handle} has ${kids.length} active child(ren) ` + + `[${kids.join(', ')}]. Prune children first or use pruneSubtree().`, + ); + } + this._ctx._branchPrune(this._handle); + this._disposed = true; + } + + /** + * Discard this branch and all its descendants (async) + * + * Async contract: local branches resolve immediately; cloud branches + * may perform an HTTP round-trip. Use {@link pruneSubtreeSync} when you know + * the branch is local. + */ + async pruneSubtree(): Promise { + this.pruneSubtreeSync(); + } + + /** + * Discard this branch and all its descendants β€” CASCADE delete (sync) + * + * Iterative post-order traversal: prunes children first, then this branch. + * Use when tearing down an entire subtree (e.g. abandoned search path). + * Sets disposed synchronously. + */ + pruneSubtreeSync(): void { + if (this._disposed) return; + this._ctx._branchPruneSubtree(this._handle); + this._disposed = true; + } + + /** + * Reseed the sampler's PRNG for diversity after fork() + * + * CRITICAL for parallel generation: Without reseeding, all forked branches + * produce identical outputs because they share the same PRNG state. + * + * Only affects stochastic samplers (temperature > 0). Greedy samplers are unchanged. + * + * @param seed - New seed for the PRNG + */ + reseedSampler(seed: number): void { + this._ensureNotDisposed(); + this._ctx._branchSamplerChainReseed(this._handle, seed); + } + + /** + * Apply dynamic logit adjustments for this branch only + * + * Unlike `logit_bias` in sampling params (which is cloned on fork), steer biases + * are NOT inherited by child branches. Each branch manages its own steer state + * independently. This makes steer ideal for path-dependent constraints. + * + * **Use cases:** + * - **tsampler**: Block tokens that would create repeated N-grams based on + * this branch's specific generation history + * - **Diverse beam search**: Penalize tokens already chosen by sibling beams + * to encourage output diversity across the beam + * - **Dynamic constraints**: Apply token restrictions that change per-step + * + * **Sampling order:** Grammar β†’ Logit Bias β†’ Steer β†’ Sampler Chain + * + * @param biases - Array of token adjustments. Use `-Infinity` to completely + * block a token, positive values to boost probability, negative to reduce. + * + * @example Block tokens for N-gram deduplication (tsampler pattern) + * ```ts + * // Compute which tokens would create repeated 4-grams + * const blocked = computeNgramBlocks(generatedTokens, n=4); + * + * // Block those tokens for this sample only + * branch.steer(blocked.map(t => ({ token: t, bias: -Infinity }))); + * + * const { token } = await branch.produce(); // Blocked tokens won't be sampled + * await branch.commit(token); + * + * // Clear for next iteration (recompute based on new history) + * branch.clearSteer(); + * ``` + * + * @example Diverse beam search + * ```ts + * // Each beam penalizes tokens chosen by siblings this step + * for (const beam of beams) { + * // Collect tokens chosen by other beams + * const siblingTokens = beams + * .filter(b => b !== beam && b.lastToken !== undefined) + * .map(b => b.lastToken); + * + * // Penalize sibling choices to encourage diversity + * beam.branch.steer(siblingTokens.map(t => ({ token: t, bias: -2.0 }))); + * + * const { token } = await beam.branch.produce(); + * await beam.branch.commit(token); + * beam.lastToken = token; + * beam.branch.clearSteer(); + * } + * ``` + * + * @example Boost specific tokens + * ```ts + * // Boost "yes" and "no" tokens for a yes/no question + * branch.steer([ + * { token: yesTokenId, bias: 5.0 }, + * { token: noTokenId, bias: 5.0 } + * ]); + * ``` + */ + steer(biases: Array<{ token: number; bias: number }>): void { + this._ensureNotDisposed(); + this._ctx._branchSteer(this._handle, biases); + } + + /** + * Clear all steer biases from this branch + * + * Removes any dynamic logit adjustments set by `steer()`. Call this after + * each generation step if your steer constraints are computed per-step + * (e.g., N-gram blocking where the blocked set changes as text grows). + * + * @example Per-step steer pattern + * ```ts + * for (let i = 0; i < maxTokens; i++) { + * // Compute constraints based on current state + * const blocked = computeConstraints(generatedTokens); + * branch.steer(blocked.map(t => ({ token: t, bias: -Infinity }))); + * + * const { token, isStop } = await branch.produce(); + * if (isStop) break; + * + * await branch.commit(token); + * branch.clearSteer(); // Reset for next iteration + * generatedTokens.push(token); + * } + * ``` + */ + clearSteer(): void { + this._ensureNotDisposed(); + this._ctx._branchClearSteer(this._handle); + } + + /** + * Replace the sampler chain with new parameters (memoized) + * + * If the new params match the current chain's params, this is a no-op. + * Otherwise the old chain is freed and a new one is created. Use for + * Entropy-Driven Temperature (EDT) and other adaptive sampling strategies + * that adjust parameters per-step. + * + * @param params - New sampling parameters + * + * @example Entropy-Driven Temperature + * ```typescript + * const entropy = branch.modelEntropy('nats'); + * branch.setSamplerParams({ temperature: edtTemperature(entropy) }); + * const { token } = await branch.produce(); + * await branch.commit(token); + * ``` + */ + setSamplerParams(params: SamplingParams): void { + this._ensureNotDisposed(); + this._ctx._branchSetSamplerParams(this._handle, params); + } + + /** + * Replace or remove the grammar constraint + * + * Pass a GBNF grammar string to constrain generation. Pass empty string + * or undefined to remove the constraint. The grammar state is cloned on + * fork(), so sibling branches can diverge independently after hot-swap. + * + * @param grammarStr - GBNF grammar string, or empty/undefined to remove + * + * @example Hot-swap grammar mid-generation + * ```typescript + * // Start unconstrained, then switch to JSON after detecting tool call + * branch.setGrammar(jsonGrammar); + * const { token } = await branch.produce(); + * ``` + */ + setGrammar(grammarStr?: string): void { + this._ensureNotDisposed(); + this._ctx._branchSetGrammar(this._handle, grammarStr || ''); + } + + /** + * Set lazy grammar β€” unconstrained until trigger, then grammar-constrained + * + * Generation runs freely until a trigger pattern or token fires, at which + * point the grammar activates and constrains subsequent tokens. Used for + * tool-call generation: model writes freely until ``, then + * grammar forces valid XML structure. + * + * The grammar state is cloned on fork(), so sibling branches can diverge + * independently. Call again after a tool result prefill to reset. + * + * @param grammar - GBNF grammar string + * @param triggers - Trigger conditions from formatChat().grammarTriggers + */ + setGrammarLazy(grammar: string, triggers: GrammarTrigger[]): void { + this._ensureNotDisposed(); + const escapeRegex = (s: string) => s.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); + const patterns: string[] = []; + const tokens: number[] = []; + for (const t of triggers) { + switch (t.type) { + case GrammarTriggerType.WORD: + patterns.push(escapeRegex(t.value)); + break; + case GrammarTriggerType.PATTERN: + patterns.push(t.value); + break; + case GrammarTriggerType.PATTERN_FULL: { + const p = t.value; + patterns.push((p[0] !== '^' ? '^' : '') + p + (p[p.length - 1] !== '$' ? '$' : '')); + break; + } + case GrammarTriggerType.TOKEN: + tokens.push(t.token); + break; + } + } + this._ctx._branchSetGrammarLazy(this._handle, grammar, patterns, tokens); + } + + /** + * Sample next token without advancing state (async) + * + * Async contract: local branches resolve immediately; cloud branches + * may perform an HTTP round-trip. Use {@link produceSync} when you know + * the branch is local and want zero-overhead sampling. + */ + async produce(): Promise { + return this.produceSync(); + } + + /** + * Sample next token without advancing state (sync) + * + * Same as {@link produce} but synchronous. Use when you know the branch + * is local and want to avoid the microtick overhead of a promise. + */ + produceSync(): Produced { + this._ensureNotDisposed(); + const token = this.sample(); + return { + token, + text: this._ctx.tokenToText(token), + isStop: this._ctx.isStopToken(token), + }; + } + + /** + * Accept and decode β€” update branch state, then write token to KV + * + * Accepts the token into the sampler penalty window (for correct PPL + * measurement), then decodes (writing to KV cache via AsyncWorker on + * the libuv thread pool) and captures the resulting logits for the next + * produce() call. Accept-first ordering with rollback: if decode throws, + * sampler/grammar/metrics are restored from clones. + * + * @param token Token to commit (from produce()) + */ + async commit(token: number): Promise { + this._ensureNotDisposed(); + await this._ctx._storeCommit([this._handle], [token]); + } + + // ===== METRICS ===== + + /** + * Compute entropy of the branch's logits distribution + * + * Measures model uncertainty from the branch's captured logits snapshot: + * - Low entropy: Model is confident (peaked distribution) + * - High entropy: Model is uncertain (flat distribution) + * + * Operates directly on `state->logits_snapshot` β€” no JS round-trip. + * + * @param base - Logarithm base: "nats" (default) or "bits" + * @returns Entropy value in specified base + * + * COST: O(n_vocab) - must sum over all token probabilities + */ + modelEntropy(base: 'nats' | 'bits' = 'nats'): number { + this._ensureNotDisposed(); + return this._ctx._branchModelEntropy(this._handle, base); + } + + /** + * Compute surprisal (negative log-likelihood) for a specific token + * + * Measures how "surprising" the model finds the given token from + * the branch's captured logits snapshot: + * - Low surprisal: Model expected this token (high probability) + * - High surprisal: Model didn't expect this token (low probability) + * + * Operates directly on `state->logits_snapshot` β€” no JS round-trip. + * + * @param token - Token ID to compute surprisal for + * @param base - Logarithm base: "nats" (default) or "bits" + * @returns Surprisal value in specified base + * + * COST: O(n_vocab) - softmax normalization required + */ + modelSurprisal(token: number, base: 'nats' | 'bits' = 'nats'): number { + this._ensureNotDisposed(); + return this._ctx._branchModelSurprisal(this._handle, token, base); + } + + /** + * Sampling-level perplexity (from filtered distribution) + * + * Returns perplexity from the distribution actually sampled from + * (after top-k/p/temp/penalties). Useful for policy priors and + * monitoring sampler chain impact. + * + * Compare with {@link perplexity} which is model-level (raw logits). + */ + get samplingPerplexity(): number { + this._ensureNotDisposed(); + return this._ctx._branchGetSamplingPerplexity(this._handle); + } + + /** + * Set static logit biases on this branch + * + * Unlike {@link steer} (which is NOT inherited on fork), logit biases + * ARE cloned when forking. Use for persistent constraints that should + * propagate to child branches. + * + * Applied during sample() in order: Grammar -> Logit Bias -> Steer -> Sampler Chain + * + * @param biases - Array of token adjustments. Use `-Infinity` to block, + * positive to boost, negative to reduce. + */ + setLogitBias(biases: Array<{ token: number; bias: number }>): void { + this._ensureNotDisposed(); + this._ctx._branchSetLogitBias(this._handle, biases); + } + + /** + * Clear all static logit biases from this branch + */ + clearLogitBias(): void { + this._ensureNotDisposed(); + this._ctx._branchClearLogitBias(this._handle); + } + + // ===== ACCESSORS ===== + + /** Branch's current position (number of tokens decoded) */ + get position(): number { + this._ensureNotDisposed(); + return this._ctx._branchGetPosition(this._handle); + } + + /** Branch's perplexity (exp of mean surprisal) */ + get perplexity(): number { + this._ensureNotDisposed(); + return this._ctx._branchGetPerplexity(this._handle); + } + + /** Internal handle (for debugging) */ + get handle(): number { + return this._handle; + } + + /** Whether this branch has been disposed */ + get disposed(): boolean { + return this._disposed; + } + + /** Parent branch handle, or null if root */ + get parent(): number | null { + this._ensureNotDisposed(); + const h = this._ctx._branchParent(this._handle); + return h === 0 ? null : h; + } + + /** Child branch handles */ + get children(): number[] { + this._ensureNotDisposed(); + return this._ctx._branchChildren(this._handle); + } + + /** True if this branch has no children */ + get isLeaf(): boolean { + this._ensureNotDisposed(); + return this._ctx._branchIsLeaf(this._handle); + } + + /** True if this branch holds a KV lease */ + get isActive(): boolean { + this._ensureNotDisposed(); + return this._ctx._branchIsActive(this._handle); + } + + // ===== ASYNC ITERATION ===== + + /** + * Async iterator β€” generate tokens until EOG + * + * Commit-before-yield semantics: every yielded token is already written + * to KV and accepted into the sampler. Breaking out of the loop is clean β€” + * no orphaned uncommitted tokens, perplexity reflects all yielded tokens. + * + * For inspect-before-commit (speculative decoding, tree search), use + * the {@link produce}/{@link commit} protocol directly. + * + * @example Generate to completion + * ```typescript + * for await (const { token, text } of branch) { + * process.stdout.write(text); + * } + * ``` + * + * @example Generate with consumer-side bound + * ```typescript + * const tokens = []; + * for await (const { token } of branch) { + * tokens.push(token); + * if (tokens.length >= limit) break; + * } + * ``` + */ + async *[Symbol.asyncIterator](): AsyncIterableIterator<{ token: number; text: string }> { + while (!this._disposed) { + const { token, text, isStop } = await this.produce(); + if (isStop) return; + await this.commit(token); + yield { token, text }; + } + } + + // ===== INTERNAL ===== + + private _ensureNotDisposed(): void { + if (this._disposed) { + throw new Error('Branch has been disposed'); + } + } +} diff --git a/src/BranchStore.ts b/src/BranchStore.ts new file mode 100644 index 0000000..c4813b9 --- /dev/null +++ b/src/BranchStore.ts @@ -0,0 +1,155 @@ +import type { Branch } from './Branch'; +import type { SessionContext } from './types'; + +/** + * High-throughput multi-branch decode operations + * + * The naive approach to N-branch generation is N sequential llama_decode() + * calls β€” each paying full GPU kernel launch overhead, memory barrier, and + * PCIe round-trip. BranchStore eliminates this by packing all branches into + * a single llama_batch and dispatching once: O(1) GPU round-trips regardless + * of branch count. The GPU parallelizes across sequences within the batch, + * so N branches approach the wall-time cost of 1. + * + * Two operations, two packing strategies: + * + * **commit()** β€” Generation step. Each branch contributes exactly 1 token. + * Packs N tokens into a single batch via `decode_each` (one row per sequence, + * all at their respective positions). Single `llama_decode()` call. Logits + * captured per-branch at batch index `i`. O(N) total work, O(1) GPU + * dispatches, O(1) amortized dispatch overhead per branch. Accept-first + * ordering with rollback: accepts each token into its branch's repeat-penalty + * window before decode, restores from clones if decode throws. + * + * **prefill()** β€” Bulk token injection. Each branch contributes a + * variable-length token array. Uses a two-pass bin-packing algorithm: + * + * - *Pass 1 (planning)*: Greedy first-fit packs items into chunks ≀ nBatch. + * Items larger than nBatch get a dedicated chunk and fall through to + * decode_many's internal auto-chunking (ceil(nTokens / nBatch) calls). + * - *Pass 2 (dispatch)*: Normal chunks dispatch via `decode_scatter` (one + * `llama_decode` per chunk). Logits are indexed by flattened cursor + * position: for item k in a chunk, logits live at `cursor + nTokens[k] - 1`. + * + * For T total tokens across N branches with batch capacity B: + * - Best case (T ≀ B): 1 GPU dispatch, all branches in one batch. + * - Worst case: ceil(T / B) dispatches. Each dispatch is fully packed. + * - Amortized per-token GPU overhead: O(1/B) β€” vanishes as batch fills. + * + * Does NOT accept tokens into the sampler penalty window β€” use for + * external/replayed tokens where repeat-penalty tracking is unwanted. + * For model-generated tokens, use {@link commit} instead. + * + * Both methods take `[branch, token(s)]` tuples β€” the branch-to-token + * binding is structural, not positional. After either call, each branch's + * logits snapshot is updated with the output distribution from its decoded + * token(s), ready for the next `produce()`/`sample()` call. + * + * @example 32-branch generation step β€” one GPU dispatch + * ```typescript + * const store = new BranchStore(ctx); + * const entries = await Promise.all(branches.map(async b => [b, (await b.produce()).token] as [Branch, number])); + * await store.commit(entries); // 32 tokens, 1 llama_decode() + * ``` + * + * @example Best-of-N with batched commit + * ```typescript + * const store = new BranchStore(ctx); + * const branches = []; + * for (const _ of [1, 2, 3]) branches.push(await root.fork()); + * + * for (let step = 0; step < 50; step++) { + * const produced = await Promise.all(branches.map(async b => [b, await b.produce()] as const)); + * const live = produced.filter(([, p]) => !p.isStop); + * if (!live.length) break; + * await store.commit(live.map(([b, p]) => [b, p.token])); + * } + * ``` + * + * @example Asymmetric prefill β€” variable-length injections, auto-chunked + * ```typescript + * await store.prefill([ + * [branchA, systemPromptTokens], // 200 tokens + * [branchB, shortQueryTokens], // 12 tokens + * [branchC, longDocumentTokens], // 800 tokens + * ]); + * // Bin-packed into ceil(1012 / nBatch) GPU dispatches + * ``` + * + * @category Branching + */ +export class BranchStore { + private _ctx: SessionContext; + + constructor(ctx: SessionContext) { + this._ctx = ctx; + } + + /** + * Batched single-token commit for model-generated tokens + * + * Each tuple `[branch, token]` binds one token to one branch. + * Accepts each token into its branch's repeat-penalty window (for correct + * PPL measurement), then decodes all N tokens in a single llama_decode() + * call via decode_each and captures logits per-branch. Accept-first + * ordering with rollback: if decode throws, sampler/grammar/metrics are + * restored from clones taken before the accept. + * + * @param entries - Array of `[branch, token]` tuples (branches must not be disposed) + * @throws If any branch is disposed + */ + async commit(entries: [Branch, number][]): Promise { + const handles: number[] = []; + const tokens: number[] = []; + for (const [branch, token] of entries) { + if (branch.disposed) throw new Error('BranchStore.commit: branch is disposed'); + handles.push(branch.handle); + tokens.push(token); + } + await this._ctx._storeCommit(handles, tokens); + } + + /** + * Batched variable-length prefill for external tokens + * + * Each tuple `[branch, tokens]` binds a token array to one branch. + * Each branch can receive a different number of tokens β€” decode_scatter + * handles variable-length runs and auto-chunks to fit nBatch. + * + * Does NOT call accept_token β€” use for external/replayed tokens where + * repeat-penalty tracking is unwanted. For model-generated tokens, + * use {@link commit} instead. + * + * @param entries - Array of `[branch, tokens]` tuples (branches must not be disposed) + * @throws If any branch is disposed + */ + async prefill(entries: [Branch, number[]][]): Promise { + const handles: number[] = []; + const tokenArrays: number[][] = []; + for (const [branch, tokens] of entries) { + if (branch.disposed) throw new Error('BranchStore.prefill: branch is disposed'); + handles.push(branch.handle); + tokenArrays.push(tokens); + } + await this._ctx._storePrefill(handles, tokenArrays); + } + + /** + * Retain only the winner branch β€” evict all other leases and free their slots. + * + * Nuclear operation: calls `kv::seq_keep` on the winner's seq_id (stripping all + * other sequences from KV cache in a single pass), then frees all loser slots + * and rebuilds the vacancy list. The winner's topology is reset (no parent, no children). + * + * @param winner - The branch to keep (must not be disposed, must hold a lease) + * @throws If winner is disposed or has no lease + */ + async retainOnly(winner: Branch): Promise { + if (winner.disposed) throw new Error('BranchStore.retainOnly: winner is disposed'); + this._ctx._storeRetainOnly(winner.handle); + } + + get available(): number { + return this._ctx._storeAvailable(); + } +} diff --git a/src/Rerank.ts b/src/Rerank.ts new file mode 100644 index 0000000..0771fef --- /dev/null +++ b/src/Rerank.ts @@ -0,0 +1,268 @@ +import { createContext } from './index.js'; +import type { SessionContext, RerankOptions, RerankResult, RerankProgress } from './types'; + +const SYSTEM_PROMPT = + 'Judge whether the Document meets the requirements based on the Query ' + + 'and the Instruct provided. Note that the answer can only be "yes" or "no".'; + +const USER_PREFIX = + ': Given a web search query, retrieve relevant passages that answer the query\n\n' + + ': '; + +interface ScoringRequest { + tokenArrays: number[][]; + cursor: number; + scores: number[]; + filled: number; + topK: number | undefined; + total: number; + push: (progress: RerankProgress) => void; + finish: () => void; + error: (err: Error) => void; +} + +/** Simple async channel β€” _drain pushes, consumer pulls via for-await */ +function channel(): { + push: (value: T) => void; + finish: () => void; + error: (err: Error) => void; + iterable: AsyncIterable; +} { + const buffer: T[] = []; + let done = false; + let err: Error | null = null; + let notify: (() => void) | null = null; + + const wait = () => new Promise((r) => { notify = r; }); + + return { + push(value: T) { + buffer.push(value); + notify?.(); + notify = null; + }, + finish() { + done = true; + notify?.(); + notify = null; + }, + error(e: Error) { + err = e; + notify?.(); + notify = null; + }, + iterable: { + [Symbol.asyncIterator](): AsyncIterator { + return { + async next(): Promise> { + while (buffer.length === 0 && !done && !err) await wait(); + if (err) throw err; + if (buffer.length > 0) return { value: buffer.shift()!, done: false }; + return { value: undefined as unknown as T, done: true }; + }, + }; + }, + }, + }; +} + +export class Rerank { + private _ctx: SessionContext; + private _nSeqMax: number; + private _nCtx: number; + private _yesId: number; + private _noId: number; + private _prefixTokens: number[]; + private _midTokens: number[]; + private _suffixTokens: number[]; + private _pending: ScoringRequest[] = []; + private _draining = false; + private _disposed = false; + + private constructor( + ctx: SessionContext, + nSeqMax: number, + nCtx: number, + yesId: number, + noId: number, + prefixTokens: number[], + midTokens: number[], + suffixTokens: number[], + ) { + this._ctx = ctx; + this._nSeqMax = nSeqMax; + this._nCtx = nCtx; + this._yesId = yesId; + this._noId = noId; + this._prefixTokens = prefixTokens; + this._midTokens = midTokens; + this._suffixTokens = suffixTokens; + } + + static async create(options: RerankOptions): Promise { + const nSeqMax = options.nSeqMax ?? 8; + const nCtx = options.nCtx ?? 4096; + const ctx = await createContext({ + modelPath: options.modelPath, + nCtx, + nSeqMax, + typeK: options.typeK ?? 'q4_0', + typeV: options.typeV ?? 'q4_0', + }); + + const [yesId] = await ctx.tokenize('yes', false); + const [noId] = await ctx.tokenize('no', false); + + const SENTINEL_Q = '\x00QUERY\x00'; + const SENTINEL_D = '\x00DOC\x00'; + const probe = await ctx.formatChat(JSON.stringify([ + { role: 'system', content: SYSTEM_PROMPT }, + { role: 'user', content: `${USER_PREFIX}${SENTINEL_Q}\n\n: ${SENTINEL_D}` }, + ]), { addGenerationPrompt: true, enableThinking: false }); + + const p = probe.prompt; + const qi = p.indexOf(SENTINEL_Q); + const di = p.indexOf(SENTINEL_D); + const prefixTokens = await ctx.tokenize(p.slice(0, qi), true); + const midTokens = await ctx.tokenize(p.slice(qi + SENTINEL_Q.length, di), false); + const suffixTokens = await ctx.tokenize(p.slice(di + SENTINEL_D.length), false); + + return new Rerank(ctx, nSeqMax, nCtx, yesId, noId, prefixTokens, midTokens, suffixTokens); + } + + score(query: string, documents: number[][], topK?: number): AsyncIterable { + if (this._disposed) throw new Error('Rerank disposed'); + + const self = this; + const ch = channel(); + + (async () => { + try { + const queryTokens = await self._ctx.tokenize(query, false); + const shared = [...self._prefixTokens, ...queryTokens, ...self._midTokens]; + const maxDoc = Math.floor(self._nCtx / self._nSeqMax) - shared.length - self._suffixTokens.length; + + const tokenArrays = documents.map((doc) => { + const trimmed = doc.length > maxDoc ? doc.slice(0, maxDoc) : doc; + return [...shared, ...trimmed, ...self._suffixTokens]; + }); + + self._enqueue(tokenArrays, topK, ch.push, ch.finish, ch.error); + } catch (err) { + ch.error(err instanceof Error ? err : new Error(String(err))); + } + })(); + + return ch.iterable; + } + + async tokenize(text: string): Promise { + return this._ctx.tokenize(text, false); + } + + dispose(): void { + this._disposed = true; + const err = new Error('Rerank disposed'); + for (const req of this._pending) req.error(err); + this._pending.length = 0; + this._ctx.dispose(); + } + + // ── Queue internals ────────────────────────────────────────── + + private _sortResults(scores: number[], topK: number | undefined): RerankResult[] { + const sorted = scores + .map((score, index) => ({ score: Math.round(score * 1000) / 1000, index })) + .sort((a, b) => b.score - a.score); + return topK != null ? sorted.slice(0, topK) : sorted; + } + + private _enqueue( + tokenArrays: number[][], + topK: number | undefined, + push: (progress: RerankProgress) => void, + finish: () => void, + error: (err: Error) => void, + ): void { + this._pending.push({ + tokenArrays, cursor: 0, + scores: new Array(tokenArrays.length), + filled: 0, + topK, + total: tokenArrays.length, + push, finish, error, + }); + this._drain(); + } + + private _fillGroup(): { reqIdx: number; promptIdx: number; tokens: number[] }[] { + const group: { reqIdx: number; promptIdx: number; tokens: number[] }[] = []; + let added = true; + while (group.length < this._nSeqMax && added) { + added = false; + for (let r = 0; r < this._pending.length && group.length < this._nSeqMax; r++) { + const req = this._pending[r]; + if (req.cursor < req.tokenArrays.length) { + group.push({ reqIdx: r, promptIdx: req.cursor, tokens: req.tokenArrays[req.cursor] }); + req.cursor++; + added = true; + } + } + } + return group; + } + + private async _drain(): Promise { + if (this._draining) return; + this._draining = true; + + try { + while (this._pending.length > 0) { + const group = this._fillGroup(); + if (group.length === 0) break; + + let logits: Float32Array[]; + try { + logits = await this._ctx._scoreGroup(group.map((g) => g.tokens)); + } catch (err) { + const error = err instanceof Error ? err : new Error(String(err)); + for (const req of this._pending) req.error(error); + this._pending.length = 0; + return; + } + + // Track which requests got new scores this group + const touched = new Set(); + for (let i = 0; i < group.length; i++) { + const req = this._pending[group[i].reqIdx]; + req.scores[group[i].promptIdx] = this._rerankScore(logits[i]); + req.filled++; + touched.add(group[i].reqIdx); + } + + // Push progress for each request that advanced, finish completed ones + for (let r = this._pending.length - 1; r >= 0; r--) { + const req = this._pending[r]; + if (!touched.has(r)) continue; + + const results = this._sortResults(req.scores, req.topK); + req.push({ filled: req.filled, total: req.total, results }); + + if (req.filled === req.total) { + req.finish(); + this._pending.splice(r, 1); + } + } + } + } finally { + this._draining = false; + } + } + + private _rerankScore(logits: Float32Array): number { + const max = Math.max(logits[this._yesId], logits[this._noId]); + const yesExp = Math.exp(logits[this._yesId] - max); + const noExp = Math.exp(logits[this._noId] - max); + return yesExp / (yesExp + noExp); + } +} diff --git a/src/Session.ts b/src/Session.ts new file mode 100644 index 0000000..4ce87fb --- /dev/null +++ b/src/Session.ts @@ -0,0 +1,99 @@ +import type { Branch } from './Branch'; +import type { BranchStore } from './BranchStore'; +import type { SessionContext } from './types'; +import { buildUserDelta, buildToolResultDelta } from './agents/deltas'; + +/** + * Session - Trunk lifecycle + conversation delta helpers + * + * Owns the current "trunk" branch and provides promote() to crown a winner, + * plus delta helpers that centralize the sep + formatChat + tokenize + prefill + * pattern for injecting new turns into an ongoing conversation. + * + * Session does NOT own the SessionContext or BranchStore β€” the consumer + * creates those and passes them in. dispose() prunes trunk only. + * + * @example + * ```typescript + * const session = new Session({ ctx, store }); + * session.trunk = initialBranch; + * + * // After verification, promote the best attempt + * await session.promote(bestAttempt.branch); + * + * // Inject a user turn and generate + * await session.prefillUser('What about X?'); + * for await (const { text } of session.trunk) { + * process.stdout.write(text); + * } + * + * // Cleanup + * await session.dispose(); + * ctx.dispose(); + * ``` + * + * @category Branching + */ +export class Session { + private _ctx: SessionContext; + private _store: BranchStore; + private _trunk: Branch | null; + + constructor({ ctx, store }: { ctx: SessionContext; store: BranchStore }) { + this._ctx = ctx; + this._store = store; + this._trunk = null; + } + + /** Current trunk branch */ + get trunk(): Branch | null { + return this._trunk; + } + + /** Assign initial trunk (no promote) */ + set trunk(branch: Branch | null) { + this._trunk = branch; + } + + /** + * Promote a winner to trunk β€” retainOnly + reassign + * + * Safe even if winner is the only branch (resets topology, no-op on KV). + */ + async promote(winner: Branch): Promise { + await this._store.retainOnly(winner); + this._trunk = winner; + } + + /** + * Dispose trunk only β€” consumer owns ctx and other resources + */ + async dispose(): Promise { + if (this._trunk && !this._trunk.disposed) { + await this._trunk.prune(); + } + this._trunk = null; + } + + /** + * Prefill a user turn into trunk + * + * @param content - User message content + * @param opts - Optional tools JSON string + */ + async prefillUser(content: string, opts: { tools?: string } = {}): Promise { + const tokens = buildUserDelta(this._ctx, content, opts); + await this._trunk!.prefill(tokens); + } + + /** + * Prefill a tool result turn into trunk + * + * @param resultStr - JSON-stringified tool result + * @param callId - Tool call ID + */ + async prefillToolResult(resultStr: string, callId: string): Promise { + const tokens = buildToolResultDelta(this._ctx, resultStr, callId); + await this._trunk!.prefill(tokens); + } +} diff --git a/src/SessionContext.cpp b/src/SessionContext.cpp index afc74bc..bd6095e 100644 --- a/src/SessionContext.cpp +++ b/src/SessionContext.cpp @@ -9,9 +9,11 @@ #include #include #include +#include #include #include #include +#include #include namespace liblloyal_node { @@ -471,6 +473,9 @@ class DetokenizeWorker : public Napi::AsyncWorker { std::string _result; }; +// Forward declaration β€” defined after parseFormatChatArgs +static Napi::Object marshalFormatResult(Napi::Env env, const lloyal::chat_in::FormatResult& r); + /** * AsyncWorker for formatChat operation */ @@ -498,45 +503,7 @@ class FormatChatWorker : public Napi::AsyncWorker { } void OnOK() override { - Napi::Env env = Env(); - - Napi::Object result = Napi::Object::New(env); - result.Set("prompt", Napi::String::New(env, _result.prompt)); - - // stopTokens (backward compat) - Napi::Array stopTokens = Napi::Array::New(env, _result.additional_stops.size()); - for (size_t i = 0; i < _result.additional_stops.size(); i++) { - stopTokens[i] = Napi::String::New(env, _result.additional_stops[i]); - } - result.Set("stopTokens", stopTokens); - - // Format awareness fields - result.Set("format", Napi::Number::New(env, static_cast(_result.format))); - result.Set("grammar", Napi::String::New(env, _result.grammar)); - result.Set("grammarLazy", Napi::Boolean::New(env, _result.grammar_lazy)); - result.Set("thinkingForcedOpen", Napi::Boolean::New(env, _result.thinking_forced_open)); - result.Set("reasoningFormat", Napi::Number::New(env, static_cast(_result.reasoning_format))); - result.Set("parser", Napi::String::New(env, _result.parser)); - - // grammarTriggers: Array<{ type: number, value: string, token: number }> - Napi::Array triggers = Napi::Array::New(env, _result.grammar_triggers.size()); - for (size_t i = 0; i < _result.grammar_triggers.size(); i++) { - Napi::Object trigger = Napi::Object::New(env); - trigger.Set("type", Napi::Number::New(env, static_cast(_result.grammar_triggers[i].type))); - trigger.Set("value", Napi::String::New(env, _result.grammar_triggers[i].value)); - trigger.Set("token", Napi::Number::New(env, static_cast(_result.grammar_triggers[i].token))); - triggers[i] = trigger; - } - result.Set("grammarTriggers", triggers); - - // preservedTokens: string[] - Napi::Array preserved = Napi::Array::New(env, _result.preserved_tokens.size()); - for (size_t i = 0; i < _result.preserved_tokens.size(); i++) { - preserved[i] = Napi::String::New(env, _result.preserved_tokens[i]); - } - result.Set("preservedTokens", preserved); - - _deferred.Resolve(result); + _deferred.Resolve(marshalFormatResult(Env(), _result)); } void OnError(const Napi::Error& err) override { @@ -707,6 +674,68 @@ class StorePrefillWorker : public Napi::AsyncWorker { std::vector> _tokenStorage; }; +/** + * AsyncWorker for batch logit scoring (process_chunks) + * Owns token storage and logit output buffers + */ +class ScoreGroupWorker : public Napi::AsyncWorker { +public: + ScoreGroupWorker(Napi::Env env, + llama_context* ctx, + llama_model* model, + int32_t nSeqMax, + std::vector> tokenStorage) + : AsyncWorker(env), _deferred(env), _ctx(ctx), _model(model), + _nSeqMax(nSeqMax), _tokenStorage(std::move(tokenStorage)) {} + + void Execute() override { + try { + if (static_cast(_tokenStorage.size()) > _nSeqMax) { + SetError("_scoreGroup: input size " + std::to_string(_tokenStorage.size()) + + " exceeds n_seq_max " + std::to_string(_nSeqMax)); + return; + } + + int32_t n_vocab = lloyal::tokenizer::vocab_size(_model); + size_t n = _tokenStorage.size(); + + _logitsStorage.resize(n); + std::vector> spans(n); + std::vector outputs(n); + for (size_t i = 0; i < n; ++i) { + _logitsStorage[i].resize(n_vocab); + spans[i] = _tokenStorage[i]; + outputs[i] = _logitsStorage[i].data(); + } + + lloyal::logits::process_chunks(_ctx, spans, outputs, n_vocab); + } catch (const std::exception& e) { SetError(e.what()); } + } + + void OnOK() override { + Napi::Env env = Env(); + Napi::Array result = Napi::Array::New(env, _logitsStorage.size()); + for (size_t i = 0; i < _logitsStorage.size(); ++i) { + auto buf = Napi::Float32Array::New(env, _logitsStorage[i].size()); + std::memcpy(buf.Data(), _logitsStorage[i].data(), + _logitsStorage[i].size() * sizeof(float)); + result.Set(static_cast(i), buf); + } + _deferred.Resolve(result); + } + + void OnError(const Napi::Error& err) override { _deferred.Reject(err.Value()); } + Napi::Promise GetPromise() { return _deferred.Promise(); } + +private: + Napi::Promise::Deferred _deferred; + llama_context* _ctx; + llama_model* _model; + int32_t _nSeqMax; + std::vector> _tokenStorage; + std::vector> _logitsStorage; +}; + /** * AsyncWorker for JSON schema β†’ GBNF grammar conversion * Pure CPU, no shared state β€” cleanest worker @@ -744,6 +773,7 @@ Napi::Object SessionContext::Init(Napi::Env env, Napi::Object exports) { // ===== PROMPT PREPARATION ===== InstanceMethod("tokenize", &SessionContext::tokenize), + InstanceMethod("tokenizeSync", &SessionContext::tokenizeSync), InstanceMethod("detokenize", &SessionContext::detokenize), // ===== KV CACHE MANAGEMENT ===== @@ -763,8 +793,10 @@ Napi::Object SessionContext::Init(Napi::Env env, Napi::Object exports) { // ===== HELPERS ===== InstanceMethod("formatChat", &SessionContext::formatChat), + InstanceMethod("formatChatSync", &SessionContext::formatChatSync), InstanceMethod("parseChatOutput", &SessionContext::parseChatOutput), InstanceMethod("jsonSchemaToGrammar", &SessionContext::jsonSchemaToGrammar), + InstanceMethod("jsonSchemaToGrammarSync", &SessionContext::jsonSchemaToGrammarSync), InstanceMethod("validateChatTemplate", &SessionContext::validateChatTemplate), // ===== EMBEDDING EXTRACTION ===== @@ -796,6 +828,7 @@ Napi::Object SessionContext::Init(Napi::Env env, Napi::Object exports) { InstanceMethod("_branchClearSteer", &SessionContext::_branchClearSteer), InstanceMethod("_branchSetSamplerParams", &SessionContext::_branchSetSamplerParams), InstanceMethod("_branchSetGrammar", &SessionContext::_branchSetGrammar), + InstanceMethod("_branchSetGrammarLazy", &SessionContext::_branchSetGrammarLazy), InstanceMethod("_branchModelEntropy", &SessionContext::_branchModelEntropy), InstanceMethod("_branchModelSurprisal", &SessionContext::_branchModelSurprisal), InstanceMethod("_branchGetSamplingPerplexity", &SessionContext::_branchGetSamplingPerplexity), @@ -807,6 +840,10 @@ Napi::Object SessionContext::Init(Napi::Env env, Napi::Object exports) { InstanceMethod("_storePrefill", &SessionContext::_storePrefill), InstanceMethod("_storeRetainOnly", &SessionContext::_storeRetainOnly), InstanceMethod("_storeAvailable", &SessionContext::_storeAvailable), + InstanceMethod("_storeKvPressure", &SessionContext::_storeKvPressure), + + // ===== SCORING API ===== + InstanceMethod("_scoreGroup", &SessionContext::_scoreGroup), // ===== PROPERTIES ===== InstanceAccessor("vocabSize", &SessionContext::getVocabSize, nullptr), @@ -880,6 +917,38 @@ Napi::Value SessionContext::tokenize(const Napi::CallbackInfo& info) { return worker->GetPromise(); } +Napi::Value SessionContext::tokenizeSync(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + if (info.Length() < 1 || !info[0].IsString()) { + throw Napi::TypeError::New(env, "Expected (text: string[, addSpecial: boolean])"); + } + + std::string text = info[0].As().Utf8Value(); + + bool addSpecial = true; + bool addSpecialOverridden = false; + if (info.Length() >= 2 && info[1].IsBoolean()) { + addSpecial = info[1].As().Value(); + addSpecialOverridden = true; + } + + std::vector result; + if (addSpecialOverridden) { + const llama_vocab* vocab = llama_model_get_vocab(_model.get()); + result = lloyal::tokenizer::tokenize(vocab, text, addSpecial, true); + } else { + result = lloyal::tokenizer::tokenize(_model.get(), text); + } + + Napi::Array jsTokens = Napi::Array::New(env, result.size()); + for (size_t i = 0; i < result.size(); i++) { + jsTokens[i] = Napi::Number::New(env, static_cast(result[i])); + } + return jsTokens; +} + Napi::Value SessionContext::detokenize(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); ensureNotDisposed(); @@ -1051,21 +1120,13 @@ Napi::Value SessionContext::getTurnSeparator(const Napi::CallbackInfo& info) { return result; } -Napi::Value SessionContext::formatChat(const Napi::CallbackInfo& info) { - Napi::Env env = info.Env(); - ensureNotDisposed(); - - if (info.Length() < 1 || !info[0].IsString()) { - throw Napi::TypeError::New(env, "Expected (messagesJson: string[, options: object])"); - } - +// Shared helper: parse JS args into FormatInputs +static lloyal::chat_in::FormatInputs parseFormatChatArgs(const Napi::CallbackInfo& info) { lloyal::chat_in::FormatInputs inputs; inputs.messages_json = info[0].As().Utf8Value(); - // Second argument: options object (or string for backward compat) if (info.Length() >= 2) { if (info[1].IsString()) { - // Backward compat: formatChat(messagesJson, templateOverride) inputs.template_override = info[1].As().Utf8Value(); } else if (info[1].IsObject()) { Napi::Object opts = info[1].As(); @@ -1099,12 +1160,76 @@ Napi::Value SessionContext::formatChat(const Napi::CallbackInfo& info) { } } } + return inputs; +} + +// Shared helper: marshal FormatResult β†’ Napi::Object +static Napi::Object marshalFormatResult(Napi::Env env, const lloyal::chat_in::FormatResult& r) { + Napi::Object result = Napi::Object::New(env); + result.Set("prompt", Napi::String::New(env, r.prompt)); + + Napi::Array stopTokens = Napi::Array::New(env, r.additional_stops.size()); + for (size_t i = 0; i < r.additional_stops.size(); i++) { + stopTokens[i] = Napi::String::New(env, r.additional_stops[i]); + } + result.Set("stopTokens", stopTokens); + + result.Set("format", Napi::Number::New(env, static_cast(r.format))); + result.Set("grammar", Napi::String::New(env, r.grammar)); + result.Set("grammarLazy", Napi::Boolean::New(env, r.grammar_lazy)); + result.Set("thinkingForcedOpen", Napi::Boolean::New(env, r.thinking_forced_open)); + result.Set("reasoningFormat", Napi::Number::New(env, static_cast(r.reasoning_format))); + result.Set("parser", Napi::String::New(env, r.parser)); + + Napi::Array triggers = Napi::Array::New(env, r.grammar_triggers.size()); + for (size_t i = 0; i < r.grammar_triggers.size(); i++) { + Napi::Object trigger = Napi::Object::New(env); + trigger.Set("type", Napi::Number::New(env, static_cast(r.grammar_triggers[i].type))); + trigger.Set("value", Napi::String::New(env, r.grammar_triggers[i].value)); + trigger.Set("token", Napi::Number::New(env, static_cast(r.grammar_triggers[i].token))); + triggers[i] = trigger; + } + result.Set("grammarTriggers", triggers); + + Napi::Array preserved = Napi::Array::New(env, r.preserved_tokens.size()); + for (size_t i = 0; i < r.preserved_tokens.size(); i++) { + preserved[i] = Napi::String::New(env, r.preserved_tokens[i]); + } + result.Set("preservedTokens", preserved); + + return result; +} + +Napi::Value SessionContext::formatChat(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + if (info.Length() < 1 || !info[0].IsString()) { + throw Napi::TypeError::New(env, "Expected (messagesJson: string[, options: object])"); + } + auto inputs = parseFormatChatArgs(info); auto* worker = new FormatChatWorker(env, _model, inputs); worker->Queue(); return worker->GetPromise(); } +Napi::Value SessionContext::formatChatSync(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + if (info.Length() < 1 || !info[0].IsString()) { + throw Napi::TypeError::New(env, "Expected (messagesJson: string[, options: object])"); + } + + auto inputs = parseFormatChatArgs(info); + lloyal::chat_in::FormatResult result = lloyal::chat_in::format(_model.get(), inputs); + if (result.prompt.empty()) { + throw Napi::Error::New(env, "Chat template formatting failed"); + } + return marshalFormatResult(env, result); +} + Napi::Value SessionContext::kvCacheSize(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); ensureNotDisposed(); @@ -1235,6 +1360,19 @@ Napi::Value SessionContext::jsonSchemaToGrammar(const Napi::CallbackInfo& info) return worker->GetPromise(); } +Napi::Value SessionContext::jsonSchemaToGrammarSync(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + if (info.Length() < 1 || !info[0].IsString()) { + throw Napi::TypeError::New(env, "Expected (schemaJson: string)"); + } + + std::string schemaJson = info[0].As().Utf8Value(); + std::string result = lloyal::grammar::from_json_schema(schemaJson); + return Napi::String::New(env, result); +} + Napi::Value SessionContext::validateChatTemplate(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); ensureNotDisposed(); @@ -1479,6 +1617,42 @@ Napi::Value SessionContext::kvCacheReadFile(const Napi::CallbackInfo& info) { return worker->GetPromise(); } +// ===== SCORING API ===== + +Napi::Value SessionContext::_scoreGroup(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + if (info.Length() < 1 || !info[0].IsArray()) { + throw Napi::Error::New(env, "_scoreGroup requires (tokenArrays: number[][])"); + } + + Napi::Array jsTokenArrays = info[0].As(); + uint32_t n = jsTokenArrays.Length(); + + if (n == 0) { + auto deferred = Napi::Promise::Deferred::New(env); + deferred.Resolve(Napi::Array::New(env, 0)); + return deferred.Promise(); + } + + std::vector> tokenStorage(n); + for (uint32_t i = 0; i < n; i++) { + Napi::Array jsArr = jsTokenArrays.Get(i).As(); + uint32_t len = jsArr.Length(); + tokenStorage[i].resize(len); + for (uint32_t j = 0; j < len; j++) { + tokenStorage[i][j] = static_cast( + jsArr.Get(j).As().Int32Value()); + } + } + + int32_t nSeqMax = static_cast(llama_n_seq_max(_context)); + auto* worker = new ScoreGroupWorker(env, _context, _model.get(), nSeqMax, std::move(tokenStorage)); + worker->Queue(); + return worker->GetPromise(); +} + // ===== FACTORY FUNCTION ===== Napi::Value CreateContext(const Napi::CallbackInfo& info) { @@ -1961,6 +2135,39 @@ Napi::Value SessionContext::_branchSetGrammar(const Napi::CallbackInfo& info) { return env.Undefined(); } +Napi::Value SessionContext::_branchSetGrammarLazy(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + if (info.Length() < 4) { + throw Napi::Error::New(env, "_branchSetGrammarLazy requires (handle, grammarStr, triggerPatterns, triggerTokens)"); + } + + auto handle = static_cast(info[0].As().Uint32Value()); + std::string grammar_str = info[1].As().Utf8Value(); + + // Extract trigger patterns (string[]) + std::vector trigger_patterns; + Napi::Array pArr = info[2].As(); + for (uint32_t i = 0; i < pArr.Length(); i++) { + trigger_patterns.push_back(pArr.Get(i).As().Utf8Value()); + } + + // Extract trigger tokens (number[]) + std::vector trigger_tokens; + Napi::Array tArr = info[3].As(); + for (uint32_t i = 0; i < tArr.Length(); i++) { + trigger_tokens.push_back(static_cast(tArr.Get(i).As().Int32Value())); + } + + lloyal::branch::set_grammar_lazy( + handle, _model.get(), grammar_str.c_str(), + trigger_patterns, trigger_tokens, _branchStore + ); + + return env.Undefined(); +} + // ===== BRANCH METRICS & LOGIT BIAS ===== Napi::Value SessionContext::_branchModelEntropy(const Napi::CallbackInfo& info) { @@ -2267,4 +2474,16 @@ Napi::Value SessionContext::_storeAvailable(const Napi::CallbackInfo& info) { return Napi::Number::New(env, static_cast(_branchStore.available())); } +Napi::Value SessionContext::_storeKvPressure(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + auto p = _branchStore.kv_pressure(); + auto obj = Napi::Object::New(env); + obj.Set("nCtx", Napi::Number::New(env, static_cast(p.n_ctx))); + obj.Set("cellsUsed", Napi::Number::New(env, static_cast(p.cells_used))); + obj.Set("remaining", Napi::Number::New(env, static_cast(p.remaining))); + return obj; +} + } // namespace liblloyal_node diff --git a/src/SessionContext.hpp b/src/SessionContext.hpp index 55eb35a..9039406 100644 --- a/src/SessionContext.hpp +++ b/src/SessionContext.hpp @@ -79,12 +79,19 @@ class SessionContext : public Napi::ObjectWrap { // ===== CORE PRIMITIVES ===== /** - * Tokenize text to token IDs + * Tokenize text to token IDs (async β€” dispatches to libuv thread pool) * Args: text (string) * Returns: Promise */ Napi::Value tokenize(const Napi::CallbackInfo& info); + /** + * Tokenize text to token IDs (sync β€” inline on main thread) + * Args: text (string[, addSpecial: boolean]) + * Returns: number[] + */ + Napi::Value tokenizeSync(const Napi::CallbackInfo& info); + /** * Detokenize tokens to text * Args: tokens (number[]) @@ -119,11 +126,19 @@ class SessionContext : public Napi::ObjectWrap { Napi::Value getTurnSeparator(const Napi::CallbackInfo& info); /** - * Format messages using model's chat template + * Format messages using model's chat template (async β€” dispatches to libuv thread pool) * Args: messagesJson (string), templateOverride (optional string) * Returns: Promise<{ prompt: string, stopTokens: string[] }> */ Napi::Value formatChat(const Napi::CallbackInfo& info); + + /** + * Format messages using model's chat template (sync β€” inline on main thread) + * Args: messagesJson (string), options? (object) + * Returns: { prompt: string, stopTokens: string[], ... } + */ + Napi::Value formatChatSync(const Napi::CallbackInfo& info); + Napi::Value parseChatOutput(const Napi::CallbackInfo& info); /** @@ -194,9 +209,9 @@ class SessionContext : public Napi::ObjectWrap { Napi::Value kvCacheReadFile(const Napi::CallbackInfo& info); // ===== HELPERS ===== - // Utility functions (not yet implemented) Napi::Value jsonSchemaToGrammar(const Napi::CallbackInfo& info); + Napi::Value jsonSchemaToGrammarSync(const Napi::CallbackInfo& info); Napi::Value validateChatTemplate(const Napi::CallbackInfo& info); // ===== EMBEDDING EXTRACTION ===== @@ -249,6 +264,7 @@ class SessionContext : public Napi::ObjectWrap { Napi::Value _branchClearSteer(const Napi::CallbackInfo& info); Napi::Value _branchSetSamplerParams(const Napi::CallbackInfo& info); Napi::Value _branchSetGrammar(const Napi::CallbackInfo& info); + Napi::Value _branchSetGrammarLazy(const Napi::CallbackInfo& info); Napi::Value _branchModelEntropy(const Napi::CallbackInfo& info); Napi::Value _branchModelSurprisal(const Napi::CallbackInfo& info); Napi::Value _branchGetSamplingPerplexity(const Napi::CallbackInfo& info); @@ -261,6 +277,11 @@ class SessionContext : public Napi::ObjectWrap { Napi::Value _storePrefill(const Napi::CallbackInfo& info); Napi::Value _storeRetainOnly(const Napi::CallbackInfo& info); Napi::Value _storeAvailable(const Napi::CallbackInfo& info); + Napi::Value _storeKvPressure(const Napi::CallbackInfo& info); + + // ===== SCORING API ===== + + Napi::Value _scoreGroup(const Napi::CallbackInfo& info); private: // ===== INTERNAL STATE ===== diff --git a/src/Util.cpp b/src/Util.cpp new file mode 100644 index 0000000..da5539c --- /dev/null +++ b/src/Util.cpp @@ -0,0 +1,193 @@ +#include "Util.hpp" +#include +#include +#include +#include +#include + +namespace liblloyal_node { + +struct Section { + std::string heading; + unsigned level = 0; + int start_line = 0; + int end_line = 0; +}; + +struct ParseState { + const char* input; + size_t input_size; + std::vector line_starts; + + int depth = 0; + bool in_heading = false; + unsigned heading_level = 0; + std::string heading_text; + + // Byte offset of the first text seen in the current top-level block + size_t block_first_offset = SIZE_MAX; + + std::vector
sections; + + void build_line_table() { + line_starts.push_back(0); + for (size_t i = 0; i < input_size; i++) { + if (input[i] == '\n') { + line_starts.push_back(i + 1); + } + } + } + + // Binary search: find the 1-indexed line number containing the given byte offset + int line_at(size_t offset) const { + auto it = std::upper_bound(line_starts.begin(), line_starts.end(), offset); + return static_cast(it - line_starts.begin()); + } + + // Line number of the last content line in the input + int last_line() const { + if (input_size == 0) return 0; + size_t last = input_size - 1; + if (input[last] == '\n' && last > 0) last--; + return line_at(last); + } +}; + +// md4c callbacks β€” static functions with C-compatible signatures + +static int on_enter_block(MD_BLOCKTYPE type, void* detail, void* userdata) { + auto* s = static_cast(userdata); + s->depth++; + + // depth==2 means direct child of MD_BLOCK_DOC (top-level block) + if (s->depth == 2) { + s->block_first_offset = SIZE_MAX; + + if (type == MD_BLOCK_H) { + s->in_heading = true; + s->heading_level = static_cast(detail)->level; + s->heading_text.clear(); + } + } + return 0; +} + +static int on_leave_block(MD_BLOCKTYPE type, void* /* detail */, void* userdata) { + auto* s = static_cast(userdata); + + if (s->depth == 2 && type == MD_BLOCK_H && s->block_first_offset != SIZE_MAX) { + int heading_line = s->line_at(s->block_first_offset); + + // Close previous section + if (!s->sections.empty()) { + s->sections.back().end_line = heading_line - 1; + } + + // Start new section at this heading + Section sec; + sec.heading = s->heading_text; + sec.level = s->heading_level; + sec.start_line = heading_line; + s->sections.push_back(sec); + + s->in_heading = false; + } + + s->depth--; + return 0; +} + +static int on_enter_span(MD_SPANTYPE /* type */, void* /* detail */, void* /* userdata */) { + return 0; +} + +static int on_leave_span(MD_SPANTYPE /* type */, void* /* detail */, void* /* userdata */) { + return 0; +} + +static int on_text(MD_TEXTTYPE /* type */, const MD_CHAR* text, MD_SIZE size, void* userdata) { + auto* s = static_cast(userdata); + + // Track first text offset for the current top-level block + if (s->depth >= 2 && s->block_first_offset == SIZE_MAX) { + s->block_first_offset = static_cast(text - s->input); + } + + // Accumulate heading text + if (s->in_heading) { + s->heading_text.append(text, size); + } + + return 0; +} + +// N-API entry points + +void Util::Init(Napi::Env env, Napi::Object exports) { + exports.Set("parseMarkdown", Napi::Function::New(env, ParseMarkdown)); +} + +Napi::Value Util::ParseMarkdown(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + + if (info.Length() < 1 || !info[0].IsString()) { + Napi::TypeError::New(env, "parseMarkdown expects a string argument") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + + std::string input = info[0].As().Utf8Value(); + + // Empty input β†’ empty result + if (input.empty()) { + return Napi::Array::New(env, 0); + } + + ParseState state; + state.input = input.c_str(); + state.input_size = input.size(); + state.build_line_table(); + + // Preamble: content before the first heading + Section preamble; + preamble.start_line = 1; + state.sections.push_back(preamble); + + MD_PARSER parser = {}; + parser.abi_version = 0; + parser.flags = MD_FLAG_TABLES | MD_FLAG_STRIKETHROUGH; + parser.enter_block = on_enter_block; + parser.leave_block = on_leave_block; + parser.enter_span = on_enter_span; + parser.leave_span = on_leave_span; + parser.text = on_text; + + md_parse(input.c_str(), static_cast(input.size()), &parser, &state); + + // Close last section + if (!state.sections.empty()) { + state.sections.back().end_line = state.last_line(); + } + + // Remove empty sections (startLine > endLine) + state.sections.erase( + std::remove_if(state.sections.begin(), state.sections.end(), + [](const Section& sec) { return sec.start_line > sec.end_line; }), + state.sections.end()); + + // Build N-API result + Napi::Array result = Napi::Array::New(env, state.sections.size()); + for (size_t i = 0; i < state.sections.size(); i++) { + const auto& sec = state.sections[i]; + Napi::Object obj = Napi::Object::New(env); + obj.Set("heading", sec.heading); + obj.Set("level", static_cast(sec.level)); + obj.Set("startLine", static_cast(sec.start_line)); + obj.Set("endLine", static_cast(sec.end_line)); + result.Set(static_cast(i), obj); + } + + return result; +} + +} // namespace liblloyal_node diff --git a/src/Util.hpp b/src/Util.hpp new file mode 100644 index 0000000..2ebbc5c --- /dev/null +++ b/src/Util.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace liblloyal_node { + +class Util { +public: + static void Init(Napi::Env env, Napi::Object exports); + +private: + static Napi::Value ParseMarkdown(const Napi::CallbackInfo& info); +}; + +} // namespace liblloyal_node diff --git a/src/agents/Tool.ts b/src/agents/Tool.ts new file mode 100644 index 0000000..20c34d8 --- /dev/null +++ b/src/agents/Tool.ts @@ -0,0 +1,76 @@ +import type { JsonSchema, ToolSchema, ToolContext } from './types'; + +/** + * Abstract base class for tools usable by agents in the runtime + * + * Subclass to define tools that agents can invoke during generation. + * Implement `name`, `description`, `parameters`, and `execute()`. The + * {@link schema} getter auto-generates the OpenAI-compatible function + * schema expected by `formatChat()`. + * + * Pass tool instances to {@link createToolkit} to build the `toolMap` + * and `toolsJson` pair consumed by {@link useAgentPool} and + * {@link runAgents}. + * + * @example Search tool + * ```typescript + * class SearchTool extends Tool<{ query: string; topK?: number }> { + * readonly name = 'search'; + * readonly description = 'Search the corpus for relevant passages'; + * readonly parameters = { + * type: 'object', + * properties: { + * query: { type: 'string', description: 'Search query' }, + * topK: { type: 'number', description: 'Number of results' }, + * }, + * required: ['query'], + * }; + * + * async execute(args: { query: string; topK?: number }, ctx?: ToolContext) { + * const results = await this.reranker.rank(args.query, args.topK ?? 5); + * return { results }; + * } + * } + * ``` + * + * @category Agents + */ +export abstract class Tool> { + /** Tool name β€” used as the function identifier in tool calls */ + abstract readonly name: string; + /** Human-readable description shown to the model */ + abstract readonly description: string; + /** JSON Schema describing the tool's expected arguments */ + abstract readonly parameters: JsonSchema; + + /** + * Execute the tool with parsed arguments + * + * Called by the agent pool when the model emits a tool call matching + * this tool's name. The return value is JSON-serialized and prefilled + * back into the agent's context as a tool result. + * + * @param args - Parsed arguments from the model's tool call + * @param context - Execution context with progress reporting callback + * @returns Tool result (will be JSON-serialized) + */ + abstract execute(args: TArgs, context?: ToolContext): Promise; + + /** + * OpenAI-compatible function tool schema + * + * Auto-generated from `name`, `description`, and `parameters`. + * Used by {@link createToolkit} to build the JSON string passed + * to `formatChat()`. + */ + get schema(): ToolSchema { + return { + type: 'function', + function: { + name: this.name, + description: this.description, + parameters: this.parameters, + }, + }; + } +} diff --git a/src/agents/agent-pool.ts b/src/agents/agent-pool.ts new file mode 100644 index 0000000..7ff9316 --- /dev/null +++ b/src/agents/agent-pool.ts @@ -0,0 +1,586 @@ +import { resource, call, action, ensure, useScope, createSignal, spawn, each } from 'effection'; +import type { Operation, Scope, Channel } from 'effection'; +import type { Branch } from '../Branch'; +import { CHAT_FORMAT_CONTENT_ONLY, CHAT_FORMAT_GENERIC, GrammarTriggerType, type GrammarTrigger, type ParsedToolCall, type SessionContext } from '../types'; +import type { BranchStore } from '../BranchStore'; +import { Ctx, Store, Events } from './context'; +import { buildToolResultDelta } from './deltas'; +import type { + TraceToken, + PressureThresholds, + AgentTaskSpec, + AgentPoolOptions, + AgentPoolResult, + AgentEvent, +} from './types'; + +// ── Internal agent state machine ─────────────────────────────── +// generating β†’ awaiting_tool β†’ generating (tool result prefilled) +// generating β†’ done (stop + no tool call, or report) +// awaiting_tool β†’ done (tool error) + +type AgentInternalState = 'generating' | 'awaiting_tool' | 'done'; + +interface AgentInternal { + id: number; // = branch.handle + parentId: number; // = parent.handle + branch: Branch; + state: AgentInternalState; + fmt: { + format: number; + reasoningFormat: number; + thinkingForcedOpen: boolean; + parser: string; + grammar: string; + grammarLazy: boolean; + grammarTriggers: GrammarTrigger[]; + }; + rawOutput: string; + tokenCount: number; + toolCallCount: number; + turns: number; + findings: string | null; + traceBuffer: TraceToken[]; +} + +interface SettledTool { + agentId: number; + prefillTokens: number[]; + toolName: string; +} + +/** + * Immutable KV budget snapshot for one tick of the agent loop + * + * Created from `SessionContext._storeKvPressure()` which returns + * `{ nCtx, cellsUsed, remaining }` where `remaining = nCtx - cellsUsed`. + * `cellsUsed` is a monotonic counter in `BranchStore` β€” it increments on + * every `decode_each` / `decode_scatter` but does **not** decrement on + * individual branch prune (only resets on bulk ops like `retainOnly` and + * `drain`). This means `remaining` is a conservative lower bound that + * becomes increasingly pessimistic as branches are pruned mid-run. + * + * Two thresholds partition `remaining` into three zones: + * + * ``` + * β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + * β”‚ nCtx β”‚ + * β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ + * β”‚ β”‚cellsUsed β”‚ headroom > 0 β”‚ softLimit β”‚ β”‚ + * β”‚ β”‚ (in use) β”‚ (new work OK) β”‚ (reserved) β”‚ β”‚ + * β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ + * β”‚ ◄── remaining ──► β”‚ β”‚ + * β”‚ β”‚ β”‚ + * β”‚ headroom = remaining - softLimit β”‚ + * β”‚ critical = remaining < hardLimit β”‚ + * β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + * ``` + * + * - **headroom > 0** β€” room for new work (tool results, generation) + * - **headroom ≀ 0** β€” over budget. SETTLE rejects tool results, PRODUCE + * hard-cuts non-terminal tool calls. Terminal tools still pass. + * - **critical** β€” remaining below hardLimit. Agents killed before + * `produceSync()` to prevent llama_decode crashes. + * + * @category Agents + */ +export class ContextPressure { + /** Default softLimit: 1024 tokens reserved for downstream work */ + static readonly DEFAULT_SOFT_LIMIT = 1024; + /** Default hardLimit: 128 tokens crash-prevention floor */ + static readonly DEFAULT_HARD_LIMIT = 128; + + /** + * KV slots remaining (`nCtx - cellsUsed`). + * Infinity when nCtx ≀ 0 (no context limit). + * Conservative: may undercount actual free space when branches have been + * pruned, since `cellsUsed` is monotonic. + */ + readonly remaining: number; + /** Remaining KV floor β€” tokens reserved for downstream work */ + readonly softLimit: number; + /** Crash-prevention floor β€” agents killed when remaining drops below */ + readonly hardLimit: number; + + constructor(ctx: SessionContext, opts?: PressureThresholds) { + const p = ctx._storeKvPressure(); + this.remaining = p.nCtx <= 0 ? Infinity : p.remaining; + this.softLimit = opts?.softLimit ?? ContextPressure.DEFAULT_SOFT_LIMIT; + this.hardLimit = opts?.hardLimit ?? ContextPressure.DEFAULT_HARD_LIMIT; + } + + /** + * Tokens available for new work: `remaining - softLimit`. + * Positive means room to accept tool results or continue generating. + * Negative means over budget β€” SETTLE rejects, PRODUCE hard-cuts. + */ + get headroom(): number { return this.remaining - this.softLimit; } + + /** `remaining < hardLimit` β€” agent must not call `produceSync()`. */ + get critical(): boolean { return this.remaining < this.hardLimit; } + + /** Can `tokenCount` tokens fit while staying above softLimit? */ + canFit(tokenCount: number): boolean { return tokenCount <= this.headroom; } +} + +/** + * Fork an agent from a parent branch with its own system prompt and task. + * + * Generator β€” uses sync native calls so Effection sees everything. + * On scope exit (error, cancellation), `ensure()` prunes the branch + * automatically β€” the orphaned-branch leak is structurally impossible. + */ +function* setupAgent( + parent: Branch, + task: AgentTaskSpec, + ctx: SessionContext, +): Operation<{ agent: AgentInternal; suffixTokens: number[] }> { + const messages = [ + { role: 'system', content: task.systemPrompt }, + { role: 'user', content: task.content }, + ]; + const fmtOpts = task.tools ? { tools: task.tools } : {}; + const fmt = ctx.formatChatSync(JSON.stringify(messages), fmtOpts); + if (task.tools && (fmt.format === CHAT_FORMAT_CONTENT_ONLY || fmt.format === CHAT_FORMAT_GENERIC)) { + // Error before fork β€” no branch to clean up + throw new Error('Model does not support tool calling. Please use a model with native tool support (e.g. Qwen3, Llama 3.x, Mistral).'); + } + const branch = parent.forkSync(); + yield* ensure(() => { if (!branch.disposed) branch.pruneSync(); }); + const sep = ctx.getTurnSeparator(); + const suffixTokens = [...sep, ...ctx.tokenizeSync(fmt.prompt, false)]; + if (task.seed != null) branch.reseedSampler(task.seed); + + return { + agent: { + id: branch.handle, + parentId: parent.handle, + branch, + state: 'generating', + fmt: { + format: fmt.format, + reasoningFormat: fmt.reasoningFormat, + thinkingForcedOpen: fmt.thinkingForcedOpen, + parser: fmt.parser, + grammar: fmt.grammar, + grammarLazy: fmt.grammarLazy, + grammarTriggers: fmt.grammarTriggers, + }, + rawOutput: '', + tokenCount: 0, + toolCallCount: 0, + turns: 0, + findings: null, + traceBuffer: [], + }, + suffixTokens, + }; +} + +/** + * Concurrent agent generation loop as an Effection resource + * + * Runs N agents in parallel using a three-phase tick loop over shared + * {@link BranchStore} infrastructure. Each agent forks from a parent + * branch, generates tokens, invokes tools, and reports findings. + * + * **Three-phase tick loop:** + * 1. **PRODUCE** β€” sample all active agents via `produceSync()` (no async gap) + * 2. **COMMIT** β€” single GPU call via `store.commit()` for all produced tokens + * 3. **SETTLE** β€” drain settled tool results, batch prefill, reset grammars + * + * Tool dispatch uses `scope.run()` for eager start β€” tool executions run as + * children of the agent pool scope and are cancelled if the scope exits. + * + * **Resource semantics:** `provide()` suspends after all agents complete, + * keeping branches alive so the caller can fork from them (e.g. for + * verification). Branches are pruned when the scope exits β€” each branch's + * `ensure()` from `setupAgent` handles cleanup automatically. + * + * For automatic branch cleanup on return, use {@link runAgents} instead. + * + * @param opts - Pool configuration: tasks, tools, sampling params, max turns + * @returns Agent pool result with per-agent findings and aggregate statistics + * + * @example Shared root with agent pool + * ```typescript + * const pool = yield* withSharedRoot( + * { systemPrompt: RESEARCH_PROMPT, tools: toolsJson }, + * function*(root) { + * return yield* useAgentPool({ + * tasks: questions.map(q => ({ + * systemPrompt: RESEARCH_PROMPT, + * content: q, + * tools: toolsJson, + * parent: root, + * })), + * tools: toolMap, + * maxTurns: 6, + * }); + * }, + * ); + * ``` + * + * @category Agents + */ +export function useAgentPool(opts: AgentPoolOptions): Operation { + return resource(function*(provide) { + const ctx: SessionContext = yield* Ctx.expect(); + const store: BranchStore = yield* Store.expect(); + const events: Channel = yield* Events.expect(); + const scope: Scope = yield* useScope(); + + // Bridge for onProgress callbacks β€” Signal is correct here (external callback). + // A spawned forwarder drains the bridge into the Channel with proper scope context. + const progressBridge = createSignal(); + yield* spawn(function*() { + for (const ev of yield* each(progressBridge)) { + yield* events.send(ev); + yield* each.next(); + } + }); + const { tasks, tools, maxTurns = 100, terminalTool, trace = false, pressure: pressureOpts } = opts; + + // Whether the pool's tool registry contains tools besides the terminal tool. + // When false, agents are allowed to call the terminal tool as their first + // action (e.g. reporter sub-agents that only have `report()`). When true, + // the first tool call must be a non-terminal tool to prevent agents from + // immediately reporting without doing any work. + // + // IMPORTANT: this checks the pool's `tools` registry, not individual task + // schemas (`task.tools`). A reporter pool must pass only the terminal tool + // in its registry β€” passing the full tool map makes this flag true and + // traps reporters in an infinite rejection loop. + const hasNonTerminalTools = terminalTool ? [...tools.keys()].some(k => k !== terminalTool) : tools.size > 0; + + // ── Setup: fork branches, collect suffix tokens ────────── + // setupAgent is now a generator β€” each branch registers its own ensure() + // for cleanup. No manual try/finally needed here. + const agents: AgentInternal[] = []; + const prefillSetup: [Branch, number[]][] = []; + + for (const task of tasks) { + const parent = task.parent; + if (!parent) throw new Error('useAgentPool: each task must have a parent branch'); + + const { agent, suffixTokens } = yield* setupAgent(parent, task, ctx); + agents.push(agent); + prefillSetup.push([agent.branch, suffixTokens]); + } + + // Batch prefill all agent suffixes β€” pressure-gated. + // Each suffix is the full formatted chat (system prompt + tools JSON + + // user message + generation prompt), tokenized via formatChatSync(). + // Suffix cost is model-dependent: ~250-400 tokens per agent depending + // on chat template verbosity and tool schema size. + const initPressure = new ContextPressure(ctx, pressureOpts); + const totalSuffix = prefillSetup.reduce((s, [, t]) => s + t.length, 0); + if (!initPressure.canFit(totalSuffix)) { + // Not enough room β€” drop agents from the end until it fits + while (prefillSetup.length > 0) { + const needed = prefillSetup.reduce((s, [, t]) => s + t.length, 0); + if (initPressure.canFit(needed)) break; + prefillSetup.pop(); + const dropped = agents.pop()!; + dropped.state = 'done'; + } + } + if (prefillSetup.length > 0) { + yield* call(() => store.prefill(prefillSetup)); + } + + // Emit spawn events β€” TUI uses parentAgentId to detect sub-agents + for (const a of agents) { + yield* events.send({ type: 'agent:spawn', agentId: a.id, parentAgentId: a.parentId }); + } + + // ── Lazy grammar setup ─────────────────────────────────── + const applyLazyGrammar = (a: AgentInternal): void => { + if (a.fmt.grammar && a.fmt.grammarLazy && a.fmt.grammarTriggers.length > 0) { + const triggers = a.fmt.grammarTriggers.map(t => { + if (t.type === GrammarTriggerType.WORD) { + const nlIdx = t.value.indexOf('\n'); + if (nlIdx >= 0 && nlIdx < t.value.length - 1) { + return { ...t, value: t.value.slice(0, nlIdx + 1) }; + } + } + return t; + }); + a.branch.setGrammarLazy(a.fmt.grammar, triggers); + } + }; + for (const a of agents) applyLazyGrammar(a); + + // ── Tool dispatch coordination ─────────────────────────── + // Plain JS buffer: spawned tool tasks push synchronously on completion. + // SETTLE drains with splice(0). Safe because generators are synchronous + // between yields β€” spawns can only push at yield points (during COMMIT's + // yield* call()), and SETTLE runs after COMMIT in the same tick. + const settledBuffer: SettledTool[] = []; + const agentById = new Map(agents.map(a => [a.id, a])); + + // Track pending tool count for idle detection + let pendingToolCount = 0; + + // Resolve function for idle wake β€” set when all agents stall + let wakeIdle: (() => void) | null = null; + + let steps = 0; + let totalToolCalls = 0; + const counters = { + warmPrefillCalls: 0, + warmPrefillBranches: 0, + stalledTicks: 0, + maxConcurrentTools: 0, + idleTicks: 0, + }; + + function* dispatchTool(agent: AgentInternal, tc: ParsedToolCall): Operation { + let toolArgs: Record; + try { toolArgs = JSON.parse(tc.arguments); } catch { toolArgs = {}; } + const callId = tc.id || `call_${agent.toolCallCount}`; + + agent.toolCallCount++; + totalToolCalls++; + agent.turns++; + agent.state = 'awaiting_tool'; + + yield* events.send({ type: 'agent:tool_call', agentId: agent.id, tool: tc.name, args: tc.arguments }); + + const tool = tools.get(tc.name); + pendingToolCount++; + counters.maxConcurrentTools = Math.max(counters.maxConcurrentTools, pendingToolCount); + + // scope.run() β€” eager start, child of agent pool scope, cancelled if scope exits. + // spawn() is lazy (Operation), but we're in a generator β€” scope.run() is eager. + scope.run(function*() { + try { + const toolContext = { + onProgress: (p: { filled: number; total: number }) => { + // Signal bridge β€” onProgress is an external callback, Signal.send() is correct here. + progressBridge.send({ type: 'agent:tool_progress', agentId: agent.id, tool: tc.name, filled: p.filled, total: p.total }); + }, + }; + + const result: unknown = yield* call(() => + tool ? tool.execute(toolArgs, toolContext) : Promise.resolve({ error: `Unknown tool: ${tc.name}` }) + ); + const resultStr = JSON.stringify(result); + yield* events.send({ type: 'agent:tool_result', agentId: agent.id, tool: tc.name, result: resultStr }); + + const prefillTokens = buildToolResultDelta(ctx, resultStr, callId); + settledBuffer.push({ agentId: agent.id, prefillTokens, toolName: tc.name }); + } catch (err) { + agent.state = 'done'; + agent.findings = `Tool error: ${(err as Error).message}`; + } finally { + pendingToolCount--; + if (wakeIdle) { wakeIdle(); wakeIdle = null; } + } + }); + } + + // ── Three-phase tick loop ──────────────────────────────── + for (;;) { + // -- Phase 1: PRODUCE -- sample from active agents + const pressure = new ContextPressure(ctx, pressureOpts); + + if (trace && (pressure.critical || pressure.headroom < 0)) { + const p = ctx._storeKvPressure(); + try { process.stderr.write(`[PRODUCE] ${pressure.critical ? 'CRITICAL' : 'SOFT_LIMIT'} remaining=${p.remaining} headroom=${pressure.headroom} cellsUsed=${p.cellsUsed} nCtx=${p.nCtx}\n`); } catch {} + } + + const entries: [Branch, number][] = []; + for (const a of agents) { + if (a.state !== 'generating') continue; + + if (pressure.critical) { + a.state = 'done'; + yield* events.send({ type: 'agent:done', agentId: a.id }); + continue; + } + + const { token, text, isStop } = a.branch.produceSync(); + if (isStop) { + const parsed = ctx.parseChatOutput(a.rawOutput, a.fmt.format, { + reasoningFormat: a.fmt.reasoningFormat, + thinkingForcedOpen: a.fmt.thinkingForcedOpen, + parser: a.fmt.parser, + }); + + const tc = parsed.toolCalls[0]; + if (!tc) { + a.state = 'done'; + if (!a.findings && a.toolCallCount > 0 && parsed.content) { + a.findings = parsed.content; + yield* events.send({ type: 'agent:report', agentId: a.id, findings: a.findings }); + } + yield* events.send({ type: 'agent:done', agentId: a.id }); + continue; + } + + // Over budget: deny non-terminal tool calls when the agent has + // exceeded maxTurns or KV headroom is negative. Terminal tools + // (e.g. `report()`) are always allowed through β€” an agent that has + // done research and wants to report should never be blocked by + // pressure, since the report call itself consumes minimal KV. + const overBudget = (a.turns >= maxTurns || pressure.headroom < 0) + && (!terminalTool || tc.name !== terminalTool); + + if (overBudget) { + a.state = 'done'; + yield* events.send({ type: 'agent:done', agentId: a.id }); + continue; + } + + // Terminal tool β€” intercept, extract findings, mark done. + if (terminalTool && tc.name === terminalTool) { + if (a.toolCallCount === 0 && hasNonTerminalTools) { + const callId = tc.id || `call_${a.toolCallCount}`; + const errorMsg = 'You must perform research before reporting. Call at least one tool first.'; + a.turns++; + a.state = 'awaiting_tool'; + pendingToolCount++; + scope.run(function*() { + try { + const prefillTokens = buildToolResultDelta(ctx, JSON.stringify({ error: errorMsg }), callId); + settledBuffer.push({ agentId: a.id, prefillTokens, toolName: tc.name }); + } finally { + pendingToolCount--; + if (wakeIdle) { wakeIdle(); wakeIdle = null; } + } + }); + a.rawOutput = ''; + continue; + } + try { a.findings = JSON.parse(tc.arguments).findings; } catch { a.findings = tc.arguments; } + a.state = 'done'; + a.toolCallCount++; + totalToolCalls++; + yield* events.send({ type: 'agent:tool_call', agentId: a.id, tool: tc.name, args: tc.arguments }); + yield* events.send({ type: 'agent:report', agentId: a.id, findings: a.findings! }); + yield* events.send({ type: 'agent:done', agentId: a.id }); + continue; + } + + // Fire-and-forget β€” dispatch tool without blocking the decode loop + yield* dispatchTool(a, tc); + a.rawOutput = ''; + continue; + } + + entries.push([a.branch, token]); + a.rawOutput += text; + a.tokenCount++; + if (trace) { + const entropy = a.branch.modelEntropy(); + const surprisal = a.branch.modelSurprisal(token); + a.traceBuffer.push({ text, entropy, surprisal }); + yield* events.send({ + type: 'agent:produce', agentId: a.id, text, tokenCount: a.tokenCount, + entropy, surprisal, + }); + } else { + yield* events.send({ type: 'agent:produce', agentId: a.id, text, tokenCount: a.tokenCount }); + } + } + + // -- Phase 2: COMMIT -- batch-decode produced tokens + if (entries.length > 0) { + yield* call(() => store.commit(entries)); + steps++; + } + + // -- Phase 3: SETTLE -- drain settled tool buffer, batch prefill + const settled = settledBuffer.splice(0); + if (settled.length > 0) { + // Fresh snapshot β€” Phase 2 commits may have advanced positions + const settlePressure = new ContextPressure(ctx, pressureOpts); + let headroom = settlePressure.headroom; + + if (trace) { + const p = ctx._storeKvPressure(); + const items = settled.map(s => `${s.toolName}:${s.prefillTokens.length}`).join(', '); + try { process.stderr.write(`[SETTLE] remaining=${p.remaining} headroom=${headroom} cellsUsed=${p.cellsUsed} nCtx=${p.nCtx} items=[${items}]\n`); } catch {} + } + + const prefillPairs: [Branch, number[]][] = []; + const settledAgents: AgentInternal[] = []; + + for (const item of settled) { + const a = agentById.get(item.agentId); + if (!a || a.state === 'done') continue; + + if (item.prefillTokens.length > headroom) { + if (trace) { + try { process.stderr.write(`[SETTLE] REJECT ${item.toolName}:${item.prefillTokens.length} > headroom=${headroom}\n`); } catch {} + } + a.state = 'done'; + yield* events.send({ type: 'agent:done', agentId: a.id }); + continue; + } + + prefillPairs.push([a.branch, item.prefillTokens]); + settledAgents.push(a); + headroom -= item.prefillTokens.length; + } + + if (prefillPairs.length > 0) { + if (trace) { + const totalPrefill = prefillPairs.reduce((s, [, t]) => s + t.length, 0); + try { process.stderr.write(`[SETTLE] PREFILL ${prefillPairs.length} branches, ${totalPrefill} tokens, headroom_after=${headroom}\n`); } catch {} + } + yield* call(() => store.prefill(prefillPairs)); + counters.warmPrefillCalls++; + counters.warmPrefillBranches += prefillPairs.length; + + // Only NOW transition state + reset grammar + for (const a of settledAgents) { + a.state = 'generating'; + a.rawOutput = ''; + applyLazyGrammar(a); + } + } + } + + // -- Termination + idle yield + const allDone = agents.every(a => a.state === 'done') && pendingToolCount === 0; + if (allDone) break; + + if (entries.length === 0 && pendingToolCount > 0) { + counters.stalledTicks++; + if (settled.length === 0) { + // Nothing produced, nothing settled β€” yield until a tool resolves + yield* action((resolve) => { + wakeIdle = resolve; + return () => { wakeIdle = null; }; + }); + counters.idleTicks++; + } + } + } + + // ── Provide result β€” suspends, branches stay alive ─────── + // Branch cleanup is handled by each branch's ensure() from setupAgent β€” + // when this resource's scope exits, all ensure() callbacks fire. + const result: AgentPoolResult = { + agents: agents.map(a => ({ + agentId: a.id, + parentAgentId: a.parentId, + branch: a.branch, + findings: a.findings, + toolCallCount: a.toolCallCount, + tokenCount: a.tokenCount, + ppl: a.branch.perplexity, + samplingPpl: a.branch.samplingPerplexity, + trace: trace ? a.traceBuffer : undefined, + })), + totalTokens: agents.reduce((s, a) => s + a.tokenCount, 0), + totalToolCalls, + steps, + counters, + }; + + yield* provide(result); + }); +} diff --git a/src/agents/context.ts b/src/agents/context.ts new file mode 100644 index 0000000..3fc593a --- /dev/null +++ b/src/agents/context.ts @@ -0,0 +1,36 @@ +import { createContext } from 'effection'; +import type { SessionContext } from '../types'; +import type { BranchStore } from '../BranchStore'; +import type { Channel } from 'effection'; +import type { AgentEvent } from './types'; + +/** + * Effection context holding the active {@link SessionContext} + * + * Set by {@link initAgents} in the caller's scope. All agent operations + * (`generate`, `diverge`, `useAgentPool`, `withSharedRoot`) read from this + * context via `yield* Ctx.expect()`. + * + * @category Agents + */ +export const Ctx = createContext('lloyal.ctx'); + +/** + * Effection context holding the active {@link BranchStore} + * + * Set by {@link initAgents}. Used by {@link diverge} and {@link useAgentPool} + * for batched commit/prefill across multiple branches. + * + * @category Agents + */ +export const Store = createContext('lloyal.store'); + +/** + * Effection context holding the agent event channel + * + * Set by {@link initAgents}. {@link useAgentPool} emits {@link AgentEvent} + * values through this channel via `yield* channel.send()`. + * + * @category Agents + */ +export const Events = createContext>('lloyal.events'); diff --git a/src/agents/deltas.ts b/src/agents/deltas.ts new file mode 100644 index 0000000..baf12d0 --- /dev/null +++ b/src/agents/deltas.ts @@ -0,0 +1,63 @@ +import type { SessionContext } from '../types'; + +/** + * Build a token delta for a user turn + * + * Composes `getTurnSeparator()` + `formatChatSync()` + `tokenizeSync()` into a + * single token array suitable for `branch.prefill()`. Usable with any + * branch β€” not tied to {@link Session}'s trunk. + * + * This is the canonical way to build a user-turn delta for warm prefill + * in multi-turn conversations. + * + * @param ctx - Active session context + * @param content - User message content + * @param opts - Optional tools JSON for tool-aware formatting + * @returns Token array ready for `branch.prefill()` + * + * @category Agents + */ +export function buildUserDelta( + ctx: SessionContext, + content: string, + opts: { tools?: string } = {} +): number[] { + const sep = ctx.getTurnSeparator(); + const fmtOpts = opts.tools ? { tools: opts.tools } : {}; + const { prompt } = ctx.formatChatSync( + JSON.stringify([{ role: 'system', content: '' }, { role: 'user', content }]), + fmtOpts + ); + const delta = ctx.tokenizeSync(prompt, false); + return [...sep, ...delta]; +} + +/** + * Build a token delta for a tool result turn + * + * Composes `getTurnSeparator()` + `formatChatSync()` + `tokenizeSync()` into a + * single token array suitable for `branch.prefill()`. Used by + * {@link useAgentPool} to inject tool results back into agent context. + * + * @param ctx - Active session context + * @param resultStr - JSON-serialized tool result + * @param callId - Tool call identifier from the model's parsed output + * @returns Token array ready for `branch.prefill()` + * + * @category Agents + */ +export function buildToolResultDelta( + ctx: SessionContext, + resultStr: string, + callId: string +): number[] { + const sep = ctx.getTurnSeparator(); + const { prompt } = ctx.formatChatSync( + JSON.stringify([ + { role: 'system', content: '' }, + { role: 'tool', content: resultStr, tool_call_id: callId }, + ]) + ); + const delta = ctx.tokenizeSync(prompt, false); + return [...sep, ...delta]; +} diff --git a/src/agents/diverge.ts b/src/agents/diverge.ts new file mode 100644 index 0000000..ed1be3e --- /dev/null +++ b/src/agents/diverge.ts @@ -0,0 +1,145 @@ +import { call, ensure } from 'effection'; +import type { Operation } from 'effection'; +import { Branch } from '../Branch'; +import { Ctx, Store } from './context'; +import { ContextPressure } from './agent-pool'; +import type { DivergeOptions, DivergeResult, DivergeAttempt } from './types'; + +/** + * Multi-branch perplexity selection as an Effection operation + * + * Forks N branches from a parent (or a fresh root), generates to EOG via + * batched {@link BranchStore.commit}, then selects the lowest-perplexity + * attempt. Loser branches are pruned; the caller receives the best branch + * still alive. + * + * When `opts.parent` is provided, the parent branch is NOT pruned β€” it's + * owned by the calling scope. Only the forked attempt branches (losers) + * are pruned. The caller owns the winning branch's lifecycle, typically + * via {@link Session.promote}. + * + * Cleanup is structured: each forked branch registers an `ensure()` callback + * that prunes it on scope exit. Winners are marked disposed-safe (already + * pruned or ownership transferred) before the ensure fires. + * + * @param opts - Diverge options specifying parent or prompt, attempt count, + * and sampling parameters + * @returns Result containing the best branch, all attempt outputs, and + * aggregate statistics + * + * @example Verify with perplexity selection + * ```typescript + * const verified = yield* diverge({ + * prompt: verifyPrompt, + * attempts: 3, + * params: { temperature: 0.7 }, + * }); + * // verified.best is the lowest-perplexity branch, still alive + * yield* call(() => session.promote(verified.best)); + * ``` + * + * @category Agents + */ +export function* diverge(opts: DivergeOptions): Operation { + const ctx = yield* Ctx.expect(); + const store = yield* Store.expect(); + + // If parent provided, fork from it. Otherwise create a fresh root. + let root: Branch; + let ownRoot = false; + let prefixLength: number; + + if (opts.parent) { + root = opts.parent; + prefixLength = root.position; + } else { + if (!opts.prompt) throw new Error('diverge() requires either opts.parent or opts.prompt'); + const tokens = ctx.tokenizeSync(opts.prompt); + root = Branch.create(ctx, 0, opts.params ?? {}); + yield* call(() => root.prefill(tokens)); + prefixLength = tokens.length; + ownRoot = true; + // If we created the root, ensure it's cleaned up + yield* ensure(() => { + if (ownRoot && !root.disposed) { + try { root.pruneSync(); } catch { /* children may remain */ } + } + }); + } + + const live: { branch: Branch; output: string; done: boolean; tokenCount: number; ppl: number }[] = []; + + for (let i = 0; i < opts.attempts; i++) { + const branch = root.forkSync(); + // Each forked branch gets its own ensure() for structured cleanup + yield* ensure(() => { + if (!branch.disposed) { + try { branch.pruneSync(); } catch { /* already gone */ } + } + }); + branch.reseedSampler(2000 + i); + live.push({ branch, output: '', done: false, tokenCount: 0, ppl: Infinity }); + } + + // Batched generation β€” produceSync/commit loop + let steps = 0; + for (;;) { + const pressure = new ContextPressure(ctx); + if (pressure.critical) { + for (const a of live) { if (!a.done) a.done = true; } + break; + } + + const entries: [Branch, number][] = []; + for (const a of live) { + if (a.done) continue; + const { token, text, isStop } = a.branch.produceSync(); + if (isStop) { + const p = a.branch.perplexity; + a.ppl = Number.isFinite(p) ? p : Infinity; + a.done = true; + continue; + } + entries.push([a.branch, token]); + a.output += text; + a.tokenCount++; + } + if (entries.length === 0) break; + yield* call(() => store.commit(entries)); + steps++; + } + + // Select by lowest perplexity (most coherent) + const bestIdx = live.reduce((bi, a, i) => a.ppl <= live[bi].ppl ? i : bi, 0); + + // Prune losers now β€” winner stays alive as caller's result. + // ensure() will be a no-op for these since they're already disposed. + for (let i = 0; i < live.length; i++) { + if (i !== bestIdx && !live[i].branch.disposed) { + live[i].branch.pruneSync(); + } + } + + // If we created root and it's no longer needed, prune it now. + // (ensure() will be a no-op since it checks disposed) + if (ownRoot && !root.disposed && root.children.length === 0) { + root.pruneSync(); + } + + const totalTokens = live.reduce((s, a) => s + a.tokenCount, 0); + const attempts: DivergeAttempt[] = live.map(a => ({ + branch: a.branch, + output: a.output, + tokenCount: a.tokenCount, + ppl: a.ppl, + })); + + return { + best: live[bestIdx].branch, + bestOutput: live[bestIdx].output, + attempts, + totalTokens, + steps, + prefixLength, + }; +} diff --git a/src/agents/generate.ts b/src/agents/generate.ts new file mode 100644 index 0000000..4a37c66 --- /dev/null +++ b/src/agents/generate.ts @@ -0,0 +1,59 @@ +import { call } from 'effection'; +import type { Operation } from 'effection'; +import { Branch } from '../Branch'; +import { Ctx } from './context'; +import type { GenerateOptions, GenerateResult } from './types'; + +/** + * Single-branch grammar-constrained generation as an Effection operation + * + * Creates a fresh branch at position 0, prefills the prompt, generates + * to EOG, and prunes the branch. Uses {@link Branch}'s async iterator + * β€” single-branch generation doesn't need batched commit. + * + * The branch is always cleaned up via try/finally, even on error or + * scope cancellation. + * + * @param opts - Generation options (prompt, grammar, params, parse) + * @returns Generated text, token count, and optionally parsed result + * + * @example Grammar-constrained JSON generation + * ```typescript + * const plan = yield* generate({ + * prompt: planPrompt, + * grammar: planGrammar, + * params: { temperature: 0.3 }, + * parse: output => JSON.parse(output), + * }); + * console.log(plan.parsed); // typed result from parse() + * ``` + * + * @category Agents + */ +export function* generate(opts: GenerateOptions): Operation> { + const ctx = yield* Ctx.expect(); + + const samplerParams = opts.params ?? {}; + const branch = Branch.create(ctx, 0, samplerParams, undefined, opts.grammar); + + try { + const tokens = ctx.tokenizeSync(opts.prompt); + yield* call(() => branch.prefill(tokens)); + + // Consume async iterator inside call() β€” generators can't use for-await + const { output, tokenCount } = yield* call(async () => { + let output = ''; + let tokenCount = 0; + for await (const { text } of branch) { + output += text; + tokenCount++; + } + return { output, tokenCount }; + }); + + const parsed = opts.parse ? opts.parse(output) as T : undefined; + return { output, tokenCount, parsed }; + } finally { + if (!branch.disposed) branch.pruneSync(); + } +} diff --git a/src/agents/index.ts b/src/agents/index.ts new file mode 100644 index 0000000..6d5b889 --- /dev/null +++ b/src/agents/index.ts @@ -0,0 +1,32 @@ +export { Ctx, Store, Events } from './context'; +export { Tool } from './Tool'; +export { buildUserDelta, buildToolResultDelta } from './deltas'; +export { generate } from './generate'; +export { diverge } from './diverge'; +export { useAgentPool, ContextPressure } from './agent-pool'; +export { runAgents } from './run-agents'; +export { createToolkit } from './toolkit'; +export { initAgents } from './init'; +export { withSharedRoot } from './shared-root'; + +export type { Toolkit } from './toolkit'; +export type { AgentHandle } from './init'; +export type { SharedRootOptions } from './shared-root'; + +export type { + TraceToken, + JsonSchema, + ToolSchema, + ToolContext, + PressureThresholds, + AgentTaskSpec, + AgentPoolOptions, + AgentResult, + AgentPoolResult, + GenerateOptions, + GenerateResult, + DivergeOptions, + DivergeAttempt, + DivergeResult, + AgentEvent, +} from './types'; diff --git a/src/agents/init.ts b/src/agents/init.ts new file mode 100644 index 0000000..d7ebbd6 --- /dev/null +++ b/src/agents/init.ts @@ -0,0 +1,78 @@ +import { ensure, createChannel, call } from 'effection'; +import type { Operation, Channel } from 'effection'; +import { BranchStore } from '../BranchStore'; +import { Session } from '../Session'; +import type { SessionContext } from '../types'; +import { Ctx, Store, Events } from './context'; +import type { AgentEvent } from './types'; + +/** + * Handle returned by {@link initAgents} containing all agent resources + * + * @category Agents + */ +export interface AgentHandle { + /** The session context (model, tokenizer, KV cache) */ + ctx: SessionContext; + /** Branch store for batched commit/prefill across branches */ + store: BranchStore; + /** Session managing conversation trunk and branch lifecycle */ + session: Session; + /** Channel for subscribing to agent events */ + events: Channel; +} + +/** + * Bootstrap the agent infrastructure and register structured cleanup + * + * Creates {@link BranchStore}, {@link Session}, and an event channel, then + * sets all three Effection contexts ({@link Ctx}, {@link Store}, + * {@link Events}) in the caller's scope. Cleanup runs on scope exit + * (Ctrl-C, error, normal completion) via `ensure()`. + * + * Context values are set in the caller's scope β€” visible to all subsequent + * operations. This is why `initAgents` uses `ensure()` rather than + * `resource()`: a resource creates a child scope where `Ctx.set()` would + * be invisible to sibling operations. + * + * The caller creates the {@link SessionContext} (model path, nCtx, KV types + * are harness-specific decisions) and passes it in. + * + * @param ctx - Session context created via `createContext()` + * @returns Agent handle with session, store, and event channel + * + * @example Canonical bootstrap + * ```typescript + * main(function*() { + * const ctx = yield* call(() => createContext({ + * modelPath, nCtx: 16384, + * nSeqMax: 4, typeK: 'q4_0', typeV: 'q4_0', + * })); + * + * const { session, events } = yield* initAgents(ctx); + * // Ctx, Store, Events are now set β€” generate(), diverge(), + * // useAgentPool() will find them automatically. + * // Cleanup runs on scope exit. + * }); + * ``` + * + * @category Agents + */ +export function* initAgents( + ctx: SessionContext, +): Operation> { + const store = new BranchStore(ctx); + const session = new Session({ ctx, store }); + const events: Channel = createChannel(); + + yield* Ctx.set(ctx); + yield* Store.set(store); + yield* Events.set(events as unknown as Channel); + + yield* ensure(function*() { + yield* call(() => session.dispose()); + ctx.dispose(); + }); + + return { ctx, store, session, events }; +} diff --git a/src/agents/run-agents.ts b/src/agents/run-agents.ts new file mode 100644 index 0000000..b2c71dc --- /dev/null +++ b/src/agents/run-agents.ts @@ -0,0 +1,45 @@ +import { scoped } from 'effection'; +import type { Operation } from 'effection'; +import { useAgentPool } from './agent-pool'; +import type { AgentPoolOptions, AgentPoolResult } from './types'; + +/** + * Run an agent pool with automatic branch cleanup on return + * + * Wraps {@link useAgentPool} in `scoped()` β€” agent branches are pruned + * when the scope exits, before this operation returns. Use this when you + * don't need to fork from agent branches after the pool completes. + * + * For multi-level tree topology (forking from agent branches for + * verification or follow-up), use {@link useAgentPool} directly within + * your own scope management. + * + * @param opts - Pool configuration: tasks, tools, sampling params, max turns + * @returns Agent pool result (branches already pruned) + * + * @example Research agents with shared root + * ```typescript + * const pool = yield* withSharedRoot( + * { systemPrompt: RESEARCH_PROMPT, tools: toolsJson }, + * function*(root, prefixLen) { + * return yield* runAgents({ + * tasks: questions.map(q => ({ + * systemPrompt: RESEARCH_PROMPT, + * content: q, + * tools: toolsJson, + * parent: root, + * })), + * tools: toolMap, + * maxTurns: 6, + * }); + * }, + * ); + * ``` + * + * @category Agents + */ +export function* runAgents(opts: AgentPoolOptions): Operation { + return yield* scoped(function*() { + return yield* useAgentPool(opts); + }); +} diff --git a/src/agents/shared-root.ts b/src/agents/shared-root.ts new file mode 100644 index 0000000..101958a --- /dev/null +++ b/src/agents/shared-root.ts @@ -0,0 +1,80 @@ +import { call } from 'effection'; +import type { Operation } from 'effection'; +import { Branch } from '../Branch'; +import type { SessionContext } from '../types'; +import { Ctx } from './context'; +import type { SamplingParams } from './types'; + +/** + * Configuration for {@link withSharedRoot} + * + * @category Agents + */ +export interface SharedRootOptions { + /** System prompt to tokenize and prefill into the shared root */ + systemPrompt: string; + /** JSON-serialized tool schemas for tool-aware prompt formatting */ + tools?: string; + /** Sampling parameters for the root branch */ + params?: SamplingParams; +} + +/** + * Scoped shared root branch with guaranteed cleanup + * + * Creates a root branch, prefills the system prompt, and passes it to + * the body function. The root is pruned via try/finally when the body + * returns or throws, regardless of whether children still exist. + * + * Use this for the cold-path pattern where multiple agents share a + * tokenized system prompt prefix. The `sharedPrefixLength` passed to + * the body enables KV savings calculation. + * + * @param opts - System prompt, tools, and sampling parameters + * @param body - Operation that receives the root branch and prefix length. + * Typically calls {@link runAgents} or {@link useAgentPool} inside. + * @returns The body's return value + * + * @example Cold-path research with shared prefix + * ```typescript + * const { result, prefixLen } = yield* withSharedRoot( + * { systemPrompt: RESEARCH_PROMPT, tools: toolsJson }, + * function*(root, prefixLen) { + * const result = yield* runAgents({ + * tasks: questions.map(q => ({ + * systemPrompt: RESEARCH_PROMPT, + * content: q, + * tools: toolsJson, + * parent: root, + * })), + * tools: toolMap, + * }); + * return { result, prefixLen }; + * }, + * ); + * ``` + * + * @category Agents + */ +export function* withSharedRoot( + opts: SharedRootOptions, + body: (root: Branch, sharedPrefixLength: number) => Operation, +): Operation { + const ctx: SessionContext = yield* Ctx.expect(); + + const messages = [{ role: 'system', content: opts.systemPrompt }]; + const fmtOpts = opts.tools + ? { tools: opts.tools, addGenerationPrompt: false } + : { addGenerationPrompt: false }; + const fmt = ctx.formatChatSync(JSON.stringify(messages), fmtOpts); + const sharedTokens = ctx.tokenizeSync(fmt.prompt); + + const root = Branch.create(ctx, 0, opts.params ?? { temperature: 0.5 }); + yield* call(() => root.prefill(sharedTokens)); + + try { + return yield* body(root, sharedTokens.length); + } finally { + if (!root.disposed) root.pruneSubtreeSync(); + } +} diff --git a/src/agents/toolkit.ts b/src/agents/toolkit.ts new file mode 100644 index 0000000..86bcf0c --- /dev/null +++ b/src/agents/toolkit.ts @@ -0,0 +1,44 @@ +import type { Tool } from './Tool'; + +/** + * Aggregated tool registry for agent pool consumption + * + * Contains the `toolMap` for dispatch and `toolsJson` for prompt + * formatting. Created by {@link createToolkit}. + * + * @category Agents + */ +export interface Toolkit { + /** Name-to-instance map used by {@link useAgentPool} for tool dispatch */ + toolMap: Map; + /** JSON-serialized tool schemas passed to `formatChat()` via task specs */ + toolsJson: string; +} + +/** + * Aggregate an array of {@link Tool} instances into a toolkit + * + * Builds both the dispatch map and the JSON schema string from the + * tool array. Pass the result directly to {@link AgentPoolOptions} + * and {@link AgentTaskSpec}. + * + * @param tools - Tool instances to aggregate + * @returns Toolkit with `toolMap` and `toolsJson` + * + * @example + * ```typescript + * const { toolMap, toolsJson } = createToolkit([ + * new SearchTool(chunks, reranker), + * new ReadFileTool(resources), + * new GrepTool(resources), + * ]); + * ``` + * + * @category Agents + */ +export function createToolkit(tools: Tool[]): Toolkit { + return { + toolMap: new Map(tools.map(t => [t.name, t])), + toolsJson: JSON.stringify(tools.map(t => t.schema)), + }; +} diff --git a/src/agents/types.ts b/src/agents/types.ts new file mode 100644 index 0000000..df8c468 --- /dev/null +++ b/src/agents/types.ts @@ -0,0 +1,378 @@ +import type { Branch } from '../Branch'; +import type { SessionContext } from '../types'; + +// ── Tool base class types ────────────────────────────────────── + +/** + * JSON Schema definition for tool parameter validation + * + * Describes the shape of arguments a {@link Tool} accepts. Passed to the + * model via `formatChat()` so it can generate valid tool-call arguments. + * + * @category Agents + */ +export interface JsonSchema { + /** JSON Schema type (e.g. `"object"`, `"string"`, `"array"`) */ + type: string; + /** Property definitions when `type` is `"object"` */ + properties?: Record; + /** Required property names when `type` is `"object"` */ + required?: string[]; + /** Additional schema constraints (minItems, enum, etc.) */ + [key: string]: unknown; +} + +/** + * OpenAI-compatible function tool schema + * + * The wrapper format expected by `formatChat()` when passing tools to the + * model. {@link Tool.schema} generates this automatically from the tool's + * `name`, `description`, and `parameters`. + * + * @category Agents + */ +export interface ToolSchema { + /** Always `"function"` for function-calling tools */ + type: 'function'; + /** Function definition containing name, description, and parameter schema */ + function: { + /** Tool name β€” used as the function identifier in tool calls */ + name: string; + /** Human-readable description shown to the model */ + description: string; + /** JSON Schema describing the tool's arguments */ + parameters: JsonSchema; + }; +} + +/** + * Execution context passed to {@link Tool.execute} + * + * Provides callbacks for reporting progress during long-running tool + * operations (e.g. reranker scoring chunks). + * + * @category Agents + */ +export interface ToolContext { + /** Progress callback for long-running operations */ + onProgress?: (p: { filled: number; total: number }) => void; +} + +// ── Trace types ─────────────────────────────────────────────── + +/** + * Per-token trace entry captured when {@link AgentPoolOptions.trace} is true + * + * Each entry corresponds to one sampled token and the distribution state + * at the moment it was drawn. Available on {@link AgentResult.trace} after + * pool completion. + * + * @category Agents + */ +export interface TraceToken { + /** Decoded text for this token */ + text: string; + /** Shannon entropy of the full vocabulary distribution (bits, base-2) */ + entropy: number; + /** Surprisal of the chosen token: -log2(p) */ + surprisal: number; +} + +// ── Agent pool types ─────────────────────────────────────────── + +/** + * Task specification for a single agent in {@link useAgentPool} + * + * Each task defines the agent's system prompt, user content, available + * tools, and parent branch to fork from. The parent branch determines + * the agent's KV prefix β€” fork from a shared root to amortize system + * prompt tokenization across agents. + * + * @category Agents + */ +export interface AgentTaskSpec { + /** System prompt defining the agent's role and behavior */ + systemPrompt: string; + /** User message content β€” the agent's specific sub-question or task */ + content: string; + /** JSON-serialized tool schemas (from {@link createToolkit}) */ + tools?: string; + /** PRNG seed for sampler diversity β€” pass different seeds per agent */ + seed?: number; + /** Parent branch to fork from (required by {@link useAgentPool}) */ + parent?: Branch; +} + +/** + * Sampling parameters for generation + * + * Controls the sampler chain applied during token generation. Passed to + * {@link Branch.create}, {@link generate}, {@link diverge}, and agent + * pool tasks. + * + * @category Agents + */ +export interface SamplingParams { + /** Temperature for softmax scaling (0 = greedy, higher = more random) */ + temperature?: number; + /** Nucleus sampling threshold β€” cumulative probability cutoff */ + topP?: number; + /** Top-K sampling β€” keep only the K most likely tokens */ + topK?: number; + /** Minimum probability threshold relative to the most likely token */ + minP?: number; + /** Additional sampler-specific parameters */ + [key: string]: unknown; +} + +/** + * KV pressure thresholds controlling agent shutdown under context exhaustion + * + * Two thresholds govern what happens as remaining KV shrinks: + * + * **softLimit** (default 1024) β€” remaining KV floor for new work. + * Enforced at three points: + * - **SETTLE**: tool results that would cross this floor are rejected and + * the agent is marked done. This is the primary enforcement point β€” tool + * results (search results, etc.) are the largest KV consumers. + * - **PRODUCE (stop-token boundary)**: agents that want a non-terminal tool + * call are hard-cut. Terminal tools (e.g. `report()`) still pass. + * - **INIT prefill**: agents that don't fit above this floor are dropped. + * + * Set to account for downstream pool needs (reporters, verification). + * + * **hardLimit** (default 128) β€” crash-prevention floor. + * When remaining drops below this, agents are killed immediately before + * `produceSync()`. Prevents `llama_decode` "no memory slot" failures. + * Pure safety net β€” should never be the primary budget control. + * + * @category Agents + */ +export interface PressureThresholds { + /** + * Remaining KV floor for new work (tokens). When remaining drops below + * this, SETTLE rejects tool results, PRODUCE hard-cuts non-terminal tool + * calls, and INIT drops agents that don't fit. + * + * Set to account for downstream pool needs (reporters, verification). + * Default: 1024 + */ + softLimit?: number; + /** + * Crash-prevention floor (tokens). When remaining drops below this, + * agents are killed immediately before `produceSync()`. Prevents + * `llama_decode` "no memory slot for batch" failures. + * Default: 128 + */ + hardLimit?: number; +} + +/** + * Configuration for {@link useAgentPool} and {@link runAgents} + * + * @category Agents + */ +export interface AgentPoolOptions { + /** Agent task specifications β€” one per concurrent agent */ + tasks: AgentTaskSpec[]; + /** + * Tool registry mapping tool names to {@link Tool} instances. + * + * This is the **execution registry** β€” it determines which tools can be + * dispatched at runtime. It is distinct from the per-task `task.tools` + * JSON schema that tells the model which tools are available. + * + * The registry also controls {@link AgentPoolOptions.terminalTool | terminalTool} + * gating: if the registry contains only the terminal tool, agents are + * allowed to call it as their first action (e.g. reporter sub-agents). + * If the registry contains other tools, the first call must be + * non-terminal to prevent agents from reporting without doing work. + */ + tools: Map; + /** Sampling parameters applied to all agents */ + params?: SamplingParams; + /** Maximum tool-call turns per agent before forced termination */ + maxTurns?: number; + /** Tool name that signals agent completion. When the model calls this tool, + * findings are extracted from arguments and the agent is marked done. + * The tool is intercepted β€” never dispatched to execute(). If omitted, + * agents complete only via stop token or hard-cut. */ + terminalTool?: string; + /** Enable per-token entropy/surprisal on `agent:produce` events */ + trace?: boolean; + /** KV pressure thresholds β€” tune per pool. Reporter pools typically use + * lower thresholds than research pools since they complete in a single + * terminal tool call. See {@link PressureThresholds} for tuning guidance. */ + pressure?: PressureThresholds; +} + +/** + * Result for a single completed agent + * + * @category Agents + */ +export interface AgentResult { + /** Stable agent identifier (branch handle at creation time) */ + agentId: number; + /** Parent branch handle β€” shared root for top-level agents, parent agentId for sub-agents */ + parentAgentId: number; + /** The agent's branch β€” still alive when returned from {@link useAgentPool} */ + branch: Branch; + /** Agent's research findings (from terminal tool or final output), or null */ + findings: string | null; + /** Number of tool calls the agent made */ + toolCallCount: number; + /** Total tokens generated by this agent */ + tokenCount: number; + /** Model-level perplexity at completion (exp of mean NLL from raw logits) */ + ppl: number; + /** Sampling-level perplexity at completion (from filtered distribution) */ + samplingPpl: number; + /** Per-token trace data (present only when {@link AgentPoolOptions.trace} is true) */ + trace?: TraceToken[]; +} + +/** + * Aggregate result from a completed agent pool run + * + * Returned by both {@link useAgentPool} and {@link runAgents}. Contains + * per-agent results plus aggregate statistics for display and telemetry. + * + * @category Agents + */ +export interface AgentPoolResult { + /** Per-agent results in task order */ + agents: AgentResult[]; + /** Sum of all agent token counts */ + totalTokens: number; + /** Sum of all agent tool calls */ + totalToolCalls: number; + /** Number of batched commit steps in the tick loop */ + steps: number; + /** Internal performance counters for telemetry */ + counters: { + /** Number of batch prefill calls for tool result injection */ + warmPrefillCalls: number; + /** Total branches across all warm prefill batches */ + warmPrefillBranches: number; + /** Ticks where no agent was generating (all awaiting tools) */ + stalledTicks: number; + /** Peak concurrent tool executions */ + maxConcurrentTools: number; + /** Ticks spent idle-waiting via action() */ + idleTicks: number; + }; +} + +// ── Generate types ───────────────────────────────────────────── + +/** + * Options for single-branch {@link generate} + * + * @category Agents + */ +export interface GenerateOptions { + /** Pre-formatted prompt string (from `formatChat()` + `tokenize()`) */ + prompt: string; + /** GBNF grammar string for constrained generation */ + grammar?: string; + /** Sampling parameters */ + params?: SamplingParams; + /** Optional parser applied to the raw output string */ + parse?: (output: string) => unknown; +} + +/** + * Result from single-branch {@link generate} + * + * @category Agents + */ +export interface GenerateResult { + /** Raw generated text */ + output: string; + /** Number of tokens generated */ + tokenCount: number; + /** Parsed output (present only when `parse` was provided in options) */ + parsed?: T; +} + +// ── Diverge types ────────────────────────────────────────────── + +/** + * Options for multi-branch {@link diverge} + * + * Either `parent` or `prompt` must be provided. When `parent` is given, + * branches fork from it and no new root is created. When only `prompt` + * is given, a fresh root is created, prefilled, and cleaned up on error. + * + * @category Agents + */ +export interface DivergeOptions { + /** Pre-formatted prompt for creating a fresh root (mutually exclusive with parent) */ + prompt?: string; + /** Number of parallel generation attempts */ + attempts: number; + /** Parent branch to fork from (mutually exclusive with prompt) */ + parent?: Branch; + /** Sampling parameters for all attempts */ + params?: SamplingParams; +} + +/** + * Single attempt result from {@link diverge} + * + * @category Agents + */ +export interface DivergeAttempt { + /** The attempt's branch (only the best branch survives after diverge) */ + branch: Branch; + /** Generated text for this attempt */ + output: string; + /** Number of tokens generated */ + tokenCount: number; + /** Model perplexity β€” lower indicates more coherent generation */ + ppl: number; +} + +/** + * Aggregate result from {@link diverge} + * + * The `best` branch is still alive; all other attempt branches have been + * pruned. The caller owns cleanup β€” typically via {@link Session.promote} + * to make the best branch the new conversation trunk. + * + * @category Agents + */ +export interface DivergeResult { + /** Lowest-perplexity branch β€” still alive, caller owns cleanup */ + best: Branch; + /** Text output from the best attempt */ + bestOutput: string; + /** All attempts (losers already pruned, branches disposed) */ + attempts: DivergeAttempt[]; + /** Sum of all attempt token counts */ + totalTokens: number; + /** Number of batched commit steps */ + steps: number; + /** Shared prefix length in tokens (for KV savings calculation) */ + prefixLength: number; +} + +// ── Runtime events ───────────────────────────────────────────── + +/** + * Events emitted by the runtime during agent pool execution + * + * Subscribe to these via the `events` channel from {@link initAgents}. + * Harnesses can extend this union with phase-level events for display. + * + * @category Agents + */ +export type AgentEvent = + | { type: 'agent:spawn'; agentId: number; parentAgentId: number } + | { type: 'agent:produce'; agentId: number; text: string; tokenCount: number; entropy?: number; surprisal?: number } + | { type: 'agent:tool_call'; agentId: number; tool: string; args: string } + | { type: 'agent:tool_result'; agentId: number; tool: string; result: string } + | { type: 'agent:tool_progress'; agentId: number; tool: string; filled: number; total: number } + | { type: 'agent:report'; agentId: number; findings: string } + | { type: 'agent:done'; agentId: number }; diff --git a/src/binding.cpp b/src/binding.cpp index 447a07a..02a32da 100644 --- a/src/binding.cpp +++ b/src/binding.cpp @@ -1,5 +1,6 @@ #include #include "SessionContext.hpp" +#include "Util.hpp" #include /** @@ -24,6 +25,9 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) { // Export factory function exports.Set("createContext", Napi::Function::New(env, CreateContext)); + // Export utility functions (parseMarkdown, etc.) + Util::Init(env, exports); + return exports; } diff --git a/src/index.ts b/src/index.ts new file mode 100644 index 0000000..1ee3f6d --- /dev/null +++ b/src/index.ts @@ -0,0 +1,315 @@ +/** + * liblloyal-node - Thin N-API wrapper over liblloyal + * + * Exposes raw llama.cpp inference primitives for Node.js. + * + * @example + * ```js + * const { createContext } = require('@lloyal-labs/lloyal.node'); + * + * const ctx = await createContext({ + * modelPath: './model.gguf', + * nCtx: 2048, + * nThreads: 4 + * }); + * + * // Tokenize + * const tokens = await ctx.tokenize("Hello world"); + * + * // Generate via Branch API + * const branch = Branch.create(ctx, 0, { temperature: 0.7 }); + * await branch.prefill(tokens); + * for await (const { text } of branch) { + * process.stdout.write(text); + * } + * await branch.prune(); + * + * // Cleanup + * ctx.dispose(); + * ``` + */ + +import type { + ContextOptions, + GpuVariant, + LoadOptions, + NativeBinding, + SessionContext, +} from './types'; + +import { Branch } from './Branch'; +import { BranchStore } from './BranchStore'; +import { Session } from './Session'; +import { buildUserDelta, buildToolResultDelta } from './agents/deltas'; +import { Rerank } from './Rerank'; + +/** + * Platform package naming: @lloyal-labs/lloyal.node-{platform}-{arch}[-{gpu}] + */ +const getPlatformPackageName = (variant?: string): string => { + const platform = process.platform; + const arch = process.arch; + const noSuffix = !variant || variant === 'default' || variant === 'cpu' || variant === 'metal'; + const suffix = noSuffix ? '' : `-${variant}`; + return `@lloyal-labs/lloyal.node-${platform}-${arch}${suffix}`; +}; + +/** + * Try to load a platform package, return null on failure. + */ +const tryLoadPackage = (packageName: string, verbose = false): NativeBinding | null => { + try { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const mod = require(packageName) as NativeBinding; + if (mod && typeof mod.createContext === 'function') { + return mod; + } + if (verbose) { + console.warn(`[lloyal.node] ${packageName} loaded but missing createContext export`); + } + return null; + } catch (e) { + if (verbose) { + console.warn(`[lloyal.node] Failed to load ${packageName}: ${(e as Error).message}`); + } + return null; + } +}; + +/** + * Load native binary for a specific GPU variant + * + * lloyal.node ships as a family of platform-specific npm packages, each + * containing a prebuilt native addon: + * `@lloyal-labs/lloyal.node-{platform}-{arch}[-{gpu}]` + * (e.g., `darwin-arm64`, `linux-x64-cuda`, `win32-x64-vulkan`). + * + * `loadBinary()` resolves the correct package at runtime with a prioritized + * fallback chain: + * + * 1. Requested GPU variant package (if `variant` or `LLOYAL_GPU` env var set) + * 2. Local development build (`build/Release/lloyal.node`) + * 3. Default CPU platform package + * + * Most callers should use {@link createContext} directly β€” it calls + * `loadBinary()` internally. Use this function when you need to: + * - Pre-check whether a GPU variant is available before creating contexts + * - Share one loaded binary across multiple context creations + * - Inspect or test the binary loading logic in isolation + * + * **Environment variables:** + * - `LLOYAL_LOCAL=1` β€” Force local build only; throws if not found + * (use during development to test local C++ changes) + * - `LLOYAL_GPU=cuda|vulkan` β€” Request GPU variant (equivalent to `variant` param) + * - `LLOYAL_NO_FALLBACK=1` β€” Disable silent CPU fallback; throws if GPU + * variant fails (use in CI to catch missing runtime libraries) + * + * @param variant GPU variant: 'cuda', 'vulkan', or undefined for CPU + * @returns Native binary module with createContext method + * @throws Error if no binary available for the current platform + * + * @example + * ```typescript + * // Load default (CPU) binary + * const binary = loadBinary(); + * + * // Load CUDA binary (falls back to CPU if unavailable) + * const binary = loadBinary('cuda'); + * + * // Create context from loaded binary + * const ctx = await binary.createContext({ modelPath: './model.gguf' }); + * ``` + * + * @category Core + */ +export const loadBinary = (variant?: GpuVariant): NativeBinding => { + const resolvedVariant = variant ?? process.env.LLOYAL_GPU; + const noFallback = process.env.LLOYAL_NO_FALLBACK === '1'; + const useLocal = process.env.LLOYAL_LOCAL === '1'; + + // 0. Use local build if explicitly requested (no fallback) + if (useLocal) { + try { + return require('../build/Release/lloyal.node') as NativeBinding; + } catch { + throw new Error( + '[lloyal.node] LLOYAL_LOCAL=1 but local build not found. ' + + 'Run `npm run build` first.' + ); + } + } + + // 1. Try requested variant (if specified) + if (resolvedVariant && resolvedVariant !== 'default') { + const pkgName = getPlatformPackageName(resolvedVariant); + const binary = tryLoadPackage(pkgName, true); + if (binary) return binary; + + if (noFallback) { + throw new Error( + `[lloyal.node] GPU variant "${resolvedVariant}" failed to load. ` + + `Package: ${pkgName}. Check that runtime libraries are available.` + ); + } + console.warn(`[lloyal.node] GPU variant "${resolvedVariant}" unavailable, falling back to CPU`); + } + + // 2. Try local build (always fresher than installed packages during development) + try { + return require('../build/Release/lloyal.node') as NativeBinding; + } catch { + // ignore β€” no local build + } + + // 3. Try default platform package (CPU) + const defaultPkg = getPlatformPackageName(); + const binary = tryLoadPackage(defaultPkg, true); + if (binary) return binary; + + throw new Error( + `No lloyal.node binary found for ${process.platform}-${process.arch}. ` + + `Tried: ${resolvedVariant ? getPlatformPackageName(resolvedVariant) + ', ' : ''}${defaultPkg}` + ); +}; + +// Default binary (loaded lazily on first use) +let _binary: NativeBinding | null = null; +const getBinary = (): NativeBinding => { + if (!_binary) { + _binary = loadBinary(process.env.LLOYAL_GPU as GpuVariant | undefined); + } + return _binary; +}; + +/** + * Create a new inference context + * + * Entry point for all inference. Resolves the correct native binary (see + * {@link loadBinary} for the platform/GPU fallback chain), loads the model + * via a reference-counted registry (multiple contexts can share one model's + * weight tensors in memory), and allocates a `llama_context` with its own + * KV cache and compute scratch buffers. + * + * **What gets allocated:** + * - KV cache: `nCtx * 2 * nLayers * dHead` bytes per KV type (fp16 default). + * For a 7B model with `nCtx: 4096`, expect ~1-2 GB of KV memory. + * - Compute scratch: temporary buffers for the forward pass, sized to `nBatch`. + * + * **Model sharing:** If two contexts use the same `modelPath`, the model + * weights are loaded once and shared. Only the KV cache and compute buffers + * are per-context. This makes multi-context setups (e.g., one context per + * conversation) memory-efficient. + * + * @param options Context creation options + * @param loadOptions Optional binary loading options (GPU variant selection) + * @returns Promise resolving to SessionContext instance + * + * @example Basic usage + * ```typescript + * const ctx = await createContext({ + * modelPath: './model.gguf', + * nCtx: 2048, + * nThreads: 4 + * }); + * + * try { + * const tokens = await ctx.tokenize("Hello"); + * const branch = Branch.create(ctx, 0, { temperature: 0.7 }); + * await branch.prefill(tokens); + * for await (const { text } of branch) process.stdout.write(text); + * } finally { + * ctx.dispose(); + * } + * ``` + * + * @example Multi-branch context (tree search, best-of-N) + * ```typescript + * const ctx = await createContext({ + * modelPath: './model.gguf', + * nCtx: 8192, + * nBatch: 512, // Bin-packing capacity for BranchStore.prefill + * nSeqMax: 33, // 32 branches + 1 root sequence + * }); + * ``` + * + * @example With GPU variant selection + * ```typescript + * const ctx = await createContext( + * { modelPath: './model.gguf', nCtx: 4096 }, + * { gpuVariant: 'cuda' } + * ); + * ``` + * + * @category Core + */ +export const createContext = async ( + options: ContextOptions, + loadOptions?: LoadOptions +): Promise => { + const variant = loadOptions?.gpuVariant || process.env.LLOYAL_GPU; + const binary = variant ? loadBinary(variant as GpuVariant) : getBinary(); + return binary.createContext(options); +}; + +// ── Layer 1: Substrate (unchanged) ────────────────────────────── +export { Branch, BranchStore, Session, buildUserDelta, buildToolResultDelta, Rerank }; + +// ── Layer 2: Agents (structured concurrency) ──────────────────── +export { + Ctx, Store, Events, + Tool, + useAgentPool, + runAgents, + generate, + diverge, + createToolkit, + initAgents, + withSharedRoot, +} from './agents/index'; + +export type { + Toolkit, + AgentHandle, + SharedRootOptions, + JsonSchema, + ToolSchema, + ToolContext, + AgentTaskSpec, + AgentPoolOptions, + AgentResult, + AgentPoolResult, + GenerateOptions, + GenerateResult, + DivergeOptions, + DivergeAttempt, + DivergeResult, + AgentEvent, +} from './agents/index'; + +// ── Enums + types from types.ts ───────────────────────────────── +export { PoolingType, CHAT_FORMAT_CONTENT_ONLY, CHAT_FORMAT_GENERIC, ReasoningFormat, GrammarTriggerType } from './types'; +export type { ChatFormat } from './types'; +export type { + GpuVariant, + KvCacheType, + LoadOptions, + ContextOptions, + FormatChatOptions, + GrammarTrigger, + FormattedChatResult, + ParseChatOutputOptions, + ParsedToolCall, + ParseChatOutputResult, + PenaltyParams, + MirostatParams, + DryParams, + XtcParams, + AdvancedSamplingParams, + SamplingParams, + SessionContext, + Produced, + RerankOptions, + RerankResult, + RerankProgress, + NativeBinding, +} from './types'; diff --git a/lib/index.d.ts b/src/types.ts similarity index 61% rename from lib/index.d.ts rename to src/types.ts index 6c79a1c..aa97a19 100644 --- a/lib/index.d.ts +++ b/src/types.ts @@ -16,6 +16,9 @@ * Parallel and tree-structured generation with batched GPU dispatch. */ +import type { Branch } from './Branch'; +import type { BranchStore } from './BranchStore'; + /** * GPU variant for binary loading * @@ -90,6 +93,69 @@ export enum PoolingType { RANK = 4, } +/** + * Chat format detected by the template engine + * + * Identifies how the model formats tool calls, reasoning blocks, and content. + * Opaque chat format identifier returned by + * {@link SessionContext.formatChat | formatChat()} and consumed by + * {@link SessionContext.parseChatOutput | parseChatOutput()}. + * + * Maps 1:1 to llama.cpp's `common_chat_format` enum (30+ values). + * Treat as an opaque number β€” pass through, don't switch on it. + * + * @category Chat + */ +export type ChatFormat = number; + +/** Model template has no tool/structured-output support. */ +export const CHAT_FORMAT_CONTENT_ONLY: ChatFormat = 0; + +/** llama.cpp's generic JSON fallback β€” imposes format the model wasn't trained on. */ +export const CHAT_FORMAT_GENERIC: ChatFormat = 1; + +/** + * Reasoning/thinking block format + * + * Controls how `` blocks are handled during formatting and parsing. + * + * @see {@link FormatChatOptions.reasoningFormat} for input-side usage + * @see {@link ParseChatOutputOptions.reasoningFormat} for output-side usage + * + * @category Chat + */ +export enum ReasoningFormat { + /** No reasoning extraction (default) */ + NONE = 0, + /** Auto-detect reasoning format from model template */ + AUTO = 1, + /** DeepSeek legacy format (`...` in content) */ + DEEPSEEK_LEGACY = 2, + /** DeepSeek format (structured reasoning extraction) */ + DEEPSEEK = 3, +} + +/** + * Grammar trigger type + * + * Determines how lazy grammar activation is triggered during generation. + * + * @see {@link GrammarTrigger} + * @see {@link FormattedChatResult.grammarTriggers} + * + * @category Chat + */ +export enum GrammarTriggerType { + /** Trigger on a specific token ID */ + TOKEN = 0, + /** Trigger on a word boundary match */ + WORD = 1, + /** Trigger on a regex pattern match */ + PATTERN = 2, + /** Trigger on a full-string regex pattern match */ + PATTERN_FULL = 3, +} + /** * Configuration for context creation * @@ -177,71 +243,6 @@ export interface ContextOptions { typeV?: KvCacheType; } -/** - * Chat format detected by the template engine - * - * Identifies how the model formats tool calls, reasoning blocks, and content. - * Returned by {@link SessionContext.formatChat | formatChat()} in - * {@link FormattedChatResult.format} and consumed by - * {@link SessionContext.parseChatOutput | parseChatOutput()}. - * - * You generally don't need to inspect these values directly -- - * just pass them through from the formatChat result to parseChatOutput. - * - * Only commonly-used values are listed. The full set matches llama.cpp's - * `common_chat_format` enum (30+ formats). - * - * @category Chat - */ -export enum ChatFormat { - /** Plain content, no special formatting */ - CONTENT_ONLY = 0, - /** Generic tool call format */ - GENERIC = 1, -} - -/** - * Reasoning/thinking block format - * - * Controls how `` blocks are handled during formatting and parsing. - * - * @see {@link FormatChatOptions.reasoningFormat} for input-side usage - * @see {@link ParseChatOutputOptions.reasoningFormat} for output-side usage - * - * @category Chat - */ -export enum ReasoningFormat { - /** No reasoning extraction (default) */ - NONE = 0, - /** Auto-detect reasoning format from model template */ - AUTO = 1, - /** DeepSeek legacy format (`...` in content) */ - DEEPSEEK_LEGACY = 2, - /** DeepSeek format (structured reasoning extraction) */ - DEEPSEEK = 3, -} - -/** - * Grammar trigger type - * - * Determines how lazy grammar activation is triggered during generation. - * - * @see {@link GrammarTrigger} - * @see {@link FormattedChatResult.grammarTriggers} - * - * @category Chat - */ -export enum GrammarTriggerType { - /** Trigger on a specific token ID */ - TOKEN = 0, - /** Trigger on a word boundary match */ - WORD = 1, - /** Trigger on a regex pattern match */ - PATTERN = 2, - /** Trigger on a full-string regex pattern match */ - PATTERN_FULL = 3, -} - /** * Options for chat template formatting * @@ -775,6 +776,20 @@ export interface SessionContext { */ tokenize(text: string, addSpecial?: boolean): Promise; + /** + * Tokenize text into model's vocabulary (sync β€” inline on main thread) + * + * Same as {@link tokenize} but synchronous. Use from Effection generators + * to avoid `yield* call()` overhead for CPU-only work. + * + * @param text Text to tokenize + * @param addSpecial Whether to add special tokens (BOS/EOS). Defaults to + * model metadata setting (typically true). Pass false for mid-sequence + * tokenization. + * @returns Array of token IDs + */ + tokenizeSync(text: string, addSpecial?: boolean): number[]; + /** * Detokenize array of tokens back to text * @@ -824,7 +839,7 @@ export interface SessionContext { * - Forgetting specific messages * - Preparing for injection of new context * - * ⚠️ CRITICAL: Call BEFORE next decode(), not after! + * CRITICAL: Call BEFORE next decode(), not after! * The model needs to know about the removal before processing new tokens. * * Cost: ~1-5ms depending on range @@ -884,7 +899,6 @@ export interface SessionContext { * Use when starting a completely new conversation. * * Cost: ~1ms - * */ kvCacheClear(): Promise; @@ -902,8 +916,8 @@ export interface SessionContext { * **Why not naive eviction?** Selective eviction (`kvCacheRemove`) preserves * original position IDs, which grow without bound. Across 5 architectures, * naive eviction produces PPL spanning 3 orders of magnitude β€” ranging from - * 1.15Γ— baseline (Llama, lucky config) to 198Γ— (Phi, sinks present). - * Under Blink KV reconstruction, all 5 converge to 3–16% of baseline. + * 1.15x baseline (Llama, lucky config) to 198x (Phi, sinks present). + * Under Blink KV reconstruction, all 5 converge to 3-16% of baseline. * * **Sinks are optional.** Under reconstruction, the 0+N (sinkless) config * matches 4+N (with sinks) within <2% across all tested architectures. @@ -921,7 +935,7 @@ export interface SessionContext { * @param sinks First N tokens from conversation start (typically 4, or empty). * Must be the same tokens every reseed β€” reusing different tokens degrades * any attention-sink patterns the model may have learned for early positions. - * @param tail Recent M tokens to preserve (typically 252–1020) + * @param tail Recent M tokens to preserve (typically 252-1020) * @returns Promise that resolves when reconstruction completes. * Next decode continues at position `sinks.length + tail.length`. * @@ -960,7 +974,7 @@ export interface SessionContext { * physical KV entries for the shared prefix; only tokens decoded after * the fork point allocate new storage. This is what makes tree-structured * generation (best-of-N, beam search, speculative decoding) memory-efficient: - * N branches sharing a 1000-token prefix cost ~1000 KV entries, not NΓ—1000. + * N branches sharing a 1000-token prefix cost ~1000 KV entries, not N*1000. * * The higher-level {@link Branch.fork} wraps this and additionally clones * the sampler chain, grammar state, logits snapshot, and perplexity tracker. @@ -1050,7 +1064,7 @@ export interface SessionContext { /** * Format messages using model's chat template * - * Converts [{role, content}] β†’ formatted prompt string with full format awareness. + * Converts [{role, content}] -> formatted prompt string with full format awareness. * Uses model's built-in template (ChatML, Llama, Mistral, etc.). * * The returned `format` and `reasoningFormat` fields should be passed to @@ -1081,6 +1095,21 @@ export interface SessionContext { options?: FormatChatOptions | string ): Promise; + /** + * Format messages using model's chat template (sync β€” inline on main thread) + * + * Same as {@link formatChat} but synchronous. Use from Effection generators + * to avoid `yield* call()` overhead for CPU-only work. + * + * @param messagesJson JSON string containing array of messages + * @param options Formatting options (tools, reasoning, grammar, etc.) + * @returns Formatted prompt with format-awareness metadata + */ + formatChatSync( + messagesJson: string, + options?: FormatChatOptions | string + ): FormattedChatResult; + /** * Parse model output into structured content * @@ -1201,6 +1230,17 @@ export interface SessionContext { */ jsonSchemaToGrammar(schemaJson: string): Promise; + /** + * Convert JSON schema to GBNF grammar (sync β€” inline on main thread) + * + * Same as {@link jsonSchemaToGrammar} but synchronous. Use from Effection + * generators to avoid `yield* call()` overhead for CPU-only work. + * + * @param schemaJson JSON schema string + * @returns GBNF grammar string + */ + jsonSchemaToGrammarSync(schemaJson: string): string; + /** * Validate chat template syntax * @@ -1333,208 +1373,104 @@ export interface SessionContext { // ===== BRANCH API (internal, wrapped by Branch class) ===== - /** @internal Create a new branch for parallel generation */ - _branchCreate(position: number, params?: SamplingParams, nBatch?: number): number; + /** @internal */ + _branchCreate(position: number, params?: SamplingParams, nBatch?: number, grammar?: string): number; - /** @internal Fork a branch to a new sequence */ + /** @internal */ _branchFork(handle: number): number; - /** @internal Decode multiple tokens in n_batch-sized chunks and capture logits */ + /** @internal */ _branchPrefill(handle: number, tokens: number[]): Promise; - /** @internal Sample next token from branch's logits snapshot */ + /** @internal */ _branchSample(handle: number): number; - /** @internal Accept token (update sampler state for penalties) */ + /** @internal */ _branchAccept(handle: number, token: number): void; - /** @internal Get branch's current position */ + /** @internal */ _branchGetPosition(handle: number): number; - /** @internal Get branch's perplexity */ + /** @internal */ _branchGetPerplexity(handle: number): number; - /** @internal Get copy of branch's logits snapshot */ + /** @internal */ _branchGetLogits(handle: number): Float32Array; - /** @internal Prune branch (remove KV cache entries and free handle) β€” RESTRICT: throws if children */ + /** @internal */ _branchPrune(handle: number): void; - /** @internal Prune branch and all descendants β€” CASCADE */ + /** @internal */ _branchPruneSubtree(handle: number): void; - /** @internal Get parent branch handle (0 = INVALID_HANDLE if root) */ + /** @internal */ _branchParent(handle: number): number; - /** @internal Get child branch handles */ + /** @internal */ _branchChildren(handle: number): number[]; - /** @internal Check if branch has no children */ + /** @internal */ _branchIsLeaf(handle: number): boolean; - /** @internal Check if branch holds a KV lease */ + /** @internal */ _branchIsActive(handle: number): boolean; - /** @internal Reseed branch sampler PRNG for diversity after fork */ + /** @internal */ _branchSamplerChainReseed(handle: number, seed: number): void; - /** @internal Set dynamic logit biases for a branch */ + /** @internal */ _branchSteer(handle: number, biases: Array<{ token: number; bias: number }>): void; - /** @internal Clear all dynamic logit biases from a branch */ + /** @internal */ _branchClearSteer(handle: number): void; - /** @internal Replace sampler chain with new parameters (memoized) */ + /** @internal */ _branchSetSamplerParams(handle: number, params: SamplingParams): void; - /** @internal Replace or remove grammar constraint */ + /** @internal */ _branchSetGrammar(handle: number, grammarStr: string): void; - /** @internal Compute entropy from branch's logits snapshot */ + /** @internal */ + _branchSetGrammarLazy(handle: number, grammar: string, patterns: string[], tokens: number[]): void; + + /** @internal */ _branchModelEntropy(handle: number, base?: string): number; - /** @internal Compute surprisal from branch's logits snapshot */ + /** @internal */ _branchModelSurprisal(handle: number, token: number, base?: string): number; - /** @internal Get sampling-level perplexity */ + /** @internal */ _branchGetSamplingPerplexity(handle: number): number; - /** @internal Set static logit biases on a branch */ + /** @internal */ _branchSetLogitBias(handle: number, biases: Array<{ token: number; bias: number }>): void; - /** @internal Clear all static logit biases from a branch */ + /** @internal */ _branchClearLogitBias(handle: number): void; // ===== STORE API (internal, wrapped by BranchStore) ===== - /** @internal Batched accept + decode_each + capture for N branches */ + /** @internal */ _storeCommit(handles: number[], tokens: number[]): Promise; - /** @internal Batched decode_scatter + capture for N branches with variable token counts */ + /** @internal */ _storePrefill(handles: number[], tokenArrays: number[][]): Promise; - /** @internal Retain winner branch, evict all others */ + /** @internal */ _storeRetainOnly(handle: number): void; - /** @internal Get number of available seq_id leases */ + /** @internal */ _storeAvailable(): number; -} -/** - * Create a new inference context - * - * Entry point for all inference. Resolves the correct native binary (see - * {@link loadBinary} for the platform/GPU fallback chain), loads the model - * via a reference-counted registry (multiple contexts can share one model's - * weight tensors in memory), and allocates a `llama_context` with its own - * KV cache and compute scratch buffers. - * - * **What gets allocated:** - * - KV cache: `nCtx * 2 * nLayers * dHead` bytes per KV type (fp16 default). - * For a 7B model with `nCtx: 4096`, expect ~1-2 GB of KV memory. - * - Compute scratch: temporary buffers for the forward pass, sized to `nBatch`. - * - * **Model sharing:** If two contexts use the same `modelPath`, the model - * weights are loaded once and shared. Only the KV cache and compute buffers - * are per-context. This makes multi-context setups (e.g., one context per - * conversation) memory-efficient. - * - * @param options Context creation options - * @param loadOptions Optional binary loading options (GPU variant selection) - * @returns Promise resolving to SessionContext instance - * - * @example Basic usage - * ```typescript - * const ctx = await createContext({ - * modelPath: './model.gguf', - * nCtx: 2048, - * nThreads: 4 - * }); - * - * try { - * const tokens = await ctx.tokenize("Hello"); - * const branch = Branch.create(ctx, 0, { temperature: 0.7 }); - * await branch.prefill(tokens); - * for await (const { text } of branch) process.stdout.write(text); - * } finally { - * ctx.dispose(); - * } - * ``` - * - * @example Multi-branch context (tree search, best-of-N) - * ```typescript - * const ctx = await createContext({ - * modelPath: './model.gguf', - * nCtx: 8192, - * nBatch: 512, // Bin-packing capacity for BranchStore.prefill - * nSeqMax: 33, // 32 branches + 1 root sequence - * }); - * ``` - * - * @example With GPU variant selection - * ```typescript - * const ctx = await createContext( - * { modelPath: './model.gguf', nCtx: 4096 }, - * { gpuVariant: 'cuda' } - * ); - * ``` - * - * @category Core - */ -export function createContext( - options: ContextOptions, - loadOptions?: LoadOptions -): Promise; + /** KV cache pressure snapshot from native BranchStore. + * cells_used is a monotonic counter reset on drain/retainOnly. */ + _storeKvPressure(): { nCtx: number; cellsUsed: number; remaining: number }; -/** - * Load native binary for a specific GPU variant - * - * lloyal.node ships as a family of platform-specific npm packages, each - * containing a prebuilt native addon: - * `@lloyal-labs/lloyal.node-{platform}-{arch}[-{gpu}]` - * (e.g., `darwin-arm64`, `linux-x64-cuda`, `win32-x64-vulkan`). - * - * `loadBinary()` resolves the correct package at runtime with a prioritized - * fallback chain: - * - * 1. Requested GPU variant package (if `variant` or `LLOYAL_GPU` env var set) - * 2. Local development build (`build/Release/lloyal.node`) - * 3. Default CPU platform package - * - * Most callers should use {@link createContext} directly β€” it calls - * `loadBinary()` internally. Use this function when you need to: - * - Pre-check whether a GPU variant is available before creating contexts - * - Share one loaded binary across multiple context creations - * - Inspect or test the binary loading logic in isolation - * - * **Environment variables:** - * - `LLOYAL_LOCAL=1` β€” Force local build only; throws if not found - * (use during development to test local C++ changes) - * - `LLOYAL_GPU=cuda|vulkan` β€” Request GPU variant (equivalent to `variant` param) - * - `LLOYAL_NO_FALLBACK=1` β€” Disable silent CPU fallback; throws if GPU - * variant fails (use in CI to catch missing runtime libraries) - * - * @param variant GPU variant: 'cuda', 'vulkan', or undefined for CPU - * @returns Native binary module with createContext method - * @throws Error if no binary available for the current platform - * - * @example - * ```typescript - * // Load default (CPU) binary - * const binary = loadBinary(); - * - * // Load CUDA binary (falls back to CPU if unavailable) - * const binary = loadBinary('cuda'); - * - * // Create context from loaded binary - * const ctx = await binary.createContext({ modelPath: './model.gguf' }); - * ``` - * - * @category Core - */ -export function loadBinary(variant?: GpuVariant): { - createContext(options: ContextOptions): Promise; -}; + // ===== SCORING API ===== + + /** @internal β€” processes ≀ n_seq_max prompts in a single group */ + _scoreGroup(tokenArrays: number[][]): Promise; +} /** * Result from Branch.produce() @@ -1550,552 +1486,55 @@ export interface Produced { isStop: boolean; } +// AgentTask, AgentState, RunAgentsOptions, RunAgentsResult removed β€” +// superseded by src/runtime/ (useAgentPool, AgentTaskSpec, AgentPoolResult) + /** - * Forkable inference handle for covalent generation - * - * A Branch owns everything needed for independent generation: a KV cache - * sequence, sampler chain, logits snapshot, and perplexity tracker. - * - * Forking is cheap β€” the KV prefix is shared in memory (metadata-only operation under unified KV β€” - * no KV tensor buffers are copied), so sibling branches read from the same physical KV entries. - * Only tokens decoded after the fork point are exclusive to each branch. - * - * Branches form trees, not just flat lists. Fork from root for best-of-N, - * fork from children for tree search/beam search, fork from a draft for speculative - * decoding. - * - * The produce/commit protocol separates sampling from state advancement: - * produce() samples without writing to KV, letting you inspect the result - * before deciding to commit(). - * - * @example Best-of-N with perplexity selection - * ```typescript - * const root = Branch.create(ctx, tokens.length, { temperature: 0.8 }); - * await root.prefill(tokens); - * - * const results = []; - * for (let i = 0; i < 5; i++) { - * const branch = await root.fork(); - * branch.reseedSampler(1000 + i); - * const tokens = []; - * for await (const { token } of branch) tokens.push(token); - * results.push({ branch, tokens, ppl: branch.perplexity }); - * } - * - * const best = results.reduce((a, b) => a.ppl < b.ppl ? a : b); - * for (const r of results) { if (r !== best) await r.branch.prune(); } - * ``` - * - * @category Branching + * Options for Rerank context creation + * @category Core */ -export class Branch { - /** - * Create a root branch at the given position - * - * The branch takes ownership of the sequence and creates its own sampler - * chain from the provided params. Call prefill() to decode prompt tokens - * and capture the logit distribution before forking. - * - * @param ctx SessionContext to create branch on - * @param position Starting position (typically prompt token count) - * @param params Sampling parameters (temperature, topP, etc.) - * @param nBatch Per-branch batch size override (defaults to context nBatch) - * @param grammar GBNF grammar string for constrained generation. When provided, - * sample() returns only grammar-valid tokens. The grammar state is cloned on - * fork(), so sibling branches can diverge independently. - */ - static create( - ctx: SessionContext, - position: number, - params?: SamplingParams, - nBatch?: number, - grammar?: string - ): Branch; - - /** - * Fork this branch to a new sequence - * - * The child shares the parent's KV prefix in memory (metadata-only under unified KV, no KV buffer copy). - * Logits, sampler state, and perplexity tracker are cloned so the child - * can diverge independently. Fork from any branch β€” root or intermediate β€” - * to build arbitrarily deep trees. - * - */ - fork(): Promise; - - /** - * Get a copy of this branch's captured logits snapshot. - * - * Returns n_vocab floats β€” the raw logit distribution from the last - * prefill() or commit() call. - * - * Returns an independent copy of the branch's internal snapshot. - * The returned Float32Array is safe to hold across async boundaries - * and is not affected by subsequent decode operations. - * - * @returns Independent copy of the logits snapshot (n_vocab elements) - * @throws If no logits have been captured yet - */ - getLogits(): Float32Array; - - /** - * Bulk-decode tokens into the branch's KV cache and capture logits. - * - * `tokens.length` is the total count to process; the branch's `nBatch` - * (set at `Branch.create`) controls how many are sent per `llama_decode` - * call. E.g. 500 tokens with `nBatch=64` β†’ 8 calls (7Γ—64 + 1Γ—52). - * - * Advances `position` by `tokens.length`. Stores final logits into the - * branch's internal snapshot β€” the next `produce()`/`sample()` reads - * from it. - * - * Does NOT accept tokens into the repeat-penalty window β€” for external - * tokens (user input between turns), not model-generated tokens. - * For model output, use `commit()` which does accept + decode. - * - * The primary way to feed tokens into a branch's KV cache. - * - * @param tokens - Token IDs to decode - */ - prefill(tokens: number[]): Promise; - - /** Sample next token from branch's frozen logits snapshot */ - sample(): number; - - /** Accept token for repeat-penalty tracking */ - accept(token: number): void; - - /** - * Discard this branch β€” remove its divergent KV entries and free the handle - * - * Only removes KV entries divergent from the shared prefix; sibling branches - * are unaffected. The disposed flag is set synchronously β€” any call to - * produce(), commit(), etc. after prune() will throw immediately, even - * before the returned promise resolves. - * - * RESTRICT mode: throws if children exist. Use {@link pruneSubtree} to - * cascade-delete an entire subtree. - */ - prune(): Promise; - - /** - * Discard this branch and all its descendants β€” CASCADE delete - * - * Iterative post-order traversal: prunes children first, then this branch. - * Use when tearing down an entire subtree (e.g. abandoned search path). - * Sets disposed synchronously, like {@link prune}. - */ - pruneSubtree(): Promise; - - /** - * Reseed the sampler's PRNG for diversity after fork() - * - * CRITICAL for parallel generation: Without reseeding, all forked branches - * produce identical outputs because they share the same PRNG state. - * - * Only affects stochastic samplers (temperature > 0). Greedy samplers are unchanged. - * - * @param seed - New seed for the PRNG - */ - reseedSampler(seed: number): void; - - /** - * Apply dynamic logit adjustments for this branch only - * - * Unlike `logit_bias` in sampling params (which is cloned on fork), steer biases - * are NOT inherited by child branches. Each branch manages its own steer state - * independently. This makes steer ideal for path-dependent constraints. - * - * **Use cases:** - * - **tsampler**: Block tokens that would create repeated N-grams based on - * this branch's specific generation history - * - **Diverse beam search**: Penalize tokens already chosen by sibling beams - * to encourage output diversity across the beam - * - **Dynamic constraints**: Apply token restrictions that change per-step - * - * **Sampling order:** Grammar β†’ Logit Bias β†’ Steer β†’ Sampler Chain - * - * @param biases - Array of token adjustments. Use `-Infinity` to completely - * block a token, positive values to boost probability, negative to reduce. - * - * @example Block tokens for N-gram deduplication (tsampler pattern) - * ```ts - * // Compute which tokens would create repeated 4-grams - * const blocked = computeNgramBlocks(generatedTokens, n=4); - * - * // Block those tokens for this sample only - * branch.steer(blocked.map(t => ({ token: t, bias: -Infinity }))); - * - * const { token } = await branch.produce(); // Blocked tokens won't be sampled - * await branch.commit(token); - * - * // Clear for next iteration (recompute based on new history) - * branch.clearSteer(); - * ``` - * - * @example Diverse beam search - * ```ts - * // Each beam penalizes tokens chosen by siblings this step - * for (const beam of beams) { - * // Collect tokens chosen by other beams - * const siblingTokens = beams - * .filter(b => b !== beam && b.lastToken !== undefined) - * .map(b => b.lastToken); - * - * // Penalize sibling choices to encourage diversity - * beam.branch.steer(siblingTokens.map(t => ({ token: t, bias: -2.0 }))); - * - * const { token } = await beam.branch.produce(); - * await beam.branch.commit(token); - * beam.lastToken = token; - * beam.branch.clearSteer(); - * } - * ``` - * - * @example Boost specific tokens - * ```ts - * // Boost "yes" and "no" tokens for a yes/no question - * branch.steer([ - * { token: yesTokenId, bias: 5.0 }, - * { token: noTokenId, bias: 5.0 } - * ]); - * ``` - */ - steer(biases: Array<{ token: number; bias: number }>): void; - - /** - * Clear all steer biases from this branch - * - * Removes any dynamic logit adjustments set by `steer()`. Call this after - * each generation step if your steer constraints are computed per-step - * (e.g., N-gram blocking where the blocked set changes as text grows). - * - * @example Per-step steer pattern - * ```ts - * for (let i = 0; i < maxTokens; i++) { - * // Compute constraints based on current state - * const blocked = computeConstraints(generatedTokens); - * branch.steer(blocked.map(t => ({ token: t, bias: -Infinity }))); - * - * const { token, isStop } = await branch.produce(); - * if (isStop) break; - * - * await branch.commit(token); - * branch.clearSteer(); // Reset for next iteration - * generatedTokens.push(token); - * } - * ``` - */ - clearSteer(): void; - - /** - * Compute entropy of the branch's logits distribution - * - * Measures model uncertainty from the branch's captured logits snapshot: - * - Low entropy: Model is confident (peaked distribution) - * - High entropy: Model is uncertain (flat distribution) - * - * Operates directly on `state->logits_snapshot` β€” no JS round-trip. - * - * @param base - Logarithm base: "nats" (default) or "bits" - * @returns Entropy value in specified base - * - * COST: O(n_vocab) - must sum over all token probabilities - */ - modelEntropy(base?: 'nats' | 'bits'): number; - - /** - * Compute surprisal (negative log-likelihood) for a specific token - * - * Measures how "surprising" the model finds the given token from - * the branch's captured logits snapshot: - * - Low surprisal: Model expected this token (high probability) - * - High surprisal: Model didn't expect this token (low probability) - * - * Operates directly on `state->logits_snapshot` β€” no JS round-trip. - * - * @param token - Token ID to compute surprisal for - * @param base - Logarithm base: "nats" (default) or "bits" - * @returns Surprisal value in specified base - * - * COST: O(n_vocab) - softmax normalization required - */ - modelSurprisal(token: number, base?: 'nats' | 'bits'): number; - - /** - * Sampling-level perplexity (from filtered distribution) - * - * Returns perplexity from the distribution actually sampled from - * (after top-k/p/temp/penalties). Useful for policy priors and - * monitoring sampler chain impact. - * - * Compare with {@link perplexity} which is model-level (raw logits). - */ - readonly samplingPerplexity: number; - - /** - * Set static logit biases on this branch - * - * Unlike {@link steer} (which is NOT inherited on fork), logit biases - * ARE cloned when forking. Use for persistent constraints that should - * propagate to child branches. - * - * Applied during sample() in order: Grammar -> Logit Bias -> Steer -> Sampler Chain - * - * @param biases - Array of token adjustments. Use `-Infinity` to block, - * positive to boost, negative to reduce. - */ - setLogitBias(biases: Array<{ token: number; bias: number }>): void; - - /** - * Clear all static logit biases from this branch - */ - clearLogitBias(): void; - - /** - * Replace the sampler chain with new parameters (memoized) - * - * If the new params match the current chain's params, this is a no-op. - * Otherwise the old chain is freed and a new one is created. Use for - * Entropy-Driven Temperature (EDT) and other adaptive sampling strategies - * that adjust parameters per-step. - * - * @param params - New sampling parameters - * - * @example Entropy-Driven Temperature - * ```typescript - * const entropy = branch.modelEntropy('nats'); - * branch.setSamplerParams({ temperature: edtTemperature(entropy) }); - * const { token } = await branch.produce(); - * await branch.commit(token); - * ``` - */ - setSamplerParams(params: SamplingParams): void; - - /** - * Replace or remove the grammar constraint - * - * Pass a GBNF grammar string to constrain generation. Pass empty string - * or undefined to remove the constraint. The grammar state is cloned on - * fork(), so sibling branches can diverge independently after hot-swap. - * - * @param grammarStr - GBNF grammar string, or empty/undefined to remove - * - * @example Hot-swap grammar mid-generation - * ```typescript - * // Start unconstrained, then switch to JSON after detecting tool call - * branch.setGrammar(jsonGrammar); - * const { token } = await branch.produce(); - * ``` - */ - setGrammar(grammarStr?: string): void; - - /** - * Sample next token without advancing state (async) - * - * Async contract: local branches resolve immediately; cloud branches - * may perform an HTTP round-trip. Use {@link produceSync} when you know - * the branch is local and want zero-overhead sampling. - */ - produce(): Promise; - - /** - * Sample next token without advancing state (sync) - * - * Same as {@link produce} but synchronous. Use when you know the branch - * is local and want to avoid the microtick overhead of a promise. - */ - produceSync(): Produced; - - /** - * Accept and decode β€” update branch state, then write token to KV - * - * Accepts the token into the sampler penalty window (for correct PPL - * measurement), then decodes (writing to KV cache via AsyncWorker on - * the libuv thread pool) and captures the resulting logits for the next - * produce() call. Accept-first ordering with rollback: if decode throws, - * sampler/grammar/metrics are restored from clones. - * - * @param token Token to commit (from produce()) - */ - commit(token: number): Promise; - - /** Branch's current position */ - readonly position: number; - - /** Branch's perplexity */ - readonly perplexity: number; - - /** Internal handle (for debugging) */ - readonly handle: number; - - /** Whether this branch has been disposed */ - readonly disposed: boolean; - - /** Parent branch handle, or null if root */ - readonly parent: number | null; - - /** Child branch handles */ - readonly children: number[]; - - /** True if this branch has no children */ - readonly isLeaf: boolean; +export interface RerankOptions { + /** Path to reranker .gguf model */ + modelPath: string; + /** Max prompts per GPU dispatch (default: 8) */ + nSeqMax?: number; + /** Context window size (default: 4096) */ + nCtx?: number; + /** KV cache key quantization (default: 'q4_0') */ + typeK?: KvCacheType; + /** KV cache value quantization (default: 'q4_0') */ + typeV?: KvCacheType; +} - /** True if this branch holds a KV lease */ - readonly isActive: boolean; +/** + * A single rerank result β€” score for one document + * @category Core + */ +export interface RerankResult { + /** Relevance probability (0–1) */ + score: number; + /** Original index in the input array */ + index: number; +} - /** - * Async iterator β€” generate tokens until EOG - * - * Commit-before-yield semantics: every yielded token is already written - * to KV and accepted into the sampler. Breaking out of the loop is clean β€” - * no orphaned uncommitted tokens, perplexity reflects all yielded tokens. - * - * For inspect-before-commit (speculative decoding, tree search), use - * the {@link produce}/{@link commit} protocol directly. - * - * @example Generate to completion - * ```typescript - * for await (const { token, text } of branch) { - * process.stdout.write(text); - * } - * ``` - * - * @example Generate with consumer-side bound - * ```typescript - * const tokens = []; - * for await (const { token } of branch) { - * tokens.push(token); - * if (tokens.length >= limit) break; - * } - * ``` - */ - [Symbol.asyncIterator](): AsyncIterableIterator<{ token: number; text: string }>; +/** + * Progress yielded by Rerank.score() after each scoring group completes + * @category Core + */ +export interface RerankProgress { + /** Number of documents scored so far */ + filled: number; + /** Total documents to score */ + total: number; + /** Sorted results β€” partial until filled === total */ + results: RerankResult[]; } /** - * High-throughput multi-branch decode operations - * - * The naive approach to N-branch generation is N sequential llama_decode() - * calls β€” each paying full GPU kernel launch overhead, memory barrier, and - * PCIe round-trip. BranchStore eliminates this by packing all branches into - * a single llama_batch and dispatching once: O(1) GPU round-trips regardless - * of branch count. The GPU parallelizes across sequences within the batch, - * so N branches approach the wall-time cost of 1. - * - * Two operations, two packing strategies: - * - * **commit()** β€” Generation step. Each branch contributes exactly 1 token. - * Packs N tokens into a single batch via `decode_each` (one row per sequence, - * all at their respective positions). Single `llama_decode()` call. Logits - * captured per-branch at batch index `i`. O(N) total work, O(1) GPU - * dispatches, O(1) amortized dispatch overhead per branch. Accept-first - * ordering with rollback: accepts each token into its branch's repeat-penalty - * window before decode, restores from clones if decode throws. - * - * **prefill()** β€” Bulk token injection. Each branch contributes a - * variable-length token array. Uses a two-pass bin-packing algorithm: - * - * - *Pass 1 (planning)*: Greedy first-fit packs items into chunks ≀ nBatch. - * Items larger than nBatch get a dedicated chunk and fall through to - * decode_many's internal auto-chunking (ceil(nTokens / nBatch) calls). - * - *Pass 2 (dispatch)*: Normal chunks dispatch via `decode_scatter` (one - * `llama_decode` per chunk). Logits are indexed by flattened cursor - * position: for item k in a chunk, logits live at `cursor + nTokens[k] - 1`. - * - * For T total tokens across N branches with batch capacity B: - * - Best case (T ≀ B): 1 GPU dispatch, all branches in one batch. - * - Worst case: ceil(T / B) dispatches. Each dispatch is fully packed. - * - Amortized per-token GPU overhead: O(1/B) β€” vanishes as batch fills. - * - * Does NOT accept tokens into the sampler penalty window β€” use for - * external/replayed tokens where repeat-penalty tracking is unwanted. - * For model-generated tokens, use {@link commit} instead. - * - * Both methods take `[branch, token(s)]` tuples β€” the branch-to-token - * binding is structural, not positional. After either call, each branch's - * logits snapshot is updated with the output distribution from its decoded - * token(s), ready for the next `produce()`/`sample()` call. - * - * @example 32-branch generation step β€” one GPU dispatch - * ```typescript - * const store = new BranchStore(ctx); - * const entries = await Promise.all(branches.map(async b => [b, (await b.produce()).token] as [Branch, number])); - * await store.commit(entries); // 32 tokens, 1 llama_decode() - * ``` - * - * @example Best-of-N with batched commit - * ```typescript - * const store = new BranchStore(ctx); - * const branches = []; - * for (const _ of [1, 2, 3]) branches.push(await root.fork()); - * - * for (let step = 0; step < 50; step++) { - * const produced = await Promise.all(branches.map(async b => [b, await b.produce()] as const)); - * const live = produced.filter(([, p]) => !p.isStop); - * if (!live.length) break; - * await store.commit(live.map(([b, p]) => [b, p.token])); - * } - * ``` + * Native binding interface β€” what loadBinary() returns * - * @example Asymmetric prefill β€” variable-length injections, auto-chunked - * ```typescript - * await store.prefill([ - * [branchA, systemPromptTokens], // 200 tokens - * [branchB, shortQueryTokens], // 12 tokens - * [branchC, longDocumentTokens], // 800 tokens - * ]); - * // Bin-packed into ceil(1012 / nBatch) GPU dispatches - * ``` - * - * @category Branching + * @category Core */ -export class BranchStore { - constructor(ctx: SessionContext); - - /** - * Batched single-token commit for model-generated tokens - * - * Each tuple `[branch, token]` binds one token to one branch. - * Accepts each token into its branch's repeat-penalty window (for correct - * PPL measurement), then decodes all N tokens in a single llama_decode() - * call via decode_each and captures logits per-branch. Accept-first - * ordering with rollback: if decode throws, sampler/grammar/metrics are - * restored from clones taken before the accept. - * - * @param entries - Array of `[branch, token]` tuples (branches must not be disposed) - * @throws If any branch is disposed - */ - commit(entries: [Branch, number][]): Promise; - - /** - * Batched variable-length prefill for external tokens - * - * Each tuple `[branch, tokens]` binds a token array to one branch. - * Each branch can receive a different number of tokens β€” decode_scatter - * handles variable-length runs and auto-chunks to fit nBatch. - * - * Does NOT call accept_token β€” use for external/replayed tokens where - * repeat-penalty tracking is unwanted. For model-generated tokens, - * use {@link commit} instead. - * - * @param entries - Array of `[branch, tokens]` tuples (branches must not be disposed) - * @throws If any branch is disposed - */ - prefill(entries: [Branch, number[]][]): Promise; - - /** - * Retain only the winner branch β€” evict all other leases and free their slots. - * - * Nuclear operation: calls `kv::seq_keep` on the winner's seq_id (stripping all - * other sequences from KV cache in a single pass), then frees all loser slots - * and rebuilds the vacancy list. The winner's topology is reset (no parent, no children). - * - * @param winner - The branch to keep (must not be disposed, must hold a lease) - * @throws If winner is disposed or has no lease - */ - retainOnly(winner: Branch): Promise; - - /** Number of available seq_id leases */ - readonly available: number; +export interface NativeBinding { + createContext(options: ContextOptions): Promise; } diff --git a/test/agents.ts b/test/agents.ts new file mode 100644 index 0000000..5e5dfd8 --- /dev/null +++ b/test/agents.ts @@ -0,0 +1,272 @@ +/** + * Structured concurrency tests for the agent system + * + * Verifies Effection v4 SC guarantees: branch cleanup on all exit paths, + * scope teardown ordering, ensure() lifecycle. + * + * Usage: + * npm run test:agents + * LLAMA_TEST_MODEL=models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf npm run test:agents + */ + +import * as path from 'node:path'; +import * as fs from 'node:fs'; +import { run, call, spawn, ensure, each } from 'effection'; +import { loadBinary, Branch } from '../dist/index.js'; +import type { SessionContext, NativeBinding } from '../dist/index.js'; +import { + initAgents, runAgents, withSharedRoot, Tool, +} from '../dist/agents/index.js'; +import type { AgentPoolResult, JsonSchema } from '../dist/agents/index.js'; + +const MODEL_PATH: string = process.env.LLAMA_TEST_MODEL + ? path.resolve(process.env.LLAMA_TEST_MODEL) + : path.join(__dirname, '../models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf'); + +const CTX_SIZE = 2048; + +if (!fs.existsSync(MODEL_PATH)) { + console.error('Test model not found:', MODEL_PATH); + process.exit(1); +} + +console.log('=== lloyal.node SC Agent Tests ===\n'); +console.log(`Model: ${path.basename(MODEL_PATH)}`); +console.log(`Size: ${(fs.statSync(MODEL_PATH).size / 1024 / 1024).toFixed(1)} MB\n`); + +let addon: NativeBinding; +try { + addon = require('../build/Release/lloyal.node') as NativeBinding; +} catch { + addon = loadBinary(); +} + +let passed = 0; +let failed = 0; + +function ok(msg: string): void { + passed++; + console.log(` [PASS] ${msg}`); +} + +function fail(msg: string): void { + failed++; + console.log(` [FAIL] ${msg}`); +} + +function assert(condition: boolean, msg: string): void { + if (condition) ok(msg); + else { fail(msg); throw new Error(msg); } +} + +// ── Test tools ──────────────────────────────────────────────────── + +class ThrowingTool extends Tool> { + readonly name = 'explode'; + readonly description = 'A tool that always throws'; + readonly parameters: JsonSchema = { + type: 'object', + properties: { input: { type: 'string' } }, + }; + async execute(): Promise { + throw new Error('intentional_tool_error'); + } +} + +// ── Helpers ──────────────────────────────────────────────────────── + +async function createTestContext(): Promise { + return addon.createContext({ + modelPath: MODEL_PATH, + nCtx: CTX_SIZE, + nThreads: 4, + nSeqMax: 4, + typeK: 'f16', + typeV: 'f16', + }); +} + +function makeTasks(parent: Branch, count: number) { + return Array.from({ length: count }, (_, i) => ({ + systemPrompt: 'You are a test agent.', + content: `Test task ${i}`, + parent, + })); +} + +/** Bootstrap agent infra via initAgents + drain events to prevent backpressure */ +function* setupTest(ctx: SessionContext) { + const { events } = yield* initAgents(ctx); + yield* spawn(function*() { + for (const _ev of yield* each(events)) { + yield* each.next(); + } + }); +} + +// ═══════════════════════════════════════════════════════════════════ +// TEST 1: ensure() cleanup β€” runs on scope exit regardless of how +// ═══════════════════════════════════════════════════════════════════ + +async function testEnsureCleanup(): Promise { + console.log('\n--- ensure() cleanup: runs on normal exit and on error ---'); + + // Test A: ensure runs on normal exit + let cleanupRanNormal = false; + await run(function*() { + yield* ensure(() => { cleanupRanNormal = true; }); + }); + assert(cleanupRanNormal, 'ensure() ran on normal scope exit'); + + // Test B: ensure runs on error exit + let cleanupRanError = false; + try { + await run(function*() { + yield* ensure(() => { cleanupRanError = true; }); + throw new Error('intentional_test_error'); + }); + } catch { + // expected + } + assert(cleanupRanError, 'ensure() ran on error scope exit'); +} + +// ═══════════════════════════════════════════════════════════════════ +// TEST 2: Normal lifecycle β€” branches pruned after runAgents returns +// ═══════════════════════════════════════════════════════════════════ + +async function testNormalLifecycle(): Promise { + console.log('\n--- Normal lifecycle: branches pruned after runAgents ---'); + + await run(function*() { + const ctx: SessionContext = yield* call(() => createTestContext()); + yield* setupTest(ctx); + + yield* withSharedRoot( + { systemPrompt: 'You are a test agent.' }, + function*(root, prefixLen) { + assert(prefixLen > 0, `shared prefix has tokens (${prefixLen})`); + + const pool: AgentPoolResult = yield* runAgents({ + tasks: makeTasks(root, 2), + tools: new Map(), + maxTurns: 1, + }); + + assert(pool.agents.length === 2, 'pool has 2 agents'); + assert(root.children.length === 0, 'agent branches pruned before body returns'); + + return pool; + }, + ); + + ok('withSharedRoot completed without error'); + }); +} + +// ═══════════════════════════════════════════════════════════════════ +// TEST 3: scoped() cleanup β€” runAgents prunes before returning +// ═══════════════════════════════════════════════════════════════════ + +async function testScopedCleanup(): Promise { + console.log('\n--- Scoped cleanup: runAgents prunes before returning to caller ---'); + + await run(function*() { + const ctx: SessionContext = yield* call(() => createTestContext()); + yield* setupTest(ctx); + + yield* withSharedRoot( + { systemPrompt: 'You are a test agent.' }, + function*(root) { + const childCountBefore = root.children.length; + assert(childCountBefore === 0, 'root starts with no children'); + + const pool = yield* runAgents({ + tasks: makeTasks(root, 2), + tools: new Map(), + maxTurns: 1, + }); + + // Critical SC assertion: scoped() in runAgents must have torn + // down the pool scope and pruned agent branches BEFORE returning. + const childCountAfter = root.children.length; + assert(childCountAfter === 0, `scoped() pruned all children before returning (was ${childCountBefore}, now ${childCountAfter})`); + + return pool; + }, + ); + + ok('scoped() teardown ordering correct'); + }); +} + +// ═══════════════════════════════════════════════════════════════════ +// TEST 4: Tool error β€” branches pruned, error does not crash pool +// ═══════════════════════════════════════════════════════════════════ + +async function testToolErrorCleanup(): Promise { + console.log('\n--- Tool error: branches pruned, pool completes gracefully ---'); + + await run(function*() { + const ctx: SessionContext = yield* call(() => createTestContext()); + yield* setupTest(ctx); + + try { + yield* withSharedRoot( + { systemPrompt: 'You are a test agent. Always call the explode tool.' }, + function*(root) { + const toolMap = new Map([['explode', new ThrowingTool()]]); + const toolsJson = JSON.stringify([{ + type: 'function', + function: { + name: 'explode', + description: 'A tool that always throws', + parameters: { type: 'object', properties: { input: { type: 'string' } } }, + }, + }]); + + const pool = yield* runAgents({ + tasks: [{ + systemPrompt: 'You are a test agent. Call the explode tool immediately.', + content: 'Do it now.', + tools: toolsJson, + parent: root, + }], + tools: toolMap, + maxTurns: 2, + }); + + assert(root.children.length === 0, 'agent branches pruned after tool error'); + assert(pool.agents.length === 1, 'pool has 1 agent'); + return pool; + }, + ); + + ok('withSharedRoot completed β€” tool error did not crash the pool'); + } catch (err) { + // Tool errors should be handled internally (agent β†’ done state). + // If we reach here, something unexpected propagated. + fail(`unexpected error escaped pool: ${(err as Error).message}`); + } + }); +} + +// ═══════════════════════════════════════════════════════════════════ +// RUNNER +// ═══════════════════════════════════════════════════════════════════ + +async function main_(): Promise { + await testEnsureCleanup(); + await testNormalLifecycle(); + await testScopedCleanup(); + await testToolErrorCleanup(); + + console.log(`\n${'='.repeat(40)}`); + console.log(`Results: ${passed} passed, ${failed} failed`); + if (failed > 0) process.exit(1); +} + +main_().catch((err: unknown) => { + console.error(`\nFatal: ${(err as Error).message}\n${(err as Error).stack}`); + process.exit(1); +}); diff --git a/test/examples.js b/test/examples.js deleted file mode 100644 index e2ab9ba..0000000 --- a/test/examples.js +++ /dev/null @@ -1,410 +0,0 @@ -/** - * Examples Integration Test - * - * Runs examples with --jsonl flag and validates structured output. - * Each example emits JSONL events that we parse and assert on. - * - * Usage: - * node test/examples.js # Run all examples - * node test/examples.js entropy # Run specific example - * - * Environment variables: - * LLAMA_TEST_MODEL - Path to chat/instruct model (default: SmolLM2) - * EMBED_MODEL_PATH - Path to embedding model (default: nomic-embed) - */ - -const { spawn } = require('child_process'); -const path = require('path'); -const fs = require('fs'); - -// Model paths - use env var or default (resolve to absolute path) -const MODEL_PATH = process.env.LLAMA_TEST_MODEL - ? path.resolve(process.env.LLAMA_TEST_MODEL) - : path.join(__dirname, '../models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf'); - -// Embedding model (separate from chat model, resolve to absolute path) -const EMBED_MODEL_PATH = process.env.EMBED_MODEL_PATH - ? path.resolve(process.env.EMBED_MODEL_PATH) - : path.join(__dirname, '../liblloyal/tests/fixtures/nomic-embed-text-v1.5.Q4_K_M.gguf'); - - -if (!fs.existsSync(MODEL_PATH)) { - console.error('❌ Test model not found!'); - console.error(` Expected: ${MODEL_PATH}`); - console.error(' Run: npm run download-models'); - process.exit(1); -} - -/** - * Run an example with --jsonl and collect events - */ -function runExample(scriptPath, timeout = 600000, extraArgs = [], modelPathOverride = null) { - return new Promise((resolve, reject) => { - const events = []; - let stderr = ''; - - const modelArg = modelPathOverride || MODEL_PATH; - - const child = spawn('node', [scriptPath, modelArg, '--jsonl', ...extraArgs], { - cwd: path.dirname(scriptPath), - stdio: ['ignore', 'pipe', 'pipe'], - }); - - child.stdout.on('data', (data) => { - const lines = data.toString().split('\n'); - for (const line of lines) { - if (line.startsWith('{')) { - try { - const event = JSON.parse(line); - events.push(event); - } catch { - // Ignore malformed JSON - } - } - } - }); - - child.stderr.on('data', (data) => { - stderr += data.toString(); - }); - - const timeoutId = setTimeout(() => { - child.kill('SIGTERM'); - reject(new Error('TIMEOUT')); - }, timeout); - - child.on('close', (code) => { - clearTimeout(timeoutId); - if (code === 0) { - resolve(events); - } else { - reject(new Error(`Exit code ${code}\n${stderr.slice(-500)}`)); - } - }); - - child.on('error', (err) => { - clearTimeout(timeoutId); - reject(err); - }); - }); -} - -/** - * Assert helper - */ -function assert(condition, message) { - if (!condition) { - throw new Error(`Assertion failed: ${message}`); - } -} - -/** - * Example test definitions - */ -const EXAMPLES = { - entropy: { - path: 'entropy/entropy.mjs', - timeout: 120000, - validate(events) { - const start = events.find(e => e.event === 'start'); - assert(start, 'should have start event'); - assert(start.model, 'start should have model'); - - const comparisons = events.filter(e => e.event === 'comparison'); - assert(comparisons.length === 3, `should have 3 comparisons, got ${comparisons.length}`); - - for (const c of comparisons) { - assert(c.fixed && c.edt, 'comparison should have fixed and edt results'); - assert(c.fixed.tokenCount > 0, 'fixed should generate tokens'); - assert(c.edt.tokenCount > 0, 'edt should generate tokens'); - assert(typeof c.edt.avgTemp === 'number', 'edt should have avgTemp'); - } - - const complete = events.find(e => e.event === 'complete'); - assert(complete, 'should have complete event'); - assert(complete.comparisons === 3, 'should complete 3 comparisons'); - }, - }, - - speculative: { - path: 'speculative/speculative.mjs', - timeout: 120000, - validate(events) { - const start = events.find(e => e.event === 'start'); - assert(start, 'should have start event'); - assert(start.draftCount > 0, 'should have draftCount'); - - const iterations = events.filter(e => e.event === 'iteration'); - assert(iterations.length > 0, 'should have iterations'); - - for (const iter of iterations) { - assert(iter.drafted > 0, 'iteration should have drafted tokens'); - assert(iter.accepted >= 0, 'iteration should have accepted count'); - } - - const complete = events.find(e => e.event === 'complete'); - assert(complete, 'should have complete event'); - assert(complete.outputTokens > 0, 'should generate tokens'); - assert(complete.acceptRate >= 0 && complete.acceptRate <= 1, 'acceptRate should be 0-1'); - }, - }, - - grammar: { - path: 'grammar/grammar.mjs', - timeout: 120000, - validate(events) { - const start = events.find(e => e.event === 'start'); - assert(start, 'should have start event'); - - const branchPoint = events.find(e => e.event === 'branch_point'); - assert(branchPoint, 'should have branch_point event'); - assert(branchPoint.prefix.includes('"city"'), 'should branch at city field'); - - const branchCompletes = events.filter(e => e.event === 'branch_complete'); - assert(branchCompletes.length === 3, 'should complete 3 branches'); - - const complete = events.find(e => e.event === 'complete'); - assert(complete, 'should have complete event'); - assert(complete.validJsonCount > 0, 'should produce valid JSON'); - }, - }, - - 'best-of-n': { - path: 'best-of-n/best-of-n.mjs', - timeout: 180000, - validate(events) { - const start = events.find(e => e.event === 'start'); - assert(start, 'should have start event'); - assert(start.n === 5, 'should have n=5 candidates'); - - const baseline = events.find(e => e.event === 'baseline'); - assert(baseline, 'should have baseline'); - assert(baseline.ppl > 0, 'baseline should have positive ppl'); - - const candidates = events.filter(e => e.event === 'candidate'); - assert(candidates.length === 5, 'should have 5 candidates'); - - for (const c of candidates) { - assert(c.ppl > 1 && c.ppl < 1000, `candidate ppl should be in (1, 1000), got ${c.ppl}`); - assert(c.tokenCount > 0, 'candidate should have tokens'); - } - - const complete = events.find(e => e.event === 'complete'); - assert(complete, 'should have complete event'); - assert(complete.bestPpl > 0, 'should have bestPpl'); - }, - }, - - streaming: { - path: 'streaming/streaming.mjs', - timeout: 120000, - extraArgs: ['--max-tokens=500'], - validate(events) { - const start = events.find(e => e.event === 'start'); - assert(start, 'should have start event'); - - const tokens = events.filter(e => e.event === 'token'); - assert(tokens.length > 50, 'should generate tokens'); - - for (const t of tokens.slice(0, 10)) { - assert(typeof t.surprisal === 'number', 'token should have surprisal'); - } - - const complete = events.find(e => e.event === 'complete'); - assert(complete, 'should have complete event'); - assert(complete.generatedTokens > 0, 'should generate tokens'); - assert(complete.finalPpl > 0, 'should have finalPpl'); - }, - }, - - 'streaming-tsampler': { - path: 'streaming/streaming-tsampler.mjs', - timeout: 120000, - extraArgs: ['--max-tokens=500'], - validate(events) { - const start = events.find(e => e.event === 'start'); - assert(start, 'should have start event'); - assert(start.ngramSize > 0, 'should have ngramSize'); - - const tokens = events.filter(e => e.event === 'token'); - assert(tokens.length > 0, 'should generate tokens'); - - const complete = events.find(e => e.event === 'complete'); - assert(complete, 'should have complete event'); - assert(complete.generatedTokens > 0, 'should generate tokens'); - assert(typeof complete.blockedCount === 'number', 'should track blocked count'); - assert(complete.uniqueNgrams > 0, 'should track unique ngrams'); - }, - }, - - 'streaming-summary': { - path: 'streaming/streaming-summary.mjs', - timeout: 180000, - extraArgs: ['--max-tokens=500'], - validate(events) { - const start = events.find(e => e.event === 'start'); - assert(start, 'should have start event'); - assert(start.summaryMode === 'self', 'should default to self-summary mode'); - - const tokens = events.filter(e => e.event === 'token'); - assert(tokens.length > 50, 'should generate tokens'); - - for (const t of tokens.slice(0, 10)) { - assert(t.source === 'main', 'token should have source=main'); - assert(typeof t.surprisal === 'number', 'token should have surprisal'); - } - - const complete = events.find(e => e.event === 'complete'); - assert(complete, 'should have complete event'); - assert(complete.generatedTokens > 0, 'should generate tokens'); - assert(complete.finalPpl > 0, 'should have finalPpl'); - }, - }, - - embed: { - path: 'embed/embed.mjs', - timeout: 60000, - modelPath: EMBED_MODEL_PATH, - skip: !fs.existsSync(EMBED_MODEL_PATH), - skipReason: 'nomic-embed-text model not found', - validate(events) { - const start = events.find(e => e.event === 'start'); - assert(start, 'should have start event'); - assert(start.embeddingDim > 0, 'should have embedding dimension'); - assert(start.hasPooling === true, 'should have pooling enabled'); - - const embeddings = events.filter(e => e.event === 'embedding'); - assert(embeddings.length === 4, 'should embed 4 texts'); - - for (const e of embeddings) { - assert(e.dimension > 0, 'embedding should have dimension'); - assert(e.elapsed >= 0, 'embedding should have elapsed time'); - } - - const similarities = events.filter(e => e.event === 'similarity'); - assert(similarities.length === 6, 'should have 6 similarity pairs (4 choose 2)'); - - for (const s of similarities) { - assert(s.similarity >= -1 && s.similarity <= 1, 'similarity should be in [-1, 1]'); - } - - const search = events.find(e => e.event === 'search'); - assert(search, 'should have search event'); - assert(search.results.length === 4, 'search should rank all texts'); - - const complete = events.find(e => e.event === 'complete'); - assert(complete, 'should have complete event'); - }, - }, -}; - -async function runTest(name, config) { - const fullPath = path.join(__dirname, '../examples', config.path); - - if (config.skip) { - console.log(`⏭️ ${name}: SKIPPED`); - console.log(` Reason: ${config.skipReason}`); - return { name, skipped: true, skipReason: config.skipReason }; - } - - console.log(`\nπŸ“œ ${name}:`); - const startTime = Date.now(); - - try { - const modelPathToUse = config.modelPath || MODEL_PATH; - const extraArgs = config.extraArgs || []; - - const events = await runExample(fullPath, config.timeout, extraArgs, modelPathToUse); - config.validate(events); - - const elapsed = ((Date.now() - startTime) / 1000).toFixed(1); - - console.log(` βœ… PASSED (${elapsed}s)`); - console.log(` Events: ${events.length} total`); - - // Show key metrics from complete event if present - const complete = events.find(e => e.event === 'complete'); - if (complete) { - const metrics = []; - if (complete.generatedTokens) metrics.push(`tokens: ${complete.generatedTokens}`); - if (complete.outputTokens) metrics.push(`tokens: ${complete.outputTokens}`); - if (complete.finalPpl) metrics.push(`ppl: ${complete.finalPpl.toFixed(2)}`); - if (complete.reseeds !== undefined) metrics.push(`reseeds: ${complete.reseeds}`); - if (complete.acceptRate !== undefined) metrics.push(`accept: ${(complete.acceptRate * 100).toFixed(0)}%`); - if (complete.validJsonCount !== undefined) metrics.push(`valid: ${complete.validJsonCount}/${complete.branchCount}`); - if (complete.bestPpl) metrics.push(`bestPpl: ${complete.bestPpl.toFixed(2)}`); - if (complete.embeddings) metrics.push(`embeddings: ${complete.embeddings}`); - if (metrics.length > 0) { - console.log(` Metrics: ${metrics.join(', ')}`); - } - } - - return { - name, - passed: true, - elapsed: parseFloat(elapsed), - eventCount: events.length, - metrics: complete || {} - }; - - } catch (err) { - const elapsed = ((Date.now() - startTime) / 1000).toFixed(1); - console.log(` ❌ FAILED (${elapsed}s)`); - console.log(` Error: ${err.message}`); - return { name, passed: false, elapsed: parseFloat(elapsed), error: err.message }; - } -} - -async function main() { - const filterName = process.argv[2]; - - console.log('=== Examples Integration Test ==='); - console.log(`Model: ${path.basename(MODEL_PATH)}`); - - const toRun = filterName - ? { [filterName]: EXAMPLES[filterName] } - : EXAMPLES; - - if (filterName && !EXAMPLES[filterName]) { - console.error(`Unknown example: ${filterName}`); - console.error(`Available: ${Object.keys(EXAMPLES).join(', ')}`); - process.exit(1); - } - - const results = []; - - for (const [name, config] of Object.entries(toRun)) { - const result = await runTest(name, config); - results.push(result); - } - - // Summary - console.log('\n' + '═'.repeat(60)); - console.log('EXAMPLES TEST SUMMARY'); - console.log('═'.repeat(60)); - console.log(`Model: ${path.basename(MODEL_PATH)}`); - console.log(); - - const passed = results.filter(r => r.passed).length; - const failed = results.filter(r => !r.passed && !r.skipped).length; - const skipped = results.filter(r => r.skipped).length; - const totalTime = results.reduce((sum, r) => sum + (r.elapsed || 0), 0).toFixed(1); - - console.log('Results:'); - for (const r of results) { - const status = r.skipped ? '⏭️ ' : (r.passed ? 'βœ…' : '❌'); - const time = r.elapsed ? ` (${r.elapsed}s)` : ''; - const detail = r.skipped ? ` - ${r.skipReason}` : (r.error ? ` - ${r.error.slice(0, 50)}` : ''); - console.log(` ${status} ${r.name}${time}${detail}`); - } - - console.log(); - console.log(`Total: ${passed} passed, ${failed} failed, ${skipped} skipped in ${totalTime}s`); - - process.exit(failed > 0 ? 1 : 0); -} - -main().catch((err) => { - console.error('Fatal:', err); - process.exit(1); -}); diff --git a/test/examples.ts b/test/examples.ts new file mode 100644 index 0000000..3005fec --- /dev/null +++ b/test/examples.ts @@ -0,0 +1,339 @@ +/** + * Examples Integration Test + * + * Runs examples with --jsonl flag and validates structured output. + * Each example emits JSONL events that we parse and assert on. + * + * Usage: + * npx tsx test/examples.ts # Run all examples + * npx tsx test/examples.ts entropy # Run specific example + * + * Environment variables: + * LLAMA_TEST_MODEL - Path to chat/instruct model (default: SmolLM2) + * EMBED_MODEL_PATH - Path to embedding model (default: nomic-embed) + */ + +import { spawn, ChildProcess } from 'node:child_process'; +import * as path from 'node:path'; +import * as fs from 'node:fs'; + +interface ExampleEvent { + event: string; + [key: string]: any; // eslint-disable-line @typescript-eslint/no-explicit-any -- dynamic JSONL fields +} + +interface ExampleConfig { + path: string; + timeout: number; + modelPath?: string; + extraArgs?: string[]; + skip?: boolean; + skipReason?: string; + validate: (events: ExampleEvent[]) => void; +} + +interface TestResult { + name: string; + passed?: boolean; + skipped?: boolean; + skipReason?: string; + elapsed?: number; + eventCount?: number; + metrics?: Record; + error?: string; +} + +const MODEL_PATH: string = process.env.LLAMA_TEST_MODEL + ? path.resolve(process.env.LLAMA_TEST_MODEL) + : path.join(__dirname, '../models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf'); + +const EMBED_MODEL_PATH: string = process.env.EMBED_MODEL_PATH + ? path.resolve(process.env.EMBED_MODEL_PATH) + : path.join(__dirname, '../liblloyal/tests/fixtures/nomic-embed-text-v1.5.Q4_K_M.gguf'); + +const QWEN3_PATH: string = process.env.QWEN3_MODEL + ? path.resolve(process.env.QWEN3_MODEL) + : path.join(__dirname, '../models/Qwen3-4B-Instruct-2507-Q4_K_M.gguf'); + +const RERANKER_PATH: string = process.env.RERANKER_MODEL + ? path.resolve(process.env.RERANKER_MODEL) + : path.join(__dirname, '../models/qwen3-reranker-0.6b-q4_k_m.gguf'); + + +if (!fs.existsSync(MODEL_PATH)) { + console.error('❌ Test model not found!'); + console.error(` Expected: ${MODEL_PATH}`); + console.error(' Run: npm run download-models'); + process.exit(1); +} + +function runExample(scriptPath: string, timeout: number = 600000, extraArgs: string[] = [], modelPathOverride: string | null = null): Promise { + return new Promise((resolve: (value: ExampleEvent[]) => void, reject: (reason: Error) => void) => { + const events: ExampleEvent[] = []; + let stderr: string = ''; + + const modelArg: string = modelPathOverride || MODEL_PATH; + + const child: ChildProcess = spawn('npx', ['tsx', scriptPath, modelArg, '--jsonl', ...extraArgs], { + cwd: path.dirname(scriptPath), + stdio: ['ignore', 'pipe', 'pipe'], + }); + + let buf = ''; + child.stdout!.on('data', (data: Buffer) => { + buf += data.toString(); + const parts = buf.split('\n'); + buf = parts.pop()!; // carry partial line forward + for (const line of parts) { + if (line.startsWith('{')) { + try { + events.push(JSON.parse(line)); + } catch { /* malformed */ } + } + } + }); + + child.stderr!.on('data', (data: Buffer) => { + stderr += data.toString(); + }); + + const timeoutId: NodeJS.Timeout = setTimeout(() => { + child.kill('SIGTERM'); + reject(new Error('TIMEOUT')); + }, timeout); + + child.on('close', (code: number | null) => { + clearTimeout(timeoutId); + if (code === 0) { + resolve(events); + } else { + reject(new Error(`Exit code ${code}\n${stderr.slice(-500)}`)); + } + }); + + child.on('error', (err: Error) => { + clearTimeout(timeoutId); + reject(err); + }); + }); +} + +function assert(condition: unknown, message: string): asserts condition { + if (!condition) { + throw new Error(`Assertion failed: ${message}`); + } +} + +const EXAMPLES: Record = { + entropy: { + path: 'entropy/entropy.ts', + timeout: 120000, + validate(events: ExampleEvent[]): void { + const start: ExampleEvent | undefined = events.find(e => e.event === 'start'); + assert(start, 'should have start event'); + assert(start.model, 'start should have model'); + + const comparisons: ExampleEvent[] = events.filter(e => e.event === 'comparison'); + assert(comparisons.length === 3, `should have 3 comparisons, got ${comparisons.length}`); + + for (const c of comparisons) { + assert(c.fixed && c.edt, 'comparison should have fixed and edt results'); + assert(c.fixed.tokenCount > 0, 'fixed should generate tokens'); + assert(c.edt.tokenCount > 0, 'edt should generate tokens'); + assert(typeof c.edt.avgTemp === 'number', 'edt should have avgTemp'); + } + + const complete: ExampleEvent | undefined = events.find(e => e.event === 'complete'); + assert(complete, 'should have complete event'); + assert(complete.comparisons === 3, 'should complete 3 comparisons'); + }, + }, + + embed: { + path: 'embed/embed.ts', + timeout: 60000, + modelPath: EMBED_MODEL_PATH, + skip: !fs.existsSync(EMBED_MODEL_PATH), + skipReason: 'nomic-embed-text model not found', + validate(events: ExampleEvent[]): void { + const start: ExampleEvent | undefined = events.find(e => e.event === 'start'); + assert(start, 'should have start event'); + assert(start.embeddingDim > 0, 'should have embedding dimension'); + assert(start.hasPooling === true, 'should have pooling enabled'); + + const embeddings: ExampleEvent[] = events.filter(e => e.event === 'embedding'); + assert(embeddings.length === 4, 'should embed 4 texts'); + + for (const e of embeddings) { + assert(e.dimension > 0, 'embedding should have dimension'); + assert(e.elapsed >= 0, 'embedding should have elapsed time'); + } + + const similarities: ExampleEvent[] = events.filter(e => e.event === 'similarity'); + assert(similarities.length === 6, 'should have 6 similarity pairs (4 choose 2)'); + + for (const s of similarities) { + assert(s.similarity >= -1 && s.similarity <= 1, 'similarity should be in [-1, 1]'); + } + + const search: ExampleEvent | undefined = events.find(e => e.event === 'search'); + assert(search, 'should have search event'); + assert(search.results.length === 4, 'search should rank all texts'); + + const complete: ExampleEvent | undefined = events.find(e => e.event === 'complete'); + assert(complete, 'should have complete event'); + }, + }, + + 'deep-research': { + path: 'deep-research/deep-research.ts', + timeout: 300000, + modelPath: QWEN3_PATH, + extraArgs: [ + '--reranker', RERANKER_PATH, + '--corpus', process.env.DEEP_RESEARCH_CORPUS || '', + '--query', process.env.DEEP_RESEARCH_QUERY || '', + ], + skip: !fs.existsSync(QWEN3_PATH) || !fs.existsSync(RERANKER_PATH) + || !process.env.DEEP_RESEARCH_CORPUS || !process.env.DEEP_RESEARCH_QUERY, + skipReason: 'Requires QWEN3_MODEL, RERANKER_MODEL, DEEP_RESEARCH_CORPUS, and DEEP_RESEARCH_QUERY env vars', + validate(events: ExampleEvent[]): void { + const start: ExampleEvent | undefined = events.find(e => e.event === 'start'); + assert(start, 'should have start event'); + assert(start.agentCount === 3, 'should have 3 agents'); + assert(start.chunks > 0, 'should have corpus chunks'); + + const plan: ExampleEvent | undefined = events.find(e => e.event === 'plan'); + assert(plan, 'should have plan event'); + assert(plan.questions.length >= 2, 'should plan at least 2 sub-questions'); + + const researchStart: ExampleEvent | undefined = events.find(e => e.event === 'research_start'); + assert(researchStart, 'should have research_start event'); + assert(researchStart.sharedPrefixTokens > 0, 'should have shared prefix'); + + const toolCalls: ExampleEvent[] = events.filter(e => e.event === 'tool_call'); + assert(toolCalls.length > 0, 'should make at least one tool call'); + + const agentsDone: ExampleEvent[] = events.filter(e => e.event === 'agent_done'); + assert(agentsDone.length === 3, 'all 3 agents should finish'); + for (const a of agentsDone) { + assert(a.tokenCount > 0, `agent ${a.index} should generate tokens`); + } + + const complete: ExampleEvent | undefined = events.find(e => e.event === 'complete'); + assert(complete, 'should have complete event'); + assert(complete.totalToolCalls > 0, 'should have tool calls'); + assert(complete.wallTimeMs > 0, 'should have wall time'); + assert(complete.converged !== undefined, 'should have convergence result'); + }, + }, +}; + +async function runTest(name: string, config: ExampleConfig): Promise { + const fullPath: string = path.join(__dirname, '../examples', config.path); + + if (config.skip) { + console.log(`⏭️ ${name}: SKIPPED`); + console.log(` Reason: ${config.skipReason}`); + return { name, skipped: true, skipReason: config.skipReason }; + } + + console.log(`\nπŸ“œ ${name}:`); + const startTime: number = Date.now(); + + try { + const modelPathToUse: string = config.modelPath || MODEL_PATH; + const extraArgs: string[] = config.extraArgs || []; + + const events: ExampleEvent[] = await runExample(fullPath, config.timeout, extraArgs, modelPathToUse); + config.validate(events); + + const elapsed: string = ((Date.now() - startTime) / 1000).toFixed(1); + + console.log(` βœ… PASSED (${elapsed}s)`); + console.log(` Events: ${events.length} total`); + + const complete: ExampleEvent | undefined = events.find(e => e.event === 'complete'); + if (complete) { + const metrics: string[] = []; + if (complete.generatedTokens) metrics.push(`tokens: ${complete.generatedTokens}`); + if (complete.outputTokens) metrics.push(`tokens: ${complete.outputTokens}`); + if (complete.finalPpl) metrics.push(`ppl: ${complete.finalPpl.toFixed(2)}`); + if (complete.reseeds !== undefined) metrics.push(`reseeds: ${complete.reseeds}`); + if (complete.acceptRate !== undefined) metrics.push(`accept: ${(complete.acceptRate * 100).toFixed(0)}%`); + if (complete.validJsonCount !== undefined) metrics.push(`valid: ${complete.validJsonCount}/${complete.branchCount}`); + if (complete.bestPpl) metrics.push(`bestPpl: ${complete.bestPpl.toFixed(2)}`); + if (complete.embeddings) metrics.push(`embeddings: ${complete.embeddings}`); + if (metrics.length > 0) { + console.log(` Metrics: ${metrics.join(', ')}`); + } + } + + return { + name, + passed: true, + elapsed: parseFloat(elapsed), + eventCount: events.length, + metrics: complete || {} + }; + + } catch (err) { + const elapsed: string = ((Date.now() - startTime) / 1000).toFixed(1); + console.log(` ❌ FAILED (${elapsed}s)`); + console.log(` Error: ${(err as Error).message}`); + return { name, passed: false, elapsed: parseFloat(elapsed), error: (err as Error).message }; + } +} + +async function main(): Promise { + const filterName: string | undefined = process.argv[2]; + + console.log('=== Examples Integration Test ==='); + console.log(`Model: ${path.basename(MODEL_PATH)}`); + + const toRun: Record = filterName + ? { [filterName]: EXAMPLES[filterName] } + : EXAMPLES; + + if (filterName && !EXAMPLES[filterName]) { + console.error(`Unknown example: ${filterName}`); + console.error(`Available: ${Object.keys(EXAMPLES).join(', ')}`); + process.exit(1); + } + + const results: TestResult[] = []; + + for (const [name, config] of Object.entries(toRun)) { + const result: TestResult = await runTest(name, config); + results.push(result); + } + + console.log('\n' + '═'.repeat(60)); + console.log('EXAMPLES TEST SUMMARY'); + console.log('═'.repeat(60)); + console.log(`Model: ${path.basename(MODEL_PATH)}`); + console.log(); + + const passed: number = results.filter(r => r.passed).length; + const failed: number = results.filter(r => !r.passed && !r.skipped).length; + const skipped: number = results.filter(r => r.skipped).length; + const totalTime: string = results.reduce((sum: number, r: TestResult) => sum + (r.elapsed || 0), 0).toFixed(1); + + console.log('Results:'); + for (const r of results) { + const status: string = r.skipped ? '⏭️ ' : (r.passed ? 'βœ…' : '❌'); + const time: string = r.elapsed ? ` (${r.elapsed}s)` : ''; + const detail: string = r.skipped ? ` - ${r.skipReason}` : (r.error ? ` - ${r.error.slice(0, 50)}` : ''); + console.log(` ${status} ${r.name}${time}${detail}`); + } + + console.log(); + console.log(`Total: ${passed} passed, ${failed} failed, ${skipped} skipped in ${totalTime}s`); + + process.exit(failed > 0 ? 1 : 0); +} + +main().catch((err: unknown) => { + console.error('Fatal:', err); + process.exit(1); +}); diff --git a/test/integration.js b/test/integration.ts similarity index 67% rename from test/integration.js rename to test/integration.ts index b063058..42b042c 100644 --- a/test/integration.js +++ b/test/integration.ts @@ -10,20 +10,29 @@ * * Optional embedding tests: * LLAMA_EMBED_MODEL=models/nomic-embed-text-v1.5.Q4_K_M.gguf npm run test:integration + * + * Optional rerank tests: + * LLAMA_RERANK_MODEL=models/bge-reranker-v2-m3-Q4_K_M.gguf npm run test:integration */ -const path = require('path'); -const fs = require('fs'); +import * as path from 'node:path'; +import * as fs from 'node:fs'; +import { loadBinary, Branch, BranchStore, Rerank } from '../dist/index.js'; +import type { SessionContext, NativeBinding, FormattedChatResult, Produced } from '../dist/index.js'; -const MODEL_PATH = process.env.LLAMA_TEST_MODEL +const MODEL_PATH: string = process.env.LLAMA_TEST_MODEL ? path.resolve(process.env.LLAMA_TEST_MODEL) : path.join(__dirname, '../models/SmolLM2-1.7B-Instruct-Q4_K_M.gguf'); -const EMBED_MODEL_PATH = process.env.LLAMA_EMBED_MODEL || +const EMBED_MODEL_PATH: string | null = process.env.LLAMA_EMBED_MODEL || (fs.existsSync(path.join(__dirname, '../models/nomic-embed-text-v1.5.Q4_K_M.gguf')) ? path.join(__dirname, '../models/nomic-embed-text-v1.5.Q4_K_M.gguf') : null); +const RERANK_MODEL_PATH: string | null = process.env.LLAMA_RERANK_MODEL || + (fs.existsSync(path.join(__dirname, '../models/qwen3-reranker-0.6b-q4_k_m.gguf')) + ? path.join(__dirname, '../models/qwen3-reranker-0.6b-q4_k_m.gguf') + : null); -const CTX_SIZE = parseInt(process.env.LLAMA_CTX_SIZE || '2048', 10); +const CTX_SIZE: number = parseInt(process.env.LLAMA_CTX_SIZE || '2048', 10); if (!fs.existsSync(MODEL_PATH)) { console.error('Test model not found:', MODEL_PATH); @@ -34,29 +43,28 @@ console.log('=== lloyal.node Integration Tests ===\n'); console.log(`Model: ${path.basename(MODEL_PATH)}`); console.log(`Size: ${(fs.statSync(MODEL_PATH).size / 1024 / 1024).toFixed(1)} MB\n`); -const { loadBinary, Branch, BranchStore } = require('..'); -let addon; +let addon: NativeBinding; try { - addon = require('../build/Release/lloyal.node'); + addon = require('../build/Release/lloyal.node') as NativeBinding; } catch { addon = loadBinary(); } // Test tracking -let passed = 0; -let failed = 0; +let passed: number = 0; +let failed: number = 0; -function ok(msg) { +function ok(msg: string): void { passed++; console.log(` [PASS] ${msg}`); } -function fail(msg) { +function fail(msg: string): void { failed++; console.log(` [FAIL] ${msg}`); } -function assert(condition, msg) { +function assert(condition: boolean, msg: string): void { if (condition) { ok(msg); } else { @@ -69,33 +77,33 @@ function assert(condition, msg) { // CORE API TESTS // ═══════════════════════════════════════════════════════════════════════════ -async function testCoreAPI(ctx) { +async function testCoreAPI(ctx: SessionContext): Promise { console.log('\n--- Core API ---'); // createContext validated by caller // tokenize / detokenize - const text = "Hello world"; - const tokens = await ctx.tokenize(text); + const text: string = "Hello world"; + const tokens: number[] = await ctx.tokenize(text); assert(tokens.length > 0, `tokenize("${text}") β†’ ${tokens.length} tokens`); - const reconstructed = await ctx.detokenize(tokens); + const reconstructed: string = await ctx.detokenize(tokens); assert(typeof reconstructed === 'string', `detokenize() β†’ "${reconstructed}"`); // tokenToText - const tokenText = ctx.tokenToText(tokens[0]); + const tokenText: string = ctx.tokenToText(tokens[0]); assert(typeof tokenText === 'string', `tokenToText(${tokens[0]}) β†’ "${tokenText}"`); // Branch-based prefill + getLogits const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(tokens); - const branchLogits = branch.getLogits(); + const branchLogits: Float32Array = branch.getLogits(); assert(branchLogits instanceof Float32Array, `branch.getLogits() β†’ Float32Array(${branchLogits.length})`); assert(branchLogits.length === ctx.vocabSize, `branchLogits.length === vocabSize (${ctx.vocabSize})`); // Validate logits are not garbage - let hasNonZero = false, hasNaN = false; + let hasNonZero: boolean = false, hasNaN: boolean = false; for (let i = 0; i < branchLogits.length; i++) { if (branchLogits[i] !== 0.0) hasNonZero = true; if (isNaN(branchLogits[i])) hasNaN = true; @@ -103,15 +111,15 @@ async function testCoreAPI(ctx) { assert(hasNonZero && !hasNaN, 'branch logits valid (non-zero, no NaN)'); // branch.modelEntropy - const entropy = branch.modelEntropy('nats'); + const entropy: number = branch.modelEntropy('nats'); assert(isFinite(entropy) && entropy >= 0, `branch.modelEntropy() β†’ ${entropy.toFixed(4)} nats`); // Branch greedy sampling (temperature: 0) - const greedy = branch.sample(); + const greedy: number = branch.sample(); assert(greedy >= 0 && greedy < ctx.vocabSize, `branch.sample() greedy β†’ ${greedy}`); // isStopToken - EOS should be a stop token - const eos = ctx.getEogToken(); + const eos: number = ctx.getEogToken(); assert(ctx.isStopToken(eos), `isStopToken(EOS=${eos}) β†’ true`); await branch.prune(); @@ -121,20 +129,20 @@ async function testCoreAPI(ctx) { // KV CACHE TESTS // ═══════════════════════════════════════════════════════════════════════════ -async function testKVCache(ctx) { +async function testKVCache(ctx: SessionContext): Promise { console.log('\n--- KV Cache ---'); await ctx.kvCacheClear(); - const tokens = await ctx.tokenize("Test prompt"); + const tokens: number[] = await ctx.tokenize("Test prompt"); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(tokens); - const sizeBefore = ctx.kvCacheSize(); + const sizeBefore: number = ctx.kvCacheSize(); assert(sizeBefore >= 0, `kvCacheSize() after prefill β†’ ${sizeBefore}`); await branch.prune(); await ctx.kvCacheClear(); - const sizeAfter = ctx.kvCacheSize(); + const sizeAfter: number = ctx.kvCacheSize(); assert(sizeAfter === -1, `kvCacheClear() β†’ size=${sizeAfter} (empty)`); } @@ -142,10 +150,10 @@ async function testKVCache(ctx) { // MULTI-SEQUENCE TESTS // ═══════════════════════════════════════════════════════════════════════════ -async function testMultiSequence() { +async function testMultiSequence(): Promise { console.log('\n--- Multi-Sequence KV ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4, @@ -154,12 +162,12 @@ async function testMultiSequence() { try { // Use a branch to prefill tokens (populates KV on its seq_id) - const tokens = await ctx.tokenize("The quick brown fox"); + const tokens: number[] = await ctx.tokenize("The quick brown fox"); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(tokens); // Branch allocates a seq_id β€” check its KV is populated - const branchPos = branch.position; + const branchPos: number = branch.position; assert(branchPos === tokens.length, `branch position β†’ ${branchPos}`); // Fork creates a new sequence with copied KV @@ -167,7 +175,7 @@ async function testMultiSequence() { assert(forked.position === branchPos, `forked position matches parent β†’ ${forked.position}`); // Raw KV seq ops still work for advanced use - const seq1Before = ctx.kvSeqPosMax(3); // unused seq_id + const seq1Before: number = ctx.kvSeqPosMax(3); // unused seq_id assert(seq1Before === -1, `kvSeqPosMax(unused) β†’ ${seq1Before} (empty)`); await forked.prune(); @@ -181,10 +189,10 @@ async function testMultiSequence() { // GRAMMAR TESTS // ═══════════════════════════════════════════════════════════════════════════ -async function testGrammar() { +async function testGrammar(): Promise { console.log('\n--- Grammar Sampling ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4, @@ -192,15 +200,15 @@ async function testGrammar() { }); try { - const grammar = `root ::= "{" ws "}" ws + const grammar: string = `root ::= "{" ws "}" ws ws ::= [ \\t\\n]*`; // Branch API with grammar - const prompt = await ctx.tokenize("Output: "); + const prompt: number[] = await ctx.tokenize("Output: "); const branch = Branch.create(ctx, 0, { temperature: 0 }, undefined, grammar); await branch.prefill(prompt); - const output = []; + const output: string[] = []; for (let i = 0; i < 10; i++) { const { token, text, isStop } = await branch.produce(); if (isStop) break; @@ -208,12 +216,12 @@ ws ::= [ \\t\\n]*`; output.push(text); } - const result = output.join(''); + const result: string = output.join(''); assert(/^\{\s*\}\s*$/.test(result), `Branch+grammar β†’ "${result}"`); // Grammar is cloned on fork β€” independent parser states await ctx.kvCacheClear(); - const prompt2 = await ctx.tokenize("Output: "); + const prompt2: number[] = await ctx.tokenize("Output: "); const root = Branch.create(ctx, 0, { temperature: 0 }, undefined, grammar); await root.prefill(prompt2); @@ -221,15 +229,15 @@ ws ::= [ \\t\\n]*`; const childB = await root.fork(); // Both children should produce grammar-valid output independently - const outA = [], outB = []; + const outA: string[] = [], outB: string[] = []; for (let i = 0; i < 10; i++) { - const pA = await childA.produce(); + const pA: Produced = await childA.produce(); if (!pA.isStop) { await childA.commit(pA.token); outA.push(pA.text); } - const pB = await childB.produce(); + const pB: Produced = await childB.produce(); if (!pB.isStop) { await childB.commit(pB.token); outB.push(pB.text); } } - const resultA = outA.join(''), resultB = outB.join(''); + const resultA: string = outA.join(''), resultB: string = outB.join(''); assert(/^\{\s*\}\s*$/.test(resultA), `Fork A grammar β†’ "${resultA}"`); assert(/^\{\s*\}\s*$/.test(resultB), `Fork B grammar β†’ "${resultB}"`); @@ -246,20 +254,20 @@ ws ::= [ \\t\\n]*`; // METRICS API TESTS // ═══════════════════════════════════════════════════════════════════════════ -async function testMetrics(ctx) { +async function testMetrics(ctx: SessionContext): Promise { console.log('\n--- Metrics API ---'); await ctx.kvCacheClear(); - const tokens = await ctx.tokenize("Hello"); + const tokens: number[] = await ctx.tokenize("Hello"); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(tokens); // branch.modelSurprisal - const token1 = branch.sample(); - const surprisal = branch.modelSurprisal(token1, "nats"); + const token1: number = branch.sample(); + const surprisal: number = branch.modelSurprisal(token1, "nats"); assert(surprisal >= 0, `branch.modelSurprisal() β†’ ${surprisal.toFixed(2)} nats`); - const surprisalBits = branch.modelSurprisal(token1, "bits"); + const surprisalBits: number = branch.modelSurprisal(token1, "bits"); assert(Math.abs(surprisalBits - surprisal / Math.log(2)) < 0.01, 'bits = nats / ln(2)'); // Branch perplexity β€” built-in, accumulates through commit() @@ -267,7 +275,7 @@ async function testMetrics(ctx) { const { token: token2 } = await branch.produce(); await branch.commit(token2); - const ppl = branch.perplexity; + const ppl: number = branch.perplexity; assert(isFinite(ppl) && ppl >= 1.0, `branch.perplexity β†’ ${ppl.toFixed(2)}`); await branch.prune(); @@ -277,10 +285,10 @@ async function testMetrics(ctx) { // BRANCH PREFILL TESTS // ═══════════════════════════════════════════════════════════════════════════ -async function testBranchPrefill() { +async function testBranchPrefill(): Promise { console.log('\n--- Branch.prefill Multi-Turn ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nBatch: 512, @@ -288,21 +296,21 @@ async function testBranchPrefill() { }); try { - const GEN_TOKENS = 5; - const turns = [ + const GEN_TOKENS: number = 5; + const turns: string[] = [ "What is the capital of France?", " Tell me more.", " What about transportation?" ]; - const messages = [{ role: 'user', content: turns[0] }]; + const messages: Array<{ role: string; content: string }> = [{ role: 'user', content: turns[0] }]; const { prompt } = await ctx.formatChat(JSON.stringify(messages)); - const promptToks = await ctx.tokenize(prompt); + const promptToks: number[] = await ctx.tokenize(prompt); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(promptToks); // Turn 1 - const gen1 = []; + const gen1: number[] = []; for (let i = 0; i < GEN_TOKENS; i++) { const { token, isStop } = await branch.produce(); if (isStop) break; @@ -312,11 +320,11 @@ async function testBranchPrefill() { assert(gen1.length > 0, `Turn 1: generated ${gen1.length} tokens`); // Track assistant response - const assistantText1 = await ctx.detokenize(gen1); + const assistantText1: string = await ctx.detokenize(gen1); messages.push({ role: 'assistant', content: assistantText1 }); // Warm continuation: format only new message + turn separator - const sep = ctx.getTurnSeparator(); + const sep: number[] = ctx.getTurnSeparator(); // Turn 2-3: prefill using format-only-new pattern + generate for (let t = 1; t < turns.length; t++) { @@ -325,15 +333,15 @@ async function testBranchPrefill() { { role: 'system', content: '' }, { role: 'user', content: turns[t] } ])); - const delta = await ctx.tokenize(prompt, false); - const prefillToks = [...sep, ...delta]; + const delta: number[] = await ctx.tokenize(prompt, false); + const prefillToks: number[] = [...sep, ...delta]; - const posBefore = branch.position; + const posBefore: number = branch.position; await branch.prefill(prefillToks); assert(branch.position === posBefore + prefillToks.length, `Turn ${t + 1}: prefill ${prefillToks.length} tokens β†’ pos=${branch.position}`); - const gen = []; + const gen: number[] = []; for (let i = 0; i < GEN_TOKENS; i++) { const { token, isStop } = await branch.produce(); if (isStop) break; @@ -343,7 +351,7 @@ async function testBranchPrefill() { assert(gen.length > 0, `Turn ${t + 1}: generated ${gen.length} tokens`); // Track assistant response - const assistantText = await ctx.detokenize(gen); + const assistantText: string = await ctx.detokenize(gen); messages.push({ role: 'assistant', content: assistantText }); } @@ -358,10 +366,10 @@ async function testBranchPrefill() { // Mirrors liblloyal C++ test: chat_in_integration_test.cpp // ═══════════════════════════════════════════════════════════════════════════ -async function testWarmMultiTurnRecall() { +async function testWarmMultiTurnRecall(): Promise { console.log('\n--- Warm Multi-Turn Recall ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nBatch: 512, @@ -369,11 +377,11 @@ async function testWarmMultiTurnRecall() { }); try { - const sep = ctx.getTurnSeparator(); + const sep: number[] = ctx.getTurnSeparator(); // Helper: generate until EOG (matches C++ test pattern) - async function generate(branch) { - const gen = []; + async function generate(branch: InstanceType): Promise { + const gen: number[] = []; for (;;) { const { token, isStop } = await branch.produce(); if (isStop) break; @@ -384,25 +392,25 @@ async function testWarmMultiTurnRecall() { } // Helper: warm continuation β€” sep + format([{system,""},{user,msg}]) - async function warmTurn(branch, userContent) { + async function warmTurn(branch: InstanceType, userContent: string): Promise { const { prompt } = await ctx.formatChat(JSON.stringify([ { role: 'system', content: '' }, { role: 'user', content: userContent } ]), {}); - const delta = await ctx.tokenize(prompt, false); + const delta: number[] = await ctx.tokenize(prompt, false); await branch.prefill([...sep, ...delta]); return generate(branch); } // Turn 1 (COLD): introduce name - const msgs1 = [{ role: 'user', content: 'Hi, my name is Lloyal' }]; + const msgs1: Array<{ role: string; content: string }> = [{ role: 'user', content: 'Hi, my name is Lloyal' }]; const { prompt, format, reasoningFormat } = await ctx.formatChat(JSON.stringify(msgs1), {}); - const promptToks = await ctx.tokenize(prompt); + const promptToks: number[] = await ctx.tokenize(prompt); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(promptToks); // Helper: parse output and check content (not reasoning) for a term - function checkRecall(rawText, term) { + function checkRecall(rawText: string, term: string): boolean { const { content } = ctx.parseChatOutput(rawText, format, { reasoningFormat, isPartial: false, @@ -411,25 +419,25 @@ async function testWarmMultiTurnRecall() { return (content || '').toLowerCase().includes(term.toLowerCase()); } - const turn1 = await generate(branch); + const turn1: string = await generate(branch); console.log(` Turn 1: "${turn1.trim()}"`); assert(turn1.length > 0, 'Turn 1: generated response'); // Turn 2 (WARM): introduce favourite food - const turn2 = await warmTurn(branch, 'My favourite food is pizza'); + const turn2: string = await warmTurn(branch, 'My favourite food is pizza'); console.log(` Turn 2: "${turn2.trim()}"`); assert(turn2.length > 0, 'Turn 2: generated response'); // Turn 3 (WARM): recall name - const turn3 = await warmTurn(branch, 'Do you remember my name?'); + const turn3: string = await warmTurn(branch, 'Do you remember my name?'); console.log(` Turn 3 (name recall): "${turn3.trim()}"`); - const nameRecalled = checkRecall(turn3, 'lloyal'); + const nameRecalled: boolean = checkRecall(turn3, 'lloyal'); assert(nameRecalled, `Name recall: ${nameRecalled ? 'found "Lloyal"' : 'MISSING "Lloyal" in: ' + turn3.trim()}`); // Turn 4 (WARM): recall food - const turn4 = await warmTurn(branch, 'Do you remember my favourite food?'); + const turn4: string = await warmTurn(branch, 'Do you remember my favourite food?'); console.log(` Turn 4 (food recall): "${turn4.trim()}"`); - const foodRecalled = checkRecall(turn4, 'pizza'); + const foodRecalled: boolean = checkRecall(turn4, 'pizza'); assert(foodRecalled, `Food recall: ${foodRecalled ? 'found "pizza"' : 'MISSING "pizza" in: ' + turn4.trim()}`); await branch.prune(); @@ -442,7 +450,7 @@ async function testWarmMultiTurnRecall() { // WARM CONTINUATION SEMANTIC RECALL - Proves context survives delta-only prefill // ═══════════════════════════════════════════════════════════════════════════ -async function testWarmSemanticRecall() { +async function testWarmSemanticRecall(): Promise { if (!EMBED_MODEL_PATH) { console.log('\n--- Warm Semantic Recall (SKIPPED - no LLAMA_EMBED_MODEL) ---'); return; @@ -450,11 +458,11 @@ async function testWarmSemanticRecall() { console.log('\n--- Warm Semantic Recall ---'); - const GEN_TOKENS = 40; + const GEN_TOKENS: number = 40; // Helper: cosine similarity - function cosine(a, b) { - let dot = 0, na = 0, nb = 0; + function cosine(a: Float32Array, b: Float32Array): number { + let dot: number = 0, na: number = 0, nb: number = 0; for (let i = 0; i < a.length; i++) { dot += a[i] * b[i]; na += a[i] * a[i]; @@ -464,9 +472,9 @@ async function testWarmSemanticRecall() { } // Phase 1: Generate multi-turn conversation via warm continuation - let recallText; + let recallText: string; { - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nBatch: 512, @@ -474,28 +482,28 @@ async function testWarmSemanticRecall() { }); try { - const sep = ctx.getTurnSeparator(); - let branch; - const messages = []; + const sep: number[] = ctx.getTurnSeparator(); + let branch: InstanceType; + const messages: Array<{ role: string; content: string }> = []; // Helper: format-only-new warm continuation - async function warmTurn(userContent) { + async function warmTurn(userContent: string): Promise { messages.push({ role: 'user', content: userContent }); const { prompt } = await ctx.formatChat(JSON.stringify([ { role: 'system', content: '' }, { role: 'user', content: userContent } ])); - const delta = await ctx.tokenize(prompt, false); + const delta: number[] = await ctx.tokenize(prompt, false); await branch.prefill([...sep, ...delta]); - const gen = []; + const gen: number[] = []; for (let i = 0; i < GEN_TOKENS; i++) { const { token, isStop } = await branch.produce(); if (isStop) break; await branch.commit(token); gen.push(token); } - const text = await ctx.detokenize(gen); + const text: string = await ctx.detokenize(gen); messages.push({ role: 'assistant', content: text }); return text; } @@ -503,19 +511,19 @@ async function testWarmSemanticRecall() { // Turn 1: Plant a specific, recallable fact messages.push({ role: 'user', content: 'Remember this: my dog is named Max.' }); const { prompt } = await ctx.formatChat(JSON.stringify(messages)); - const promptToks = await ctx.tokenize(prompt); + const promptToks: number[] = await ctx.tokenize(prompt); branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(promptToks); // Generate turn 1 response - const gen = []; + const gen: number[] = []; for (let i = 0; i < GEN_TOKENS; i++) { const { token, isStop } = await branch.produce(); if (isStop) break; await branch.commit(token); gen.push(token); } - const turn1Response = await ctx.detokenize(gen); + const turn1Response: string = await ctx.detokenize(gen); messages.push({ role: 'assistant', content: turn1Response }); // Turn 2: Distractor @@ -535,7 +543,7 @@ async function testWarmSemanticRecall() { // Phase 2: Score via embedding similarity (chat model fully released) { - const embedCtx = await addon.createContext({ + const embedCtx: SessionContext = await addon.createContext({ modelPath: EMBED_MODEL_PATH, nCtx: 512, nBatch: 512, @@ -545,8 +553,8 @@ async function testWarmSemanticRecall() { }); try { - async function embed(text) { - const tokens = await embedCtx.tokenize(text); + async function embed(text: string): Promise { + const tokens: number[] = await embedCtx.tokenize(text); await embedCtx.kvCacheClear(); await embedCtx.encode(tokens); return embedCtx.getEmbeddings(true); @@ -554,12 +562,12 @@ async function testWarmSemanticRecall() { console.log(` Recall response: "${recallText.trim()}"`); - const embResponse = await embed(recallText); - const embCorrect = await embed('The dog is named Max.'); - const embWrong = await embed('Red, blue, and green are three colors.'); + const embResponse: Float32Array = await embed(recallText); + const embCorrect: Float32Array = await embed('The dog is named Max.'); + const embWrong: Float32Array = await embed('Red, blue, and green are three colors.'); - const simCorrect = cosine(embResponse, embCorrect); - const simWrong = cosine(embResponse, embWrong); + const simCorrect: number = cosine(embResponse, embCorrect); + const simWrong: number = cosine(embResponse, embWrong); assert(simCorrect > simWrong, `Semantic recall: correct=${simCorrect.toFixed(3)} > wrong=${simWrong.toFixed(3)}`); @@ -573,10 +581,10 @@ async function testWarmSemanticRecall() { // BRANCH STEER TESTS - Dynamic per-sample logit manipulation // ═══════════════════════════════════════════════════════════════════════════ -async function testBranchSteer() { +async function testBranchSteer(): Promise { console.log('\n--- Branch.steer ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4, @@ -584,25 +592,25 @@ async function testBranchSteer() { }); try { - const tokens = await ctx.tokenize("The quick brown"); + const tokens: number[] = await ctx.tokenize("The quick brown"); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(tokens); // Get the greedy token (what would be sampled without steer) - const greedyToken = branch.sample(); + const greedyToken: number = branch.sample(); assert(greedyToken >= 0, `Greedy sample β†’ ${greedyToken}`); // Block the greedy token with steer branch.steer([{ token: greedyToken, bias: -Infinity }]); // Sample again - should get a different token - const steeredToken = branch.sample(); + const steeredToken: number = branch.sample(); assert(steeredToken !== greedyToken, `steer() blocks greedy: ${greedyToken} β†’ ${steeredToken}`); // Clear steer - should get greedy token again branch.clearSteer(); - const afterClear = branch.sample(); + const afterClear: number = branch.sample(); assert(afterClear === greedyToken, `clearSteer() restores greedy: ${afterClear} === ${greedyToken}`); @@ -611,49 +619,49 @@ async function testBranchSteer() { { token: greedyToken, bias: -Infinity }, { token: steeredToken, bias: -Infinity }, ]); - const doubleBlocked = branch.sample(); + const doubleBlocked: number = branch.sample(); assert(doubleBlocked !== greedyToken && doubleBlocked !== steeredToken, `Multiple blocks: ${doubleBlocked} β‰  {${greedyToken}, ${steeredToken}}`); // Test boost (positive bias) branch.clearSteer(); branch.steer([{ token: 42, bias: 100.0 }]); // Massive boost to token 42 - const boosted = branch.sample(); + const boosted: number = branch.sample(); assert(boosted === 42, `Boost token 42 β†’ ${boosted}`); await branch.prune(); ok('steer()/clearSteer() work correctly'); // Test fork invariant: steer is NOT cloned on fork - const tokens2 = await ctx.tokenize("Hello world"); + const tokens2: number[] = await ctx.tokenize("Hello world"); const parent = Branch.create(ctx, 0, { temperature: 0 }); await parent.prefill(tokens2); - const parentGreedy = parent.sample(); + const parentGreedy: number = parent.sample(); // Apply steer to parent - block the greedy token parent.steer([{ token: parentGreedy, bias: -Infinity }]); - const parentSteered = parent.sample(); + const parentSteered: number = parent.sample(); assert(parentSteered !== parentGreedy, `Parent steered: ${parentSteered} β‰  ${parentGreedy}`); // Fork from parent - child should NOT inherit steer const child = await parent.fork(); - const childSample = child.sample(); + const childSample: number = child.sample(); assert(childSample === parentGreedy, `Fork does NOT inherit steer: child=${childSample} === greedy=${parentGreedy}`); // Verify parent still has steer active - const parentStillSteered = parent.sample(); + const parentStillSteered: number = parent.sample(); assert(parentStillSteered === parentSteered, `Parent retains steer after fork: ${parentStillSteered} === ${parentSteered}`); // Apply different steer to child - should not affect parent child.steer([{ token: 99, bias: 100.0 }]); - const childBoosted = child.sample(); + const childBoosted: number = child.sample(); assert(childBoosted === 99, `Child can set own steer: ${childBoosted} === 99`); // Parent should be unaffected by child's steer - const parentUnaffected = parent.sample(); + const parentUnaffected: number = parent.sample(); assert(parentUnaffected === parentSteered, `Parent unaffected by child steer: ${parentUnaffected} === ${parentSteered}`); @@ -669,14 +677,14 @@ async function testBranchSteer() { // NBATCH ABLATION - Chunk size must not affect output // ═══════════════════════════════════════════════════════════════════════════ -async function testNBatchAblation() { +async function testNBatchAblation(): Promise { console.log('\n--- nBatch Ablation ---'); - const nBatchValues = [32, 64, 128, 512]; - const results = {}; + const nBatchValues: number[] = [32, 64, 128, 512]; + const results: Record = {}; for (const nBatch of nBatchValues) { - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nBatch, @@ -684,16 +692,16 @@ async function testNBatchAblation() { }); try { - const messages = [{ role: 'user', content: "Hello, how are you today?" }]; + const messages: Array<{ role: string; content: string }> = [{ role: 'user', content: "Hello, how are you today?" }]; const { prompt } = await ctx.formatChat(JSON.stringify(messages)); - const promptToks = await ctx.tokenize(prompt); + const promptToks: number[] = await ctx.tokenize(prompt); const branch = Branch.create(ctx, 0, { temperature: 0 }, nBatch); await branch.prefill(promptToks); - const followUp = await ctx.tokenize(" What else?"); + const followUp: number[] = await ctx.tokenize(" What else?"); await branch.prefill(followUp); - const gen = []; + const gen: number[] = []; for (let i = 0; i < 5; i++) { const { token, isStop } = await branch.produce(); if (isStop) break; @@ -708,8 +716,8 @@ async function testNBatchAblation() { } } - const ref = results[nBatchValues[0]]; - let allMatch = true; + const ref: string = results[nBatchValues[0]]; + let allMatch: boolean = true; for (const nb of nBatchValues) { if (results[nb] !== ref) allMatch = false; } @@ -721,37 +729,37 @@ async function testNBatchAblation() { // TOKENIZER BEHAVIOR TESTS // ═══════════════════════════════════════════════════════════════════════════ -async function testTokenizer(ctx) { +async function testTokenizer(ctx: SessionContext): Promise { console.log('\n--- Tokenizer ---'); // getEogToken - const eog = ctx.getEogToken(); + const eog: number = ctx.getEogToken(); assert(Number.isInteger(eog), `getEogToken() β†’ ${eog}`); assert(ctx.isStopToken(eog), `EOS ${eog} is stop token`); - const eogText = ctx.tokenToText(eog); + const eogText: string = ctx.tokenToText(eog); assert(eogText.length > 0, `EOS text: "${eogText}"`); // tokenize with addSpecial - const withSpecial = await ctx.tokenize('Hello world', true); - const noSpecial = await ctx.tokenize('Hello world', false); + const withSpecial: number[] = await ctx.tokenize('Hello world', true); + const noSpecial: number[] = await ctx.tokenize('Hello world', false); assert(noSpecial.length <= withSpecial.length, `addSpecial=false (${noSpecial.length}) <= addSpecial=true (${withSpecial.length})`); // getTurnSeparator - const sep = ctx.getTurnSeparator(); + const sep: number[] = ctx.getTurnSeparator(); assert(Array.isArray(sep) && sep.length > 0, `getTurnSeparator() β†’ [${sep.join(',')}]`); - const hasStop = sep.some(t => ctx.isStopToken(t)); + const hasStop: boolean = sep.some((t: number) => ctx.isStopToken(t)); assert(hasStop, 'Separator contains stop token'); - const sepText = sep.map(t => ctx.tokenToText(t)).join(''); + const sepText: string = sep.map((t: number) => ctx.tokenToText(t)).join(''); ok(`Separator text: ${JSON.stringify(sepText)}`); // Caching - const sep2 = ctx.getTurnSeparator(); - assert(sep.length === sep2.length && sep.every((t, i) => t === sep2[i]), + const sep2: number[] = ctx.getTurnSeparator(); + assert(sep.length === sep2.length && sep.every((t: number, i: number) => t === sep2[i]), 'getTurnSeparator() cached'); } @@ -759,25 +767,25 @@ async function testTokenizer(ctx) { // DETERMINISM TEST - Same prompt must produce identical output // ═══════════════════════════════════════════════════════════════════════════ -async function testDeterminism() { +async function testDeterminism(): Promise { console.log('\n--- Determinism ---'); - async function generate(prompt) { - const ctx = await addon.createContext({ + async function generate(prompt: string): Promise { + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4 }); try { - const messages = [{ role: 'user', content: prompt }]; + const messages: Array<{ role: string; content: string }> = [{ role: 'user', content: prompt }]; const { prompt: formatted } = await ctx.formatChat(JSON.stringify(messages)); - const tokens = await ctx.tokenize(formatted); + const tokens: number[] = await ctx.tokenize(formatted); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(tokens); - const gen = []; + const gen: number[] = []; for (let i = 0; i < 20; i++) { const { token, isStop } = await branch.produce(); if (isStop) break; @@ -791,9 +799,9 @@ async function testDeterminism() { } } - const prompt = "Count from 1 to 5."; - const run1 = await generate(prompt); - const run2 = await generate(prompt); + const prompt: string = "Count from 1 to 5."; + const run1: string = await generate(prompt); + const run2: string = await generate(prompt); assert(run1 === run2, `Deterministic: run1 === run2 (${run1.split(',').length} tokens)`); } @@ -802,7 +810,7 @@ async function testDeterminism() { // EMBEDDING TESTS (optional) // ═══════════════════════════════════════════════════════════════════════════ -async function testEmbeddings() { +async function testEmbeddings(): Promise { if (!EMBED_MODEL_PATH) { console.log('\n--- Embeddings (SKIPPED - no LLAMA_EMBED_MODEL) ---'); return; @@ -811,7 +819,7 @@ async function testEmbeddings() { console.log('\n--- Embeddings ---'); console.log(` Model: ${path.basename(EMBED_MODEL_PATH)}`); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: EMBED_MODEL_PATH, nCtx: 512, nBatch: 512, @@ -823,28 +831,28 @@ async function testEmbeddings() { try { assert(ctx.hasPooling(), 'hasPooling() β†’ true'); - const dim = ctx.getEmbeddingDimension(); + const dim: number = ctx.getEmbeddingDimension(); assert(dim > 0, `getEmbeddingDimension() β†’ ${dim}`); - async function embed(text) { - const tokens = await ctx.tokenize(text); + async function embed(text: string): Promise { + const tokens: number[] = await ctx.tokenize(text); await ctx.kvCacheClear(); await ctx.encode(tokens); return ctx.getEmbeddings(true); } - const emb1 = await embed("Hello world"); + const emb1: Float32Array = await embed("Hello world"); assert(emb1.length === dim, `embed("Hello world") β†’ Float32Array(${emb1.length})`); // L2 norm should be ~1.0 - let norm = 0; + let norm: number = 0; for (let i = 0; i < emb1.length; i++) norm += emb1[i] * emb1[i]; norm = Math.sqrt(norm); assert(Math.abs(norm - 1.0) < 0.01, `L2 normalized: norm=${norm.toFixed(4)}`); // Cosine similarity - function cosine(a, b) { - let dot = 0, na = 0, nb = 0; + function cosine(a: Float32Array, b: Float32Array): number { + let dot: number = 0, na: number = 0, nb: number = 0; for (let i = 0; i < a.length; i++) { dot += a[i] * b[i]; na += a[i] * a[i]; @@ -853,16 +861,16 @@ async function testEmbeddings() { return dot / (Math.sqrt(na) * Math.sqrt(nb)); } - const emb1Copy = await embed("Hello world"); - const simIdentical = cosine(emb1, emb1Copy); + const emb1Copy: Float32Array = await embed("Hello world"); + const simIdentical: number = cosine(emb1, emb1Copy); assert(simIdentical > 0.99, `Identical texts similarity: ${simIdentical.toFixed(4)}`); - const embSimilar = await embed("The cat sat on the mat"); - const embDifferent = await embed("Stock prices rose sharply"); - const embCat = await embed("A cat rested on the rug"); + const embSimilar: Float32Array = await embed("The cat sat on the mat"); + const embDifferent: Float32Array = await embed("Stock prices rose sharply"); + const embCat: Float32Array = await embed("A cat rested on the rug"); - const simSimilar = cosine(embSimilar, embCat); - const simDifferent = cosine(embSimilar, embDifferent); + const simSimilar: number = cosine(embSimilar, embCat); + const simDifferent: number = cosine(embSimilar, embDifferent); assert(simSimilar > simDifferent, `Semantic: similar=${simSimilar.toFixed(3)} > different=${simDifferent.toFixed(3)}`); } finally { @@ -874,31 +882,31 @@ async function testEmbeddings() { // BRANCH PREFILL + GET LOGITS (replaces testDecodeAndCapture) // ═══════════════════════════════════════════════════════════════════════════ -async function testBranchPrefillAndLogits() { +async function testBranchPrefillAndLogits(): Promise { console.log('\n--- Branch prefill + getLogits ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4 }); try { - const tokens = await ctx.tokenize("Hello"); + const tokens: number[] = await ctx.tokenize("Hello"); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(tokens); - const logits = branch.getLogits(); - let valid = false; + const logits: Float32Array = branch.getLogits(); + let valid: boolean = false; for (let i = 0; i < logits.length; i++) { if (logits[i] !== 0 && !isNaN(logits[i])) valid = true; } assert(valid, `branch.prefill() + getLogits() β†’ valid logits`); // Branch logits are an independent copy - const orig = logits[0]; + const orig: number = logits[0]; logits[0] = -999; - const logits2 = branch.getLogits(); + const logits2: Float32Array = branch.getLogits(); assert(logits2[0] !== -999, 'branch.getLogits() returns independent copy'); await branch.prune(); @@ -911,12 +919,12 @@ async function testBranchPrefillAndLogits() { // MAIN // ═══════════════════════════════════════════════════════════════════════════ -async function testChatInOut(ctx) { +async function testChatInOut(ctx: SessionContext): Promise { console.log('\n── chat_in / chat_out ──'); // formatChat with empty options object (new signature) - const messages = [{ role: 'user', content: 'Hello' }]; - const result = await ctx.formatChat(JSON.stringify(messages), {}); + const messages: Array<{ role: string; content: string }> = [{ role: 'user', content: 'Hello' }]; + const result: FormattedChatResult = await ctx.formatChat(JSON.stringify(messages), {}); assert(result.prompt.includes('Hello'), 'formatChat with options: prompt contains Hello'); assert(typeof result.format === 'number', 'formatChat returns format as number'); assert(typeof result.grammar === 'string', 'formatChat returns grammar as string'); @@ -928,12 +936,12 @@ async function testChatInOut(ctx) { ok('formatChat with options returns extended result'); // Backward compat: string second argument still works - const backCompat = await ctx.formatChat(JSON.stringify(messages)); + const backCompat: FormattedChatResult = await ctx.formatChat(JSON.stringify(messages)); assert(backCompat.prompt.includes('Hello'), 'formatChat backward compat works'); ok('formatChat backward compat (no second arg)'); // formatChat with tools - const tools = [{ + const tools: Array<{ type: string; function: { name: string; description: string; parameters: object } }> = [{ type: 'function', function: { name: 'get_weather', @@ -941,7 +949,7 @@ async function testChatInOut(ctx) { parameters: { type: 'object', properties: { location: { type: 'string' } } } } }]; - const toolResult = await ctx.formatChat(JSON.stringify(messages), { + const toolResult: FormattedChatResult = await ctx.formatChat(JSON.stringify(messages), { tools: JSON.stringify(tools), toolChoice: 'auto' }); @@ -975,10 +983,10 @@ async function testChatInOut(ctx) { // wrapper surface and real-world workflows. // ═══════════════════════════════════════════════════════════════════════════ -async function testBranchStore() { +async function testBranchStore(): Promise { console.log('\n--- BranchStore ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nBatch: 512, @@ -987,7 +995,7 @@ async function testBranchStore() { }); try { - const promptToks = await ctx.tokenize("The quick brown fox jumps over the lazy"); + const promptToks: number[] = await ctx.tokenize("The quick brown fox jumps over the lazy"); const store = new BranchStore(ctx); // ── Test A: Best-of-N generation ── @@ -998,24 +1006,26 @@ async function testBranchStore() { { const root = Branch.create(ctx, 0, { temperature: 0.8 }); await root.prefill(promptToks); - const branches = [root, await root.fork(), await root.fork()]; + const branches: InstanceType[] = [root, await root.fork(), await root.fork()]; branches[1].reseedSampler(42); branches[2].reseedSampler(99); for (let step = 0; step < 10; step++) { - const produced = await Promise.all(branches.map(async b => [b, await b.produce()])); + const produced: Array<[InstanceType, Produced]> = await Promise.all( + branches.map(async (b): Promise<[InstanceType, Produced]> => [b, await b.produce()]) + ); const live = produced.filter(([, p]) => !p.isStop); if (!live.length) break; - await store.commit(live.map(([b, p]) => [b, p.token])); + await store.commit(live.map(([b, p]) => [b, p.token] as [InstanceType, number])); } - const ppls = branches.map(b => b.perplexity); - console.log(` best-of-N perplexities: [${ppls.map(p => p.toFixed(2)).join(', ')}]`); - assert(ppls.every(p => isFinite(p) && p >= 1.0), - `best-of-N: all perplexities valid [${ppls.map(p => p.toFixed(2))}]`); + const ppls: number[] = branches.map((b) => b.perplexity); + console.log(` best-of-N perplexities: [${ppls.map((p) => p.toFixed(2)).join(', ')}]`); + assert(ppls.every((p) => isFinite(p) && p >= 1.0), + `best-of-N: all perplexities valid [${ppls.map((p) => p.toFixed(2))}]`); - const best = ppls.reduce((a, b) => Math.min(a, b)); - const worst = ppls.reduce((a, b) => Math.max(a, b)); + const best: number = ppls.reduce((a, b) => Math.min(a, b)); + const worst: number = ppls.reduce((a, b) => Math.max(a, b)); console.log(` [PASS] best-of-N: best=${best.toFixed(2)}, worst=${worst.toFixed(2)}`); await root.pruneSubtree(); @@ -1031,15 +1041,15 @@ async function testBranchStore() { const b2 = await b1.fork(); // Phase 1: Rehydrate from "saved" histories - const history1 = await ctx.tokenize(" dog. The weather is nice today and I want to go", false); - const history2 = await ctx.tokenize(" cat. Let me explain how quantum entanglement works in", false); + const history1: number[] = await ctx.tokenize(" dog. The weather is nice today and I want to go", false); + const history2: number[] = await ctx.tokenize(" cat. Let me explain how quantum entanglement works in", false); await store.prefill([[b1, history1], [b2, history2]]); // Branches should be at different-length positions? No β€” same length coincidentally. // But logits must differ (different KV contents) - const logitsAfterPrefill1 = b1.getLogits(); - const logitsAfterPrefill2 = b2.getLogits(); - let prefillDiffer = false; + const logitsAfterPrefill1: Float32Array = b1.getLogits(); + const logitsAfterPrefill2: Float32Array = b2.getLogits(); + let prefillDiffer: boolean = false; for (let i = 0; i < logitsAfterPrefill1.length; i++) { if (logitsAfterPrefill1[i] !== logitsAfterPrefill2[i]) { prefillDiffer = true; break; } } @@ -1047,19 +1057,19 @@ async function testBranchStore() { `rehydrate: different histories β†’ different logits after prefill`); // Phase 2: Generate continuations - const gen1 = [], gen2 = []; + const gen1: number[] = [], gen2: number[] = []; for (let i = 0; i < 5; i++) { - const produced = [[b1, await b1.produce()], [b2, await b2.produce()]]; + const produced: Array<[InstanceType, Produced]> = [[b1, await b1.produce()], [b2, await b2.produce()]]; const live = produced.filter(([, p]) => !p.isStop); if (!live.length) break; - await store.commit(live.map(([b, p]) => [b, p.token])); + await store.commit(live.map(([b, p]) => [b, p.token] as [InstanceType, number])); for (const [b, p] of live) { (b === b1 ? gen1 : gen2).push(p.token); } } - const text1 = await ctx.detokenize(gen1); - const text2 = await ctx.detokenize(gen2); + const text1: string = await ctx.detokenize(gen1); + const text2: string = await ctx.detokenize(gen2); console.log(` rehydrate "weather" β†’ "${text1}"`); console.log(` rehydrate "quantum" β†’ "${text2}"`); @@ -1078,22 +1088,22 @@ async function testBranchStore() { const b1 = Branch.create(ctx, 0, { temperature: 0 }); await b1.prefill(promptToks); - const logits = b1.getLogits(); + const logits: Float32Array = b1.getLogits(); assert(logits instanceof Float32Array, `getLogits: returns Float32Array`); assert(logits.length === ctx.vocabSize, `getLogits: length=${logits.length} === vocabSize=${ctx.vocabSize}`); // branch.modelEntropy β€” proves the logits snapshot is a valid distribution - const entropyFromBranch = b1.modelEntropy("nats"); + const entropyFromBranch: number = b1.modelEntropy("nats"); assert(isFinite(entropyFromBranch) && entropyFromBranch > 0, `branch.modelEntropy: ${entropyFromBranch.toFixed(4)} nats`); // After store.commit, logits change β€” branch reflects new state - const p = await b1.produce(); + const p: Produced = await b1.produce(); assert(!p.isStop, `modelEntropy: produce() should not hit EOG on first token`); await store.commit([[b1, p.token]]); - const entropyAfter = b1.modelEntropy("nats"); + const entropyAfter: number = b1.modelEntropy("nats"); assert(isFinite(entropyAfter), `modelEntropy after commit: entropy=${entropyAfter.toFixed(4)} nats`); @@ -1109,10 +1119,10 @@ async function testBranchStore() { await b1.prefill(promptToks); const b2 = await b1.fork(); - const output = []; + const output: string[] = []; for (let i = 0; i < 5; i++) { // Inspect with produce() β€” does NOT advance state - const p1 = await b1.produce(), p2 = await b2.produce(); + const p1: Produced = await b1.produce(), p2: Produced = await b2.produce(); // Can inspect text and isStop before committing assert(typeof p1.text === 'string' && typeof p2.text === 'string', @@ -1143,27 +1153,27 @@ async function testBranchStore() { // Step 1-3: single-branch commit (decode::one path) for (let i = 0; i < 3; i++) { - const produced = [[b1, await b1.produce()], [b2, await b2.produce()]]; + const produced: Array<[InstanceType, Produced]> = [[b1, await b1.produce()], [b2, await b2.produce()]]; const live = produced.filter(([, p]) => !p.isStop); if (!live.length) break; for (const [b, p] of live) await b.commit(p.token); } - const posAfterSingle = b1.position; + const posAfterSingle: number = b1.position; // Step 4-6: batched commit (decode::each path) for (let i = 0; i < 3; i++) { - const produced = [[b1, await b1.produce()], [b2, await b2.produce()]]; + const produced: Array<[InstanceType, Produced]> = [[b1, await b1.produce()], [b2, await b2.produce()]]; const live = produced.filter(([, p]) => !p.isStop); if (!live.length) break; - await store.commit(live.map(([b, p]) => [b, p.token])); + await store.commit(live.map(([b, p]) => [b, p.token] as [InstanceType, number])); } - const posAfterBatched = b1.position; + const posAfterBatched: number = b1.position; assert(posAfterBatched === posAfterSingle + 3, `mixed ops: position correct after singleβ†’batched (${posAfterSingle}β†’${posAfterBatched})`); // Step 7-9: back to single-branch commit for (let i = 0; i < 3; i++) { - const produced = [[b1, await b1.produce()], [b2, await b2.produce()]]; + const produced: Array<[InstanceType, Produced]> = [[b1, await b1.produce()], [b2, await b2.produce()]]; const live = produced.filter(([, p]) => !p.isStop); if (!live.length) break; for (const [b, p] of live) await b.commit(p.token); @@ -1185,9 +1195,9 @@ async function testBranchStore() { await b1.prefill(promptToks); const b2 = await b1.fork(); - const eog = ctx.getEogToken(); - const gen1 = [], gen2 = []; - const stopped = [false, false]; + const eog: number = ctx.getEogToken(); + const gen1: number[] = [], gen2: number[] = []; + const stopped: [boolean, boolean] = [false, false]; for (let step = 0; step < 8; step++) { // At step 3, force b1 to hit EOG @@ -1195,9 +1205,9 @@ async function testBranchStore() { b1.steer([{ token: eog, bias: 100.0 }]); } - const pairs = [ - ...(!stopped[0] ? [[b1, await b1.produce()]] : []), - ...(!stopped[1] ? [[b2, await b2.produce()]] : []), + const pairs: Array<[InstanceType, Produced]> = [ + ...(!stopped[0] ? [[b1, await b1.produce()] as [InstanceType, Produced]] : []), + ...(!stopped[1] ? [[b2, await b2.produce()] as [InstanceType, Produced]] : []), ]; const live = pairs.filter(([, p]) => !p.isStop); @@ -1210,7 +1220,7 @@ async function testBranchStore() { } if (!live.length) break; - await store.commit(live.map(([b, p]) => [b, p.token])); + await store.commit(live.map(([b, p]) => [b, p.token] as [InstanceType, number])); for (const [b, p] of live) { (b === b1 ? gen1 : gen2).push(p.token); @@ -1225,7 +1235,7 @@ async function testBranchStore() { assert(gen2.length > gen1.length, `independent EOG: b2 continued past b1's EOG (b1=${gen1.length}, b2=${gen2.length})`); - const text2 = await ctx.detokenize(gen2); + const text2: string = await ctx.detokenize(gen2); console.log(` independent EOG: b1 stopped at step 3, b2 continued β†’ "${text2}"`); // b2's position should reflect all its tokens, not be truncated by b1's stop @@ -1243,19 +1253,19 @@ async function testBranchStore() { // PPL SANITY β€” commit() must produce sane perplexity (not millions) // ═══════════════════════════════════════════════════════════════════════════ -async function testPplSanity() { +async function testPplSanity(): Promise { console.log('\n--- PPL Sanity ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4 }); try { - const messages = [{ role: 'user', content: 'Tell me about the weather.' }]; + const messages: Array<{ role: string; content: string }> = [{ role: 'user', content: 'Tell me about the weather.' }]; const { prompt } = await ctx.formatChat(JSON.stringify(messages)); - const promptToks = await ctx.tokenize(prompt); + const promptToks: number[] = await ctx.tokenize(prompt); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(promptToks); @@ -1265,7 +1275,7 @@ async function testPplSanity() { await branch.commit(token); } - const ppl = branch.perplexity; + const ppl: number = branch.perplexity; console.log(` perplexity after 10 commits: ${ppl.toFixed(2)}`); assert(isFinite(ppl) && ppl >= 1.0 && ppl < 1000, `PPL sanity: ${ppl.toFixed(2)} is in [1, 1000)`); @@ -1280,14 +1290,14 @@ async function testPplSanity() { // COMMIT ROLLBACK β€” decode failure must restore sampler/grammar/metrics // ═══════════════════════════════════════════════════════════════════════════ -async function testCommitRollback() { +async function testCommitRollback(): Promise { console.log('\n--- Commit Rollback ---'); // Tiny KV (nCtx=32) with many branches (nSeqMax=8). Each branch consumes // 1 KV cell per commit. With 8 branches and ~5 shared prefix cells, the // 32-cell budget exhausts after ~3 commits per branch. decode_each returns // non-zero (find_slot fails) β†’ StoreCommitWorker throws β†’ rollback fires. - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: 32, nBatch: 512, @@ -1296,10 +1306,10 @@ async function testCommitRollback() { }); try { - const promptToks = await ctx.tokenize("Hi"); + const promptToks: number[] = await ctx.tokenize("Hi"); const root = Branch.create(ctx, 0, { temperature: 1.0 }); await root.prefill(promptToks); - const branches = [root]; + const branches: InstanceType[] = [root]; for (let i = 1; i < 8; i++) { const b = await root.fork(); b.reseedSampler(1000 + i); // Divergent tokens β†’ separate KV cells @@ -1311,29 +1321,31 @@ async function testCommitRollback() { // Commit until decode fails from KV exhaustion // nCtx may be clamped to a model minimum (e.g. 256), so we need enough // rounds for 8 branches to exhaust ~256 cells: 256/8 = 32 rounds - let successfulRounds = 0; - let failedRound = false; + let successfulRounds: number = 0; + let failedRound: boolean = false; for (let round = 0; round < 50; round++) { - const produced = await Promise.all(branches.map(async b => [b, await b.produce()])); + const produced: Array<[InstanceType, Produced]> = await Promise.all( + branches.map(async (b): Promise<[InstanceType, Produced]> => [b, await b.produce()]) + ); const live = produced.filter(([, p]) => !p.isStop); if (!live.length) break; // Snapshot PPL before this round - const pplsBefore = live.map(([b]) => b.perplexity); + const pplsBefore: number[] = live.map(([b]) => b.perplexity); try { - await store.commit(live.map(([b, p]) => [b, p.token])); + await store.commit(live.map(([b, p]) => [b, p.token] as [InstanceType, number])); successfulRounds++; } catch { // Decode failed β€” verify PPL restored - const pplsAfter = live.map(([b]) => b.perplexity); - const allRestored = pplsBefore.every((p, i) => p === pplsAfter[i]); + const pplsAfter: number[] = live.map(([b]) => b.perplexity); + const allRestored: boolean = pplsBefore.every((p, i) => p === pplsAfter[i]); assert(allRestored, `rollback: all PPLs restored after decode failure at round ${round}`); // Branches still usable for single commits (1 token fits) const [b0, p0] = live[0]; - const posBefore = b0.position; + const posBefore: number = b0.position; try { await b0.commit(p0.token); assert(b0.position === posBefore + 1, @@ -1361,10 +1373,10 @@ async function testCommitRollback() { // ASYNC REJECTION β€” Worker failures must reject, branch state un-advanced // ═══════════════════════════════════════════════════════════════════════════ -async function testAsyncRejection() { +async function testAsyncRejection(): Promise { console.log('\n--- Async Rejection ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4, @@ -1372,7 +1384,7 @@ async function testAsyncRejection() { }); try { - const tokens = await ctx.tokenize("Hello world"); + const tokens: number[] = await ctx.tokenize("Hello world"); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(tokens); @@ -1380,56 +1392,56 @@ async function testAsyncRejection() { const { token, isStop } = await branch.produce(); assert(!isStop, 'rejection: initial produce succeeds'); await branch.commit(token); - const posAfterCommit = branch.position; + const posAfterCommit: number = branch.position; // Prune the branch β€” frees native resources await branch.prune(); assert(branch.disposed, 'rejection: branch is disposed after prune'); // commit() on disposed branch β€” _ensureNotDisposed should throw synchronously - let threwOnCommit = false; + let threwOnCommit: boolean = false; try { await branch.commit(token); - } catch (e) { + } catch (err) { threwOnCommit = true; - assert(e.message.includes('disposed'), `rejection: commit error says "disposed": "${e.message}"`); + assert((err as Error).message.includes('disposed'), `rejection: commit error says "disposed": "${(err as Error).message}"`); } assert(threwOnCommit, 'rejection: commit on disposed branch throws'); // produce() on disposed branch β€” async version rejects - let threwOnProduce = false; + let threwOnProduce: boolean = false; try { await branch.produce(); - } catch (e) { + } catch (err) { threwOnProduce = true; } assert(threwOnProduce, 'rejection: produce on disposed branch rejects'); // produceSync() on disposed branch β€” throws synchronously - let threwOnProduceSync = false; + let threwOnProduceSync: boolean = false; try { branch.produceSync(); - } catch (e) { + } catch (err) { threwOnProduceSync = true; } assert(threwOnProduceSync, 'rejection: produceSync on disposed branch throws'); // fork() on disposed branch - let threwOnFork = false; + let threwOnFork: boolean = false; try { await branch.fork(); - } catch (e) { + } catch (err) { threwOnFork = true; } assert(threwOnFork, 'rejection: fork on disposed branch throws'); // Native AsyncWorker rejection: call _branchPrefill with invalid handle (0) - let nativeRejected = false; + let nativeRejected: boolean = false; try { await ctx._branchPrefill(0, [token]); - } catch (e) { + } catch (err) { nativeRejected = true; - assert(e instanceof Error, `rejection: native rejection is Error: ${e.constructor.name}`); + assert(err instanceof Error, `rejection: native rejection is Error: ${(err as Error).constructor.name}`); } assert(nativeRejected, 'rejection: invalid handle to AsyncWorker rejects promise'); } finally { @@ -1441,10 +1453,10 @@ async function testAsyncRejection() { // EMPTY INPUT EDGE CASES β€” Batch workers with empty arrays resolve cleanly // ═══════════════════════════════════════════════════════════════════════════ -async function testEmptyInputEdgeCases() { +async function testEmptyInputEdgeCases(): Promise { console.log('\n--- Empty Input Edge Cases ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4, @@ -1452,12 +1464,12 @@ async function testEmptyInputEdgeCases() { }); try { - const tokens = await ctx.tokenize("Hello world"); + const tokens: number[] = await ctx.tokenize("Hello world"); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(tokens); const store = new BranchStore(ctx); - const posBefore = branch.position; + const posBefore: number = branch.position; // store.commit([]) β€” empty batch await store.commit([]); @@ -1490,10 +1502,10 @@ async function testEmptyInputEdgeCases() { // JSON SCHEMA TO GRAMMAR β€” AsyncWorker with zero prior coverage // ═══════════════════════════════════════════════════════════════════════════ -async function testJsonSchemaToGrammar() { +async function testJsonSchemaToGrammar(): Promise { console.log('\n--- jsonSchemaToGrammar ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4 @@ -1510,17 +1522,17 @@ async function testJsonSchemaToGrammar() { }; // Happy path: valid schema β†’ GBNF string - const grammar = await ctx.jsonSchemaToGrammar(JSON.stringify(schema)); + const grammar: string = await ctx.jsonSchemaToGrammar(JSON.stringify(schema)); assert(typeof grammar === 'string' && grammar.length > 0, `jsonSchemaToGrammar: returned ${grammar.length}-char grammar`); assert(grammar.includes('root'), 'jsonSchemaToGrammar: grammar contains "root" rule'); // Use the grammar with Branch.create to prove it's valid GBNF - const prompt = await ctx.tokenize("Output JSON: "); + const prompt: number[] = await ctx.tokenize("Output JSON: "); const branch = Branch.create(ctx, 0, { temperature: 0 }, undefined, grammar); await branch.prefill(prompt); - const output = []; + const output: string[] = []; for (let i = 0; i < 50; i++) { const { token, text, isStop } = await branch.produce(); if (isStop) break; @@ -1528,8 +1540,8 @@ async function testJsonSchemaToGrammar() { output.push(text); } - const result = output.join(''); - let parsed; + const result: string = output.join(''); + let parsed: { name: string; age: number } | undefined; try { parsed = JSON.parse(result); } catch { @@ -1547,12 +1559,12 @@ async function testJsonSchemaToGrammar() { await branch.prune(); // Error path: invalid JSON β†’ promise rejects - let rejected = false; + let rejected: boolean = false; try { await ctx.jsonSchemaToGrammar('not valid json {{{'); - } catch (e) { + } catch (err) { rejected = true; - assert(e instanceof Error, `jsonSchemaToGrammar: rejection is Error: ${e.constructor.name}`); + assert(err instanceof Error, `jsonSchemaToGrammar: rejection is Error: ${(err as Error).constructor.name}`); } assert(rejected, 'jsonSchemaToGrammar: invalid JSON rejects'); } finally { @@ -1564,10 +1576,10 @@ async function testJsonSchemaToGrammar() { // DISPOSED-DURING-ASYNC β€” _disposed set synchronously prevents use-after-prune // ═══════════════════════════════════════════════════════════════════════════ -async function testDisposedDuringAsync() { +async function testDisposedDuringAsync(): Promise { console.log('\n--- Disposed During Async ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4, @@ -1575,7 +1587,7 @@ async function testDisposedDuringAsync() { }); try { - const tokens = await ctx.tokenize("Test prompt"); + const tokens: number[] = await ctx.tokenize("Test prompt"); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(tokens); @@ -1584,13 +1596,13 @@ async function testDisposedDuringAsync() { await branch.commit(token); // Call prune() β€” DO NOT await yet - const prunePromise = branch.prune(); + const prunePromise: Promise = branch.prune(); // Immediately (before microtask resolves) check disposed assert(branch.disposed, 'disposed-during: _disposed is true synchronously after prune() call'); // produceSync() should throw synchronously - let threwProduce = false; + let threwProduce: boolean = false; try { branch.produceSync(); } catch { @@ -1599,7 +1611,7 @@ async function testDisposedDuringAsync() { assert(threwProduce, 'disposed-during: produceSync() throws before prune promise resolves'); // commit() should throw synchronously (the _ensureNotDisposed guard) - let threwCommit = false; + let threwCommit: boolean = false; try { await branch.commit(token); } catch { @@ -1623,10 +1635,10 @@ async function testDisposedDuringAsync() { // ASYNC ITERATOR β€” Branch as async iterable // ═══════════════════════════════════════════════════════════════════════════ -async function testAsyncIterator() { +async function testAsyncIterator(): Promise { console.log('\n--- Async Iterator ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4, @@ -1634,13 +1646,13 @@ async function testAsyncIterator() { }); try { - const prompt = await ctx.tokenize("The quick brown fox"); + const prompt: number[] = await ctx.tokenize("The quick brown fox"); // Generate to EOG via for-await const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(prompt); - const tokens = []; + const tokens: number[] = []; for await (const { token, text } of branch) { assert(typeof token === 'number' && typeof text === 'string', `iterator: yields {token, text} (token=${token})`); @@ -1663,7 +1675,7 @@ async function testAsyncIterator() { await ctx.kvCacheClear(); const branchManual = Branch.create(ctx, 0, { temperature: 0 }); await branchManual.prefill(prompt); - const manualTokens = []; + const manualTokens: number[] = []; for (let i = 0; i < 10; i++) { const { token, isStop } = await branchManual.produce(); if (isStop) break; @@ -1672,7 +1684,7 @@ async function testAsyncIterator() { } assert(tokens.length === manualTokens.length && - tokens.every((t, i) => t === manualTokens[i]), + tokens.every((t: number, i: number) => t === manualTokens[i]), 'iterator: output matches manual produce/commit (deterministic)'); await branchManual.prune(); @@ -1685,27 +1697,27 @@ async function testAsyncIterator() { // HOT-SWAP TESTS (setSamplerParams / setGrammar) // ═══════════════════════════════════════════════════════════════════════════ -async function testSetSamplerParams() { +async function testSetSamplerParams(): Promise { console.log('\n--- setSamplerParams ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4, }); try { - const prompt = await ctx.tokenize("The capital of France is"); + const prompt: number[] = await ctx.tokenize("The capital of France is"); // Greedy baseline const greedy = Branch.create(ctx, 0, { temperature: 0, topK: 0, topP: 1.0, minP: 0 }); await greedy.prefill(prompt); - const greedyTok = greedy.sample(); + const greedyTok: number = greedy.sample(); assert(greedyTok >= 0, `setSamplerParams: greedy token valid (${greedyTok})`); // Switch to stochastic β€” at high temp, should eventually diverge greedy.setSamplerParams({ temperature: 1.5, seed: 42, topK: 0, topP: 1.0, minP: 0 }); - let diverged = false; + let diverged: boolean = false; for (let i = 0; i < 20; i++) { if (greedy.sample() !== greedyTok) { diverged = true; break; } } @@ -1713,8 +1725,8 @@ async function testSetSamplerParams() { // Switch back to greedy β€” should be deterministic again greedy.setSamplerParams({ temperature: 0, topK: 0, topP: 1.0, minP: 0 }); - const tok2 = greedy.sample(); - const tok3 = greedy.sample(); + const tok2: number = greedy.sample(); + const tok3: number = greedy.sample(); assert(tok2 === tok3, `setSamplerParams: greedy restored (${tok2} === ${tok3})`); await greedy.prune(); @@ -1732,10 +1744,10 @@ async function testSetSamplerParams() { } } -async function testSetGrammar() { +async function testSetGrammar(): Promise { console.log('\n--- setGrammar ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4, @@ -1743,23 +1755,23 @@ async function testSetGrammar() { }); try { - const grammar = `root ::= "{" ws "}" ws + const grammar: string = `root ::= "{" ws "}" ws ws ::= [ \\t\\n]*`; // Hot-swap: create without grammar, then add one - const prompt = await ctx.tokenize("Output: "); + const prompt: number[] = await ctx.tokenize("Output: "); const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(prompt); branch.setGrammar(grammar); - const output = []; + const output: string[] = []; for (let i = 0; i < 10; i++) { const { token, text, isStop } = await branch.produce(); if (isStop) break; await branch.commit(token); output.push(text); } - const result = output.join(''); + const result: string = output.join(''); assert(/^\{\s*\}\s*$/.test(result), `setGrammar: hot-swap constrains β†’ "${result}"`); // Remove grammar @@ -1777,14 +1789,14 @@ ws ::= [ \\t\\n]*`; root.setGrammar(grammar); const child = await root.fork(); - const childOut = []; + const childOut: string[] = []; for (let i = 0; i < 10; i++) { - const p = await child.produce(); + const p: Produced = await child.produce(); if (p.isStop) break; await child.commit(p.token); childOut.push(p.text); } - const childResult = childOut.join(''); + const childResult: string = childOut.join(''); assert(/^\{\s*\}\s*$/.test(childResult), `setGrammar: fork inherits grammar β†’ "${childResult}"`); await child.prune(); @@ -1798,10 +1810,10 @@ ws ::= [ \\t\\n]*`; // BRANCH METRICS & LOGIT BIAS // ═══════════════════════════════════════════════════════════════════════════ -async function testBranchMetrics() { +async function testBranchMetrics(): Promise { console.log('\n--- Branch Metrics & Logit Bias ---'); - const ctx = await addon.createContext({ + const ctx: SessionContext = await addon.createContext({ modelPath: MODEL_PATH, nCtx: CTX_SIZE, nThreads: 4, @@ -1809,30 +1821,30 @@ async function testBranchMetrics() { }); try { - const tokens = await ctx.tokenize("The capital of France is"); + const tokens: number[] = await ctx.tokenize("The capital of France is"); const branch = Branch.create(ctx, 0, { temperature: 0.8, seed: 42 }); await branch.prefill(tokens); // branch.modelEntropy - const entropy = branch.modelEntropy('nats'); + const entropy: number = branch.modelEntropy('nats'); assert(isFinite(entropy) && entropy >= 0, `branch.modelEntropy('nats') β†’ ${entropy.toFixed(4)}`); - const entropyBits = branch.modelEntropy('bits'); + const entropyBits: number = branch.modelEntropy('bits'); assert(Math.abs(entropyBits - entropy / Math.log(2)) < 0.01, `branch.modelEntropy('bits') consistent with nats`); // branch.modelSurprisal - const token = branch.sample(); - const surprisal = branch.modelSurprisal(token, 'nats'); + const token: number = branch.sample(); + const surprisal: number = branch.modelSurprisal(token, 'nats'); assert(isFinite(surprisal) && surprisal >= 0, `branch.modelSurprisal(${token}, 'nats') β†’ ${surprisal.toFixed(4)}`); - const surprisalBits = branch.modelSurprisal(token, 'bits'); + const surprisalBits: number = branch.modelSurprisal(token, 'bits'); assert(Math.abs(surprisalBits - surprisal / Math.log(2)) < 0.01, `branch.modelSurprisal bits consistent with nats`); // branch.samplingPerplexity β€” before any commits, must be Infinity - const pplBefore = branch.samplingPerplexity; + const pplBefore: number = branch.samplingPerplexity; assert(pplBefore === Infinity, `branch.samplingPerplexity before commit should be Infinity, got ${pplBefore}`); @@ -1841,27 +1853,27 @@ async function testBranchMetrics() { const { token: t2 } = await branch.produce(); await branch.commit(t2); - const pplAfter = branch.samplingPerplexity; + const pplAfter: number = branch.samplingPerplexity; assert(isFinite(pplAfter) && pplAfter >= 1.0, `branch.samplingPerplexity after commits β†’ ${pplAfter.toFixed(4)}`); // setLogitBias β€” get greedy baseline, ban it, verify it changes const baseline = Branch.create(ctx, 0, { temperature: 0 }); await baseline.prefill(tokens); - const bannedToken = baseline.sample(); + const bannedToken: number = baseline.sample(); await baseline.prune(); const greedy = Branch.create(ctx, 0, { temperature: 0 }); await greedy.prefill(tokens); greedy.setLogitBias([{ token: bannedToken, bias: -Infinity }]); - const alternative = greedy.sample(); + const alternative: number = greedy.sample(); assert(alternative !== bannedToken, `setLogitBias: banned token ${bannedToken} not sampled (got ${alternative})`); // clearLogitBias β€” after clearing, the greedy baseline token should come back const greedy2 = Branch.create(ctx, 0, { temperature: 0 }); await greedy2.prefill(tokens); - const greedyToken = greedy2.sample(); + const greedyToken: number = greedy2.sample(); assert(greedyToken === bannedToken, `clearLogitBias: greedy token ${greedyToken} === baseline ${bannedToken}`); @@ -1870,7 +1882,7 @@ async function testBranchMetrics() { await parent.prefill(tokens); parent.setLogitBias([{ token: bannedToken, bias: -Infinity }]); const child = await parent.fork(); - const childToken = child.sample(); + const childToken: number = child.sample(); assert(childToken !== bannedToken, `setLogitBias cloned on fork: child doesn't sample banned token`); @@ -1883,12 +1895,227 @@ async function testBranchMetrics() { } } +// ═══════════════════════════════════════════════════════════════════════════ +// RERANK TESTS (optional) +// ═══════════════════════════════════════════════════════════════════════════ + +async function testRerank(): Promise { + if (!RERANK_MODEL_PATH) { + console.log('\n--- Rerank (SKIPPED - no LLAMA_RERANK_MODEL) ---'); + return; + } + + console.log('\n--- Rerank ---'); + console.log(` Model: ${path.basename(RERANK_MODEL_PATH)}`); + + const rerank = await Rerank.create({ modelPath: RERANK_MODEL_PATH }); + + try { + // Tokenize documents + const query = 'What is the capital of France?'; + const docs = [ + 'Berlin is the capital of Germany and its largest city.', + 'Paris is the capital and most populous city of France.', + 'The Amazon rainforest produces about 20% of the world\'s oxygen.', + 'France is a country in Western Europe, with its capital being Paris.', + ]; + const tokenized: number[][] = await Promise.all(docs.map(d => rerank.tokenize(d))); + + // Score all documents β€” drain async iterable to final progress + let results!: { score: number; index: number }[]; + let progressCount = 0; + for await (const p of rerank.score(query, tokenized)) { + progressCount++; + results = p.results; + } + assert(progressCount > 0, `rerank: received progress updates (got ${progressCount})`); + + // All results returned (no topK) + assert(results.length === docs.length, + `rerank: returns all ${docs.length} results when no topK`); + + // Scores are valid probabilities (0-1) and sorted descending + for (let i = 0; i < results.length; i++) { + assert(results[i].score >= 0 && results[i].score <= 1, + `rerank: score[${i}] = ${results[i].score} is in [0, 1]`); + assert(Number.isInteger(results[i].index) && results[i].index >= 0 && results[i].index < docs.length, + `rerank: index[${i}] = ${results[i].index} is valid`); + if (i > 0) { + assert(results[i].score <= results[i - 1].score, + `rerank: sorted descending (${results[i - 1].score} >= ${results[i].score})`); + } + } + + // Semantic: Paris docs (index 1, 3) should rank above Amazon doc (index 2) + const topIndices = results.slice(0, 2).map(r => r.index); + assert(topIndices.includes(1) || topIndices.includes(3), + `rerank: a Paris doc in top 2 (top indices: [${topIndices}])`); + + const amazonRank = results.findIndex(r => r.index === 2); + assert(amazonRank >= 2, + `rerank: Amazon doc not in top 2 (rank: ${amazonRank})`); + ok(`rerank: semantic ordering correct (top: [${topIndices}], amazon rank: ${amazonRank})`); + + // topK parameter + let top2!: { score: number; index: number }[]; + for await (const p of rerank.score(query, tokenized, 2)) { top2 = p.results; } + assert(top2.length === 2, `rerank: topK=2 returns 2 results`); + assert(top2[0].score === results[0].score && top2[0].index === results[0].index, + `rerank: topK=2 matches top of full results`); + + // tokenize() produces consistent output + const tokens1 = await rerank.tokenize('hello'); + const tokens2 = await rerank.tokenize('hello'); + assert(tokens1.length === tokens2.length && tokens1.every((t, i) => t === tokens2[i]), + `rerank: tokenize() is deterministic`); + } finally { + rerank.dispose(); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// LARGE CORPUS RERANK β€” >n_seq_max documents via C++ grouping +// ═══════════════════════════════════════════════════════════════════════════ + +async function testRerankLargeCorpus(): Promise { + if (!RERANK_MODEL_PATH) { + console.log('\n--- Rerank Large Corpus (SKIPPED - no LLAMA_RERANK_MODEL) ---'); + return; + } + + console.log('\n--- Rerank Large Corpus ---'); + console.log(` Model: ${path.basename(RERANK_MODEL_PATH)}`); + + // n_seq_max=8 so 20 documents requires 3 groups (8+8+4) + const rerank = await Rerank.create({ modelPath: RERANK_MODEL_PATH, nSeqMax: 8 }); + + try { + const query = 'What is the capital of France?'; + const relevantDoc = 'Paris is the capital and most populous city of France.'; + + // Build 20 documents: 1 relevant + 19 irrelevant + const docTexts: string[] = [ + relevantDoc, + 'The Amazon rainforest produces about 20% of the world\'s oxygen.', + 'Berlin is the capital of Germany and its largest city.', + 'The Great Wall of China is over 13,000 miles long.', + 'Tokyo is the most populous metropolitan area in the world.', + 'The Sahara Desert is the largest hot desert in the world.', + 'Mount Everest is the highest mountain above sea level.', + 'The Pacific Ocean is the largest and deepest ocean.', + 'Antarctica is the coldest continent on Earth.', + 'The Nile is traditionally considered the longest river.', + 'Australia is both a country and a continent.', + 'The human body contains approximately 206 bones.', + 'Jupiter is the largest planet in our solar system.', + 'The speed of light is approximately 299,792 kilometers per second.', + 'DNA was first identified by Friedrich Miescher in 1869.', + 'The International Space Station orbits Earth every 90 minutes.', + 'Honey never spoils due to its low moisture content.', + 'Venice is built on more than 100 small islands.', + 'The deepest point in the ocean is the Mariana Trench.', + 'Photosynthesis converts carbon dioxide and water into glucose.', + ]; + + const tokenized: number[][] = await Promise.all(docTexts.map(d => rerank.tokenize(d))); + assert(tokenized.length === 20, `large corpus: 20 documents tokenized`); + + // Single score() call β€” drain iterable, verify incremental progress + let results!: { score: number; index: number }[]; + let progressCount = 0; + for await (const p of rerank.score(query, tokenized)) { + progressCount++; + assert(p.total === 20, `large corpus: total is 20 (got ${p.total})`); + assert(p.filled <= p.total, `large corpus: filled ${p.filled} <= total ${p.total}`); + results = p.results; + } + assert(progressCount >= 3, `large corpus: β‰₯3 progress updates for 20 docs / n_seq_max=8 (got ${progressCount})`); + + assert(results.length === 20, `large corpus: all 20 results returned`); + + // Scores sorted descending + for (let i = 1; i < results.length; i++) { + assert(results[i].score <= results[i - 1].score, + `large corpus: sorted descending at index ${i}`); + } + + // Relevant doc (index 0) should rank in top 3 + const relevantRank = results.findIndex(r => r.index === 0); + assert(relevantRank < 3, + `large corpus: relevant doc ranks ${relevantRank} (expected < 3)`); + + // topK across group boundary + let top5!: { score: number; index: number }[]; + for await (const p of rerank.score(query, tokenized, 5)) { top5 = p.results; } + assert(top5.length === 5, `large corpus: topK=5 returns 5 results`); + assert(top5[0].score === results[0].score && top5[0].index === results[0].index, + `large corpus: topK=5 top result matches full ranking`); + + ok(`large corpus: 20 docs with n_seq_max=8 β†’ relevant doc at rank ${relevantRank}`); + } finally { + rerank.dispose(); + } +} + +async function testRerankConcurrent(): Promise { + if (!RERANK_MODEL_PATH) { + console.log('\n--- Rerank Concurrent (SKIPPED - no LLAMA_RERANK_MODEL) ---'); + return; + } + + console.log('\n--- Rerank Concurrent ---'); + + const rerank = await Rerank.create({ modelPath: RERANK_MODEL_PATH, nSeqMax: 4 }); + + try { + const docs = [ + 'Paris is the capital of France.', + 'Machine learning is a branch of artificial intelligence.', + 'The sun is a star at the center of the solar system.', + 'Deep learning uses neural networks with many layers.', + 'London is the capital of the United Kingdom.', + 'Gradient descent is an optimization algorithm.', + ]; + const tokenized: number[][] = await Promise.all(docs.map(d => rerank.tokenize(d))); + + // Drain helper β€” collects final results from score() async iterable + async function drain(iter: AsyncIterable<{ results: { score: number; index: number }[] }>) + : Promise<{ score: number; index: number }[]> { + let last!: { score: number; index: number }[]; + for await (const p of iter) last = p.results; + return last; + } + + // Fire both score calls concurrently β€” exercises the queue's round-robin + const [r1, r2] = await Promise.all([ + drain(rerank.score('What is the capital of France?', tokenized)), + drain(rerank.score('What is machine learning?', tokenized)), + ]); + + assert(r1.length === docs.length, 'concurrent: caller 1 gets all results'); + assert(r2.length === docs.length, 'concurrent: caller 2 gets all results'); + + // Paris doc (index 0) should rank high for query 1 + assert(r1[0].index === 0 || r1[1].index === 0, + `concurrent: Paris doc in top 2 for query 1 (got [${r1.slice(0, 2).map(r => r.index)}])`); + + // ML docs (index 1 or 3 or 5) should rank high for query 2 + const top2q2 = r2.slice(0, 2).map(r => r.index); + assert(top2q2.includes(1) || top2q2.includes(3) || top2q2.includes(5), + `concurrent: ML doc in top 2 for query 2 (got [${top2q2}])`); + + ok(`concurrent: two callers scored ${docs.length} docs each with independent results`); + } finally { + rerank.dispose(); + } +} + // ═══════════════════════════════════════════════════════════════════════════ // MAIN // ═══════════════════════════════════════════════════════════════════════════ -async function main() { - let mainCtx = null; +async function main(): Promise { + let mainCtx: SessionContext | null = null; try { // Create main context for reusable tests @@ -1927,6 +2154,9 @@ async function main() { await testSetSamplerParams(); await testSetGrammar(); await testBranchMetrics(); + await testRerank(); + await testRerankLargeCorpus(); + await testRerankConcurrent(); await testEmbeddings(); // Summary @@ -1942,8 +2172,8 @@ async function main() { process.exit(1); } } catch (err) { - console.error('\nFatal error:', err.message); - console.error(err.stack); + console.error('\nFatal error:', (err as Error).message); + console.error((err as Error).stack); process.exit(1); } finally { if (mainCtx) mainCtx.dispose(); diff --git a/tsconfig.json b/tsconfig.json new file mode 100644 index 0000000..40c7ec9 --- /dev/null +++ b/tsconfig.json @@ -0,0 +1,18 @@ +{ + "compilerOptions": { + "target": "ES2022", + "module": "CommonJS", + "moduleResolution": "Node", + "lib": ["ES2022"], + "outDir": "dist", + "rootDir": "src", + "declaration": true, + "declarationMap": true, + "sourceMap": true, + "strict": true, + "skipLibCheck": true, + "types": ["node"] + }, + "include": ["src/**/*.ts"], + "exclude": ["node_modules", "dist", "build"] +} diff --git a/tsconfig.test.json b/tsconfig.test.json new file mode 100644 index 0000000..7fd54b7 --- /dev/null +++ b/tsconfig.test.json @@ -0,0 +1,14 @@ +{ + "extends": "./tsconfig.json", + "compilerOptions": { + "rootDir": ".", + "outDir": ".", + "declaration": false, + "declarationMap": false, + "sourceMap": false, + "skipLibCheck": true, + "noEmitOnError": false + }, + "include": ["test/**/*.ts"], + "exclude": ["node_modules", "dist", "build"] +} diff --git a/typedoc.json b/typedoc.json index a638334..763d6ac 100644 --- a/typedoc.json +++ b/typedoc.json @@ -1,7 +1,7 @@ { "$schema": "https://typedoc.org/schema.json", "plugin": ["typedoc-rhineai-theme"], - "entryPoints": ["lib/index.d.ts"], + "entryPoints": ["src/index.ts"], "out": "docs/api", "name": "lloyal.node API Reference", "includeVersion": true, @@ -26,6 +26,7 @@ "Sampling", "Chat", "Branching", + "Agents", "*" ], "sort": ["kind", "instance-first", "required-first", "alphabetical"],