diff --git a/examples/entropy/entropy.mjs b/examples/entropy/entropy.mjs index 6453e22..7618567 100644 --- a/examples/entropy/entropy.mjs +++ b/examples/entropy/entropy.mjs @@ -73,8 +73,7 @@ async function generate(ctx, prompt, strategy, strategyName, maxTokens = 50) { const entropies = []; for (let i = 0; i < maxTokens; i++) { - const branchLogits = branch.getLogits(); - const entropy = ctx.modelEntropy('nats', branchLogits); + const entropy = branch.modelEntropy('nats'); entropies.push(entropy); const temp = strategy === 'edt' ? edtTemperature(entropy) : strategy; diff --git a/examples/speculative/speculative.mjs b/examples/speculative/speculative.mjs index cf927d7..93bc111 100644 --- a/examples/speculative/speculative.mjs +++ b/examples/speculative/speculative.mjs @@ -137,7 +137,7 @@ async function main() { for (let i = 0; i < DRAFT_COUNT && output.length + drafts.length < GENERATION_LENGTH; i++) { // Get entropy BEFORE sampling (from draft branch's logits snapshot) - const entropy = ctx.modelEntropy('nats', draft.getLogits()); + const entropy = draft.modelEntropy('nats'); // produce() samples from captured logits (no KV write yet) const { token, text, isStop } = draft.produceSync(); diff --git a/examples/streaming/streaming-summary.mjs b/examples/streaming/streaming-summary.mjs index 769b86c..27221ca 100644 --- a/examples/streaming/streaming-summary.mjs +++ b/examples/streaming/streaming-summary.mjs @@ -328,8 +328,7 @@ Begin: break; } - const branchLogits = branch.getLogits(); - const surprisal = ctx.modelSurprisal(token, 'nats', branchLogits); + const surprisal = branch.modelSurprisal(token, 'nats'); nllSum += Math.max(0, surprisal); nllCount++; diff --git a/examples/streaming/streaming-tsampler.mjs b/examples/streaming/streaming-tsampler.mjs index 96cc40e..ec41ec5 100644 --- a/examples/streaming/streaming-tsampler.mjs +++ b/examples/streaming/streaming-tsampler.mjs @@ -248,8 +248,8 @@ Begin: // tokenHistory.accept(token); // Disabled - matching baseline ngramTracker.accept(token); - // Track surprisal from original (unmodified) logits - const surprisal = ctx.modelSurprisal(token, 'nats', originalLogits); + // Track surprisal from branch's logits snapshot (before N-gram steering) + const surprisal = branch.modelSurprisal(token, 'nats'); nllSum += Math.max(0, surprisal); nllCount++; diff --git a/examples/streaming/streaming.mjs b/examples/streaming/streaming.mjs index 96e4e20..e877e64 100644 --- a/examples/streaming/streaming.mjs +++ b/examples/streaming/streaming.mjs @@ -121,8 +121,7 @@ Begin: } // Track surprisal from the logits used by produce() - const branchLogits = branch.getLogits(); - const surprisal = ctx.modelSurprisal(token, 'nats', branchLogits); + const surprisal = branch.modelSurprisal(token, 'nats'); nllSum += Math.max(0, surprisal); nllCount++; diff --git a/lib/Branch.js b/lib/Branch.js index c3b7a7c..b7ee396 100644 --- a/lib/Branch.js +++ b/lib/Branch.js @@ -336,6 +336,59 @@ class Branch { await this._ctx._storeCommit([this._handle], [token]); } + // ===== METRICS ===== + + /** + * Compute entropy of the branch's logits distribution + * + * @param {'nats'|'bits'} [base='nats'] + * @returns {number} + */ + modelEntropy(base = 'nats') { + this._ensureNotDisposed(); + return this._ctx._branchModelEntropy(this._handle, base); + } + + /** + * Compute surprisal for a specific token from the branch's logits + * + * @param {number} token + * @param {'nats'|'bits'} [base='nats'] + * @returns {number} + */ + modelSurprisal(token, base = 'nats') { + this._ensureNotDisposed(); + return this._ctx._branchModelSurprisal(this._handle, token, base); + } + + /** + * Sampling-level perplexity (from filtered distribution) + * + * @returns {number} + */ + get samplingPerplexity() { + this._ensureNotDisposed(); + return this._ctx._branchGetSamplingPerplexity(this._handle); + } + + /** + * Set static logit biases on this branch (cloned on fork) + * + * @param {Array<{token: number, bias: number}>} biases + */ + setLogitBias(biases) { + this._ensureNotDisposed(); + this._ctx._branchSetLogitBias(this._handle, biases); + } + + /** + * Clear all static logit biases from this branch + */ + clearLogitBias() { + this._ensureNotDisposed(); + this._ctx._branchClearLogitBias(this._handle); + } + // ===== ACCESSORS ===== /** @returns {number} Branch's current position (number of tokens decoded) */ diff --git a/lib/index.d.ts b/lib/index.d.ts index 07721a8..a769956 100644 --- a/lib/index.d.ts +++ b/lib/index.d.ts @@ -620,10 +620,10 @@ export interface SamplingParams { * per-branch KV sequences, sampler chains, and logits snapshots with O(1) * GPU dispatches via batched decode. * - * **Logits lifetime**: `getLogits()` returns a zero-copy Float32Array wrapping - * llama.cpp's internal buffer. It is invalidated (ArrayBuffer detached) on - * the next `encode()` or `dispose()`. Use {@link withLogits} for safe scoped - * access. For branch-level logits, use {@link Branch.getLogits} instead. + * **Logits**: For branch-level logits, use {@link Branch.getLogits} which + * returns an independent copy of the branch's snapshot. For metrics, use + * {@link Branch.modelEntropy} and {@link Branch.modelSurprisal} which + * operate directly on the branch's logits without JS round-trips. * * **KV cache**: Supports multi-sequence operation (`nSeqMax > 1`), per-sequence * copy/clear/eviction, file-based persistence, and context compression via @@ -640,32 +640,6 @@ export interface SamplingParams { */ export interface SessionContext { - /** - * Get logits (zero-copy view into model memory) - * - * Returns unnormalized scores for every possible next token. - * Higher score = model thinks this token is more likely. - * The returned Float32Array wraps llama.cpp's internal buffer directly - * (zero-copy). - * - * LIFETIME CONSTRAINTS: - * - Valid ONLY until the next encode() or dispose() call - * - The ArrayBuffer is detached on invalidation — accessing a stale - * buffer throws a TypeError - * - DO NOT retain references across async boundaries - * - * For a safe scoped access pattern, use {@link withLogits} instead. - * - * NOTE: This is the context-level logits view. For branch-level logits, - * use {@link Branch.getLogits} which returns an independent copy of the - * branch's snapshot (safe to hold, not invalidated by decode). - * - * Cost: ~0.5ms (zero-copy pointer, no data copied) - * - * @returns Float32Array of unnormalized logits (vocabSize elements) - */ - getLogits(): Float32Array; - /** * Convert token ID to text piece * @@ -1005,56 +979,6 @@ export interface SessionContext { */ kvSeqPosMax(seqId: number): number; - // ===== METRICS API ===== - - /** - * Compute surprisal (negative log-likelihood) for a specific token. - * - * Measures how "surprising" the model finds the given token: - * - Low surprisal: Model expected this token (high probability) - * - High surprisal: Model didn't expect this token (low probability) - * - * Pass captured logits (e.g., from {@link Branch.getLogits}) for - * offline computation, or omit to use the current context logits. - * - * @param pickedTokenId - Token ID to compute surprisal for - * @param base - Logarithm base: "nats" (default) or "bits" - * @param logits - Optional Float32Array of logits (uses current context logits if omitted) - * @returns Surprisal value in specified base - * - * @example With branch logits - * ```typescript - * const { token } = await branch.produce(); - * const surprisal = ctx.modelSurprisal(token, "bits", branch.getLogits()); - * ``` - * - * COST: O(n_vocab) - softmax normalization required - */ - modelSurprisal(pickedTokenId: number, base?: 'nats' | 'bits', logits?: Float32Array): number; - - /** - * Compute entropy of the entire logits distribution. - * - * Measures model uncertainty: - * - Low entropy: Model is confident (peaked distribution) - * - High entropy: Model is uncertain (flat distribution) - * - * Pass captured logits (e.g., from {@link Branch.getLogits}) for - * offline analysis, or omit to use the current context logits. - * - * @param base - Logarithm base: "nats" (default) or "bits" - * @param logits - Optional Float32Array of logits (uses current context logits if omitted) - * @returns Entropy value in specified base - * - * @example With branch logits - * ```typescript - * const entropy = ctx.modelEntropy("bits", branch.getLogits()); - * ``` - * - * COST: O(n_vocab) - must sum over all token probabilities - */ - modelEntropy(base?: 'nats' | 'bits', logits?: Float32Array): number; - // ===== KV CACHE FILE PERSISTENCE ===== /** @@ -1352,7 +1276,7 @@ export interface SessionContext { /** * Model vocabulary size (number of possible tokens) * - * This is the length of the logits array from getLogits(). + * This is the length of the logits array from Branch.getLogits(). */ readonly vocabSize: number; @@ -1433,6 +1357,21 @@ export interface SessionContext { /** @internal Replace or remove grammar constraint */ _branchSetGrammar(handle: number, grammarStr: string): void; + /** @internal Compute entropy from branch's logits snapshot */ + _branchModelEntropy(handle: number, base?: string): number; + + /** @internal Compute surprisal from branch's logits snapshot */ + _branchModelSurprisal(handle: number, token: number, base?: string): number; + + /** @internal Get sampling-level perplexity */ + _branchGetSamplingPerplexity(handle: number): number; + + /** @internal Set static logit biases on a branch */ + _branchSetLogitBias(handle: number, biases: Array<{ token: number; bias: number }>): void; + + /** @internal Clear all static logit biases from a branch */ + _branchClearLogitBias(handle: number): void; + // ===== STORE API (internal, wrapped by BranchStore) ===== /** @internal Batched accept + decode_each + capture for N branches */ @@ -1564,30 +1503,6 @@ export function loadBinary(variant?: GpuVariant): { createContext(options: ContextOptions): Promise; }; -/** - * Safe logits access with automatic lifetime management - * - * Ensures logits are only accessed synchronously within the callback. - * The callback MUST NOT: - * - Store the logits reference - * - Return a Promise (will throw) - * - * This prevents common bugs where logits become invalid due to - * async operations between access and usage. - * - * @template T Return type of the callback - * @param ctx The session context - * @param fn Synchronous callback that uses logits - must not return a Promise - * @returns The result from the callback - * @throws Error if callback returns a Promise (async usage not allowed) - * - * @category Core - */ -export function withLogits( - ctx: SessionContext, - fn: (logits: Float32Array) => T -): T; - /** * Result from Branch.produce() * @@ -1681,11 +1596,9 @@ export class Branch { * Returns n_vocab floats — the raw logit distribution from the last * prefill() or commit() call. * - * Unlike {@link SessionContext.getLogits} (zero-copy view into shared - * model memory, invalidated by next decode), this returns an independent - * copy of the branch's internal snapshot. The returned Float32Array is - * safe to hold across async boundaries and is not affected by subsequent - * decode operations. + * Returns an independent copy of the branch's internal snapshot. + * The returned Float32Array is safe to hold across async boundaries + * and is not affected by subsequent decode operations. * * @returns Independent copy of the logits snapshot (n_vocab elements) * @throws If no logits have been captured yet @@ -1842,6 +1755,70 @@ export class Branch { */ clearSteer(): void; + /** + * Compute entropy of the branch's logits distribution + * + * Measures model uncertainty from the branch's captured logits snapshot: + * - Low entropy: Model is confident (peaked distribution) + * - High entropy: Model is uncertain (flat distribution) + * + * Operates directly on `state->logits_snapshot` — no JS round-trip. + * + * @param base - Logarithm base: "nats" (default) or "bits" + * @returns Entropy value in specified base + * + * COST: O(n_vocab) - must sum over all token probabilities + */ + modelEntropy(base?: 'nats' | 'bits'): number; + + /** + * Compute surprisal (negative log-likelihood) for a specific token + * + * Measures how "surprising" the model finds the given token from + * the branch's captured logits snapshot: + * - Low surprisal: Model expected this token (high probability) + * - High surprisal: Model didn't expect this token (low probability) + * + * Operates directly on `state->logits_snapshot` — no JS round-trip. + * + * @param token - Token ID to compute surprisal for + * @param base - Logarithm base: "nats" (default) or "bits" + * @returns Surprisal value in specified base + * + * COST: O(n_vocab) - softmax normalization required + */ + modelSurprisal(token: number, base?: 'nats' | 'bits'): number; + + /** + * Sampling-level perplexity (from filtered distribution) + * + * Returns perplexity from the distribution actually sampled from + * (after top-k/p/temp/penalties). Useful for policy priors and + * monitoring sampler chain impact. + * + * Compare with {@link perplexity} which is model-level (raw logits). + */ + readonly samplingPerplexity: number; + + /** + * Set static logit biases on this branch + * + * Unlike {@link steer} (which is NOT inherited on fork), logit biases + * ARE cloned when forking. Use for persistent constraints that should + * propagate to child branches. + * + * Applied during sample() in order: Grammar -> Logit Bias -> Steer -> Sampler Chain + * + * @param biases - Array of token adjustments. Use `-Infinity` to block, + * positive to boost, negative to reduce. + */ + setLogitBias(biases: Array<{ token: number; bias: number }>): void; + + /** + * Clear all static logit biases from this branch + */ + clearLogitBias(): void; + /** * Replace the sampler chain with new parameters (memoized) * @@ -1854,7 +1831,7 @@ export class Branch { * * @example Entropy-Driven Temperature * ```typescript - * const entropy = ctx.modelEntropy('nats', branch.getLogits()); + * const entropy = branch.modelEntropy('nats'); * branch.setSamplerParams({ temperature: edtTemperature(entropy) }); * const { token } = await branch.produce(); * await branch.commit(token); diff --git a/lib/index.js b/lib/index.js index 928c3f7..e7719bc 100644 --- a/lib/index.js +++ b/lib/index.js @@ -6,7 +6,7 @@ * * @example * ```js - * const { createContext, withLogits } = require('@lloyal-labs/lloyal.node'); + * const { createContext } = require('@lloyal-labs/lloyal.node'); * * const ctx = await createContext({ * modelPath: './model.gguf', @@ -195,61 +195,6 @@ const getBinary = () => { return _binary; }; -/** - * Safe logits access with Runtime Borrow Checker pattern - * - * Ensures logits are only accessed synchronously within the callback. - * The callback MUST NOT: - * - Store the logits reference - * - Return a Promise (will throw) - * - Call decode() (would invalidate logits) - * - * This is a "runtime borrow checker" - it prevents async mutations - * while you're working with borrowed logits. - * - * @template T - * @param {SessionContext} ctx - The session context - * @param {(logits: Float32Array) => T} fn - Synchronous callback that uses logits - * @returns {T} The result from the callback - * @throws {Error} If callback returns a Promise (async usage not allowed) - * - * @example - * ```js - * // Safe: synchronous computation - * const entropy = withLogits(ctx, (logits) => { - * let sum = 0; - * for (let i = 0; i < logits.length; i++) { - * sum += Math.exp(logits[i]); - * } - * return Math.log(sum); - * }); - * - * // ERROR: callback returns Promise (will throw) - * withLogits(ctx, async (logits) => { - * await something(); // NOT ALLOWED - * return logits[0]; - * }); - * ``` - */ -function withLogits(ctx, fn) { - // Get logits (memoized - same buffer if called twice in same step) - const logits = ctx.getLogits(); - - // Execute user callback with logits - const result = fn(logits); - - // Detect async usage (not allowed - logits would be invalidated) - if (result && typeof result.then === 'function') { - throw new Error( - 'withLogits callback must be synchronous. ' + - 'Returning a Promise is not allowed because logits become invalid after decode(). ' + - 'Complete all logits processing synchronously within the callback.' - ); - } - - return result; -} - const { Branch } = require('./Branch'); const { BranchStore } = require('./BranchStore'); @@ -311,10 +256,4 @@ module.exports = { * ``` */ loadBinary, - - /** - * Safe logits access with Runtime Borrow Checker pattern. - * See function JSDoc for full documentation. - */ - withLogits, }; diff --git a/liblloyal b/liblloyal index b0a30f6..5fc01c9 160000 --- a/liblloyal +++ b/liblloyal @@ -1 +1 @@ -Subproject commit b0a30f6bf9ad313fcb3a4d03fb58cc3b34934f7f +Subproject commit 5fc01c9d72d670865910696cc89c51661744b97a diff --git a/src/SessionContext.cpp b/src/SessionContext.cpp index baf945b..9579426 100644 --- a/src/SessionContext.cpp +++ b/src/SessionContext.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -738,7 +737,6 @@ class JsonSchemaToGrammarWorker : public Napi::AsyncWorker { Napi::Object SessionContext::Init(Napi::Env env, Napi::Object exports) { Napi::Function func = DefineClass(env, "SessionContext", { // ===== CORE ===== - InstanceMethod("getLogits", &SessionContext::getLogits), InstanceMethod("tokenToText", &SessionContext::tokenToText), InstanceMethod("isStopToken", &SessionContext::isStopToken), InstanceMethod("getEogToken", &SessionContext::getEogToken), @@ -775,10 +773,6 @@ Napi::Object SessionContext::Init(Napi::Env env, Napi::Object exports) { InstanceMethod("getEmbeddingDimension", &SessionContext::getEmbeddingDimension), InstanceMethod("hasPooling", &SessionContext::hasPooling), - // ===== METRICS API ===== - InstanceMethod("modelSurprisal", &SessionContext::modelSurprisal), - InstanceMethod("modelEntropy", &SessionContext::modelEntropy), - // ===== LIFECYCLE ===== InstanceMethod("dispose", &SessionContext::dispose), @@ -802,6 +796,11 @@ Napi::Object SessionContext::Init(Napi::Env env, Napi::Object exports) { InstanceMethod("_branchClearSteer", &SessionContext::_branchClearSteer), InstanceMethod("_branchSetSamplerParams", &SessionContext::_branchSetSamplerParams), InstanceMethod("_branchSetGrammar", &SessionContext::_branchSetGrammar), + InstanceMethod("_branchModelEntropy", &SessionContext::_branchModelEntropy), + InstanceMethod("_branchModelSurprisal", &SessionContext::_branchModelSurprisal), + InstanceMethod("_branchGetSamplingPerplexity", &SessionContext::_branchGetSamplingPerplexity), + InstanceMethod("_branchSetLogitBias", &SessionContext::_branchSetLogitBias), + InstanceMethod("_branchClearLogitBias", &SessionContext::_branchClearLogitBias), // ===== STORE API (internal, wrapped by lib/BranchStore.js) ===== InstanceMethod("_storeCommit", &SessionContext::_storeCommit), @@ -858,80 +857,6 @@ void SessionContext::initializeContext( std::cerr << " Shared refcount: " << _model.use_count() << std::endl; } -// ===== LOGITS BUFFER MANAGEMENT ===== - -void SessionContext::invalidateLogits() { - // The Kill Switch: Detach any active logits buffer - // - // This is called before any operation that invalidates the logits pointer: - // - decode() - new forward pass overwrites logits - // - encode() - embedding pass overwrites logits - // - dispose() - context is destroyed - // - // After detach, any JS code holding a reference to the buffer will get - // a TypeError when trying to access it - exactly what we want. - if (!_logitsBufferRef.IsEmpty()) { - try { - Napi::ArrayBuffer buffer = _logitsBufferRef.Value(); - if (!buffer.IsDetached()) { - buffer.Detach(); - } - } catch (...) { - // Buffer may have been garbage collected - that's fine - } - _logitsBufferRef.Reset(); - } - - // Increment step counter - any new getLogits() call will create fresh buffer - _decodeStepId++; -} - -Napi::Value SessionContext::getLogits(const Napi::CallbackInfo& info) { - Napi::Env env = info.Env(); - ensureNotDisposed(); - - if (!_context) { - throw Napi::Error::New(env, "Context not initialized"); - } - - // ===== MEMOIZATION: Return same buffer if already created for this step ===== - // - // Pattern: "Memoized Step-Scoped Views" - // If caller calls getLogits() twice in the same step, return the same buffer. - // This avoids creating multiple views into the same memory. - if (_logitsStepId == _decodeStepId && !_logitsBufferRef.IsEmpty()) { - // Same step, reuse existing buffer - Napi::ArrayBuffer existingBuffer = _logitsBufferRef.Value(); - const int n_vocab = lloyal::tokenizer::vocab_size(_model.get()); - return Napi::Float32Array::New(env, n_vocab, existingBuffer, 0); - } - - // ===== NEW BUFFER: Get logits via lloyal wrapper (handles null checks) ===== - // - // lloyal::logits::get() throws descriptive errors if: - // - Context is null - // - Logits unavailable (decode() not called with logits=true) - float* logits; - try { - logits = lloyal::logits::get(_context, -1); - } catch (const std::exception& e) { - throw Napi::Error::New(env, e.what()); - } - - const int n_vocab = lloyal::tokenizer::vocab_size(_model.get()); - - // Create ArrayBuffer wrapping the logits (zero-copy!) - // Store reference for memoization and future revocation - Napi::ArrayBuffer buffer = Napi::ArrayBuffer::New(env, logits, n_vocab * sizeof(float)); - - // Store weak reference for memoization - _logitsBufferRef = Napi::Reference::New(buffer, 1); - _logitsStepId = _decodeStepId; - - // Return Float32Array view - return Napi::Float32Array::New(env, n_vocab, buffer, 0); -} - Napi::Value SessionContext::tokenize(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); ensureNotDisposed(); @@ -981,97 +906,6 @@ Napi::Value SessionContext::detokenize(const Napi::CallbackInfo& info) { return worker->GetPromise(); } -// ===== METRICS API ===== - -Napi::Value SessionContext::modelSurprisal(const Napi::CallbackInfo& info) { - Napi::Env env = info.Env(); - ensureNotDisposed(); - - // Argument validation - if (info.Length() < 1 || !info[0].IsNumber()) { - throw Napi::TypeError::New(env, "Expected number (pickedTokenId)"); - } - - int32_t pickedTokenId = info[0].As().Int32Value(); - - // Optional base parameter (default: "nats") - std::string baseStr = "nats"; - if (info.Length() >= 2 && info[1].IsString()) { - baseStr = info[1].As().Utf8Value(); - } - - lloyal::metrics::Base base = parseBase(baseStr); - - // Get logits - either from provided Float32Array or from current context - float* logits; - int n_vocab; - - if (info.Length() >= 3 && info[2].IsTypedArray()) { - // Use provided logits (for captured/arbitrary logits) - auto arr = info[2].As(); - if (arr.TypedArrayType() != napi_float32_array) { - throw Napi::TypeError::New(env, "Expected Float32Array for logits parameter"); - } - auto float32Arr = info[2].As(); - logits = float32Arr.Data(); - n_vocab = static_cast(float32Arr.ElementLength()); - } else { - // Use current context logits (default behavior) - try { - logits = lloyal::logits::get(_context, -1); - } catch (const std::exception& e) { - throw Napi::Error::New(env, e.what()); - } - n_vocab = lloyal::tokenizer::vocab_size(_model.get()); - } - - // Compute surprisal - float surprisal = lloyal::metrics::model_surprisal(logits, n_vocab, pickedTokenId, base); - - return Napi::Number::New(env, static_cast(surprisal)); -} - -Napi::Value SessionContext::modelEntropy(const Napi::CallbackInfo& info) { - Napi::Env env = info.Env(); - ensureNotDisposed(); - - // Optional base parameter (default: "nats") - std::string baseStr = "nats"; - if (info.Length() >= 1 && info[0].IsString()) { - baseStr = info[0].As().Utf8Value(); - } - - lloyal::metrics::Base base = parseBase(baseStr); - - // Get logits - either from provided Float32Array or from current context - float* logits; - int n_vocab; - - if (info.Length() >= 2 && info[1].IsTypedArray()) { - // Use provided logits (for captured/arbitrary logits) - auto arr = info[1].As(); - if (arr.TypedArrayType() != napi_float32_array) { - throw Napi::TypeError::New(env, "Expected Float32Array for logits parameter"); - } - auto float32Arr = info[1].As(); - logits = float32Arr.Data(); - n_vocab = static_cast(float32Arr.ElementLength()); - } else { - // Use current context logits (default behavior) - try { - logits = lloyal::logits::get(_context, -1); - } catch (const std::exception& e) { - throw Napi::Error::New(env, e.what()); - } - n_vocab = lloyal::tokenizer::vocab_size(_model.get()); - } - - // Compute entropy using metrics.hpp - float entropy = lloyal::metrics::model_entropy(logits, n_vocab, base); - - return Napi::Number::New(env, static_cast(entropy)); -} - Napi::Value SessionContext::tokenToText(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); ensureNotDisposed(); @@ -1098,9 +932,6 @@ Napi::Value SessionContext::encode(const Napi::CallbackInfo& info) { throw Napi::TypeError::New(env, "Expected (tokens: number[])"); } - // Revoke any active logits buffer before encode - invalidateLogits(); - // Extract tokens Napi::Array jsTokens = info[0].As(); std::vector tokens; @@ -1300,9 +1131,6 @@ Napi::Value SessionContext::dispose(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); if (!_disposed) { - // Revoke any active logits buffer before disposing - invalidateLogits(); - // Drain branch store while context is still alive _branchStore.drain(); @@ -1521,11 +1349,6 @@ Napi::Value SessionContext::kvCacheRemove(const Napi::CallbackInfo& info) { throw Napi::TypeError::New(env, "Expected (sequenceId: number, start: number, end: number)"); } - // CRITICAL: Invalidate logits before KV cache modification - // Logits may reference positions that will be evicted - // (matches pattern from decode() line 801, encode() line 1035) - invalidateLogits(); - double sequenceId = info[0].As().DoubleValue(); double start = info[1].As().DoubleValue(); double end = info[2].As().DoubleValue(); @@ -2114,6 +1937,137 @@ Napi::Value SessionContext::_branchSetGrammar(const Napi::CallbackInfo& info) { return env.Undefined(); } +// ===== BRANCH METRICS & LOGIT BIAS ===== + +Napi::Value SessionContext::_branchModelEntropy(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + if (info.Length() < 1) { + throw Napi::TypeError::New(env, "_branchModelEntropy requires (handle[, base])"); + } + + auto handle = static_cast(info[0].As().Uint32Value()); + + std::string baseStr = "nats"; + if (info.Length() >= 2 && info[1].IsString()) { + baseStr = info[1].As().Utf8Value(); + } + + auto* state = _branchStore.get(handle); + if (!state) { + throw Napi::Error::New(env, "_branchModelEntropy: invalid handle"); + } + if (!state->has_logits) { + throw Napi::Error::New(env, "_branchModelEntropy: no logits captured (call prefill or commit first)"); + } + + float entropy = lloyal::metrics::model_entropy( + state->logits_snapshot.data(), state->n_vocab, parseBase(baseStr)); + + return Napi::Number::New(env, static_cast(entropy)); +} + +Napi::Value SessionContext::_branchModelSurprisal(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + if (info.Length() < 2) { + throw Napi::TypeError::New(env, "_branchModelSurprisal requires (handle, token[, base])"); + } + + auto handle = static_cast(info[0].As().Uint32Value()); + auto token = static_cast(info[1].As().Int32Value()); + + std::string baseStr = "nats"; + if (info.Length() >= 3 && info[2].IsString()) { + baseStr = info[2].As().Utf8Value(); + } + + auto* state = _branchStore.get(handle); + if (!state) { + throw Napi::Error::New(env, "_branchModelSurprisal: invalid handle"); + } + if (!state->has_logits) { + throw Napi::Error::New(env, "_branchModelSurprisal: no logits captured (call prefill or commit first)"); + } + + float surprisal = lloyal::metrics::model_surprisal( + state->logits_snapshot.data(), state->n_vocab, token, parseBase(baseStr)); + + return Napi::Number::New(env, static_cast(surprisal)); +} + +Napi::Value SessionContext::_branchGetSamplingPerplexity(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + if (info.Length() < 1) { + throw Napi::TypeError::New(env, "_branchGetSamplingPerplexity requires (handle)"); + } + + auto handle = static_cast(info[0].As().Uint32Value()); + float ppl = lloyal::branch::get_sampling_perplexity(handle, _branchStore); + + return Napi::Number::New(env, static_cast(ppl)); +} + +Napi::Value SessionContext::_branchSetLogitBias(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + if (info.Length() < 2) { + throw Napi::TypeError::New(env, "_branchSetLogitBias requires (handle, biases[])"); + } + + auto handle = static_cast(info[0].As().Uint32Value()); + + if (!info[1].IsArray()) { + throw Napi::TypeError::New(env, "_branchSetLogitBias: biases must be an array"); + } + + Napi::Array biasArray = info[1].As(); + uint32_t length = biasArray.Length(); + + std::vector biases; + biases.reserve(length); + + for (uint32_t i = 0; i < length; i++) { + Napi::Value item = biasArray[i]; + if (!item.IsObject()) { + throw Napi::Error::New(env, "_branchSetLogitBias: each bias must be {token, bias}"); + } + Napi::Object obj = item.As(); + + if (!obj.Has("token") || !obj.Has("bias")) { + throw Napi::Error::New(env, "_branchSetLogitBias: each bias must have 'token' and 'bias' properties"); + } + + llama_logit_bias bias; + bias.token = static_cast(obj.Get("token").As().Int32Value()); + bias.bias = obj.Get("bias").As().FloatValue(); + biases.push_back(bias); + } + + lloyal::branch::set_logit_bias(handle, biases.data(), biases.size(), _branchStore); + + return env.Undefined(); +} + +Napi::Value SessionContext::_branchClearLogitBias(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ensureNotDisposed(); + + if (info.Length() < 1) { + throw Napi::TypeError::New(env, "_branchClearLogitBias requires (handle)"); + } + + auto handle = static_cast(info[0].As().Uint32Value()); + lloyal::branch::clear_logit_bias(handle, _branchStore); + + return env.Undefined(); +} + // ===== STORE API ===== Napi::Value SessionContext::_storeCommit(const Napi::CallbackInfo& info) { diff --git a/src/SessionContext.hpp b/src/SessionContext.hpp index 254065e..55eb35a 100644 --- a/src/SessionContext.hpp +++ b/src/SessionContext.hpp @@ -78,13 +78,6 @@ class SessionContext : public Napi::ObjectWrap { private: // ===== CORE PRIMITIVES ===== - /** - * Get raw logits (zero-copy Float32Array) - * Returns: Float32Array pointing directly to llama.cpp memory - * Lifetime: Valid until next decode() call - */ - Napi::Value getLogits(const Napi::CallbackInfo& info); - /** * Tokenize text to token IDs * Args: text (string) @@ -235,23 +228,6 @@ class SessionContext : public Napi::ObjectWrap { */ Napi::Value hasPooling(const Napi::CallbackInfo& info); - // ===== METRICS API ===== - - /** - * Compute surprisal for a specific token - * Args: pickedTokenId (number), base? (string: "nats" | "bits" | "base10") - * Returns: number (surprisal in specified base) - */ - Napi::Value modelSurprisal(const Napi::CallbackInfo& info); - - /** - * Compute entropy of logits distribution - * Args: base? (string: "nats" | "bits" | "base10") - * Returns: number (entropy in specified base) - */ - Napi::Value modelEntropy(const Napi::CallbackInfo& info); - - // ===== BRANCH API (internal, wrapped by lib/Branch.ts) ===== Napi::Value _branchCreate(const Napi::CallbackInfo& info); @@ -273,6 +249,11 @@ class SessionContext : public Napi::ObjectWrap { Napi::Value _branchClearSteer(const Napi::CallbackInfo& info); Napi::Value _branchSetSamplerParams(const Napi::CallbackInfo& info); Napi::Value _branchSetGrammar(const Napi::CallbackInfo& info); + Napi::Value _branchModelEntropy(const Napi::CallbackInfo& info); + Napi::Value _branchModelSurprisal(const Napi::CallbackInfo& info); + Napi::Value _branchGetSamplingPerplexity(const Napi::CallbackInfo& info); + Napi::Value _branchSetLogitBias(const Napi::CallbackInfo& info); + Napi::Value _branchClearLogitBias(const Napi::CallbackInfo& info); // ===== STORE API (internal, wrapped by lib/BranchStore.js) ===== @@ -296,18 +277,6 @@ class SessionContext : public Napi::ObjectWrap { std::vector _turnSeparatorCache; bool _turnSeparatorCached = false; - // ===== LOGITS BUFFER MANAGEMENT (Memoization + Revocation) ===== - // - // Pattern: "Memoized Step-Scoped Views with Explicit Revocation" - // - // - Memoization: If getLogits() called twice in same step, return same buffer - // - Revocation: On decode(), detach previous buffer to prevent use-after-invalidation - // - // See: lloyal::logits::get() for the underlying safe wrapper - uint64_t _decodeStepId = 0; // Incremented on each decode() - uint64_t _logitsStepId = 0; // Step when _logitsBuffer was created - Napi::Reference _logitsBufferRef; // Strong reference - kept alive so we can Detach() on revocation - // ===== INLINE HELPERS ===== // Pattern matches HybridSessionContext.hpp:170-176 @@ -327,19 +296,6 @@ class SessionContext : public Napi::ObjectWrap { // Parse base string ("nats", "bits", "base10") to lloyal::metrics::Base enum static lloyal::metrics::Base parseBase(const std::string& baseStr); - - /** - * Invalidate any active logits buffer (The Kill Switch) - * - * Called before any operation that would invalidate the logits pointer: - * - decode() - * - encode() - * - dispose() - * - * Detaches the ArrayBuffer so any JS code holding a reference - * will get a TypeError when trying to access it. - */ - void invalidateLogits(); }; /** diff --git a/test/integration.js b/test/integration.js index ef52b37..b063058 100644 --- a/test/integration.js +++ b/test/integration.js @@ -34,7 +34,7 @@ console.log('=== lloyal.node Integration Tests ===\n'); console.log(`Model: ${path.basename(MODEL_PATH)}`); console.log(`Size: ${(fs.statSync(MODEL_PATH).size / 1024 / 1024).toFixed(1)} MB\n`); -const { loadBinary, Branch, BranchStore, withLogits } = require('..'); +const { loadBinary, Branch, BranchStore } = require('..'); let addon; try { addon = require('../build/Release/lloyal.node'); @@ -102,9 +102,9 @@ async function testCoreAPI(ctx) { } assert(hasNonZero && !hasNaN, 'branch logits valid (non-zero, no NaN)'); - // modelEntropy with branch logits - const entropy = ctx.modelEntropy('nats', branchLogits); - assert(isFinite(entropy) && entropy >= 0, `modelEntropy(branchLogits) → ${entropy.toFixed(4)} nats`); + // branch.modelEntropy + const entropy = branch.modelEntropy('nats'); + assert(isFinite(entropy) && entropy >= 0, `branch.modelEntropy() → ${entropy.toFixed(4)} nats`); // Branch greedy sampling (temperature: 0) const greedy = branch.sample(); @@ -114,24 +114,6 @@ async function testCoreAPI(ctx) { const eos = ctx.getEogToken(); assert(ctx.isStopToken(eos), `isStopToken(EOS=${eos}) → true`); - // withLogits helper (context-level logits) - // Note: getLogits() reads from the shared context buffer, which is populated - // by branch decode operations - const maxLogit = withLogits(ctx, (l) => { - let max = l[0]; - for (let i = 1; i < l.length; i++) if (l[i] > max) max = l[i]; - return max; - }); - assert(isFinite(maxLogit), `withLogits() sync → max=${maxLogit.toFixed(2)}`); - - let asyncRejected = false; - try { - withLogits(ctx, async () => 1); - } catch { - asyncRejected = true; - } - assert(asyncRejected, 'withLogits() rejects async callbacks'); - await branch.prune(); } @@ -272,13 +254,12 @@ async function testMetrics(ctx) { const branch = Branch.create(ctx, 0, { temperature: 0 }); await branch.prefill(tokens); - // modelSurprisal with branch logits + // branch.modelSurprisal const token1 = branch.sample(); - const branchLogits = branch.getLogits(); - const surprisal = ctx.modelSurprisal(token1, "nats", branchLogits); - assert(surprisal >= 0, `modelSurprisal(branchLogits) → ${surprisal.toFixed(2)} nats`); + const surprisal = branch.modelSurprisal(token1, "nats"); + assert(surprisal >= 0, `branch.modelSurprisal() → ${surprisal.toFixed(2)} nats`); - const surprisalBits = ctx.modelSurprisal(token1, "bits", branchLogits); + const surprisalBits = branch.modelSurprisal(token1, "bits"); assert(Math.abs(surprisalBits - surprisal / Math.log(2)) < 0.01, 'bits = nats / ln(2)'); // Branch perplexity — built-in, accumulates through commit() @@ -1103,20 +1084,18 @@ async function testBranchStore() { assert(logits.length === ctx.vocabSize, `getLogits: length=${logits.length} === vocabSize=${ctx.vocabSize}`); - // Feed branch logits into ctx.modelEntropy() — proves the returned - // buffer is a valid logits distribution consumable by metrics API - const entropyFromBranch = ctx.modelEntropy("nats", logits); + // branch.modelEntropy — proves the logits snapshot is a valid distribution + const entropyFromBranch = b1.modelEntropy("nats"); assert(isFinite(entropyFromBranch) && entropyFromBranch > 0, - `getLogits→modelEntropy: ${entropyFromBranch.toFixed(4)} nats`); + `branch.modelEntropy: ${entropyFromBranch.toFixed(4)} nats`); - // After store.commit, logits change — getLogits() reflects new state + // After store.commit, logits change — branch reflects new state const p = await b1.produce(); - assert(!p.isStop, `getLogits: produce() should not hit EOG on first token`); + assert(!p.isStop, `modelEntropy: produce() should not hit EOG on first token`); await store.commit([[b1, p.token]]); - const logitsAfter = b1.getLogits(); - const entropyAfter = ctx.modelEntropy("nats", logitsAfter); + const entropyAfter = b1.modelEntropy("nats"); assert(isFinite(entropyAfter), - `getLogits after commit: entropy=${entropyAfter.toFixed(4)} nats`); + `modelEntropy after commit: entropy=${entropyAfter.toFixed(4)} nats`); await b1.prune(); } @@ -1815,6 +1794,95 @@ ws ::= [ \\t\\n]*`; } } +// ═══════════════════════════════════════════════════════════════════════════ +// BRANCH METRICS & LOGIT BIAS +// ═══════════════════════════════════════════════════════════════════════════ + +async function testBranchMetrics() { + console.log('\n--- Branch Metrics & Logit Bias ---'); + + const ctx = await addon.createContext({ + modelPath: MODEL_PATH, + nCtx: CTX_SIZE, + nThreads: 4, + nSeqMax: 8, + }); + + try { + const tokens = await ctx.tokenize("The capital of France is"); + const branch = Branch.create(ctx, 0, { temperature: 0.8, seed: 42 }); + await branch.prefill(tokens); + + // branch.modelEntropy + const entropy = branch.modelEntropy('nats'); + assert(isFinite(entropy) && entropy >= 0, `branch.modelEntropy('nats') → ${entropy.toFixed(4)}`); + + const entropyBits = branch.modelEntropy('bits'); + assert(Math.abs(entropyBits - entropy / Math.log(2)) < 0.01, + `branch.modelEntropy('bits') consistent with nats`); + + // branch.modelSurprisal + const token = branch.sample(); + const surprisal = branch.modelSurprisal(token, 'nats'); + assert(isFinite(surprisal) && surprisal >= 0, + `branch.modelSurprisal(${token}, 'nats') → ${surprisal.toFixed(4)}`); + + const surprisalBits = branch.modelSurprisal(token, 'bits'); + assert(Math.abs(surprisalBits - surprisal / Math.log(2)) < 0.01, + `branch.modelSurprisal bits consistent with nats`); + + // branch.samplingPerplexity — before any commits, must be Infinity + const pplBefore = branch.samplingPerplexity; + assert(pplBefore === Infinity, + `branch.samplingPerplexity before commit should be Infinity, got ${pplBefore}`); + + // Commit a few tokens to accumulate sampling perplexity + await branch.commit(token); + const { token: t2 } = await branch.produce(); + await branch.commit(t2); + + const pplAfter = branch.samplingPerplexity; + assert(isFinite(pplAfter) && pplAfter >= 1.0, + `branch.samplingPerplexity after commits → ${pplAfter.toFixed(4)}`); + + // setLogitBias — get greedy baseline, ban it, verify it changes + const baseline = Branch.create(ctx, 0, { temperature: 0 }); + await baseline.prefill(tokens); + const bannedToken = baseline.sample(); + await baseline.prune(); + + const greedy = Branch.create(ctx, 0, { temperature: 0 }); + await greedy.prefill(tokens); + greedy.setLogitBias([{ token: bannedToken, bias: -Infinity }]); + const alternative = greedy.sample(); + assert(alternative !== bannedToken, + `setLogitBias: banned token ${bannedToken} not sampled (got ${alternative})`); + + // clearLogitBias — after clearing, the greedy baseline token should come back + const greedy2 = Branch.create(ctx, 0, { temperature: 0 }); + await greedy2.prefill(tokens); + const greedyToken = greedy2.sample(); + assert(greedyToken === bannedToken, + `clearLogitBias: greedy token ${greedyToken} === baseline ${bannedToken}`); + + // setLogitBias cloned on fork + const parent = Branch.create(ctx, 0, { temperature: 0 }); + await parent.prefill(tokens); + parent.setLogitBias([{ token: bannedToken, bias: -Infinity }]); + const child = await parent.fork(); + const childToken = child.sample(); + assert(childToken !== bannedToken, + `setLogitBias cloned on fork: child doesn't sample banned token`); + + await branch.prune(); + await greedy.prune(); + await greedy2.prune(); + await parent.pruneSubtree(); + } finally { + ctx.dispose(); + } +} + // ═══════════════════════════════════════════════════════════════════════════ // MAIN // ═══════════════════════════════════════════════════════════════════════════ @@ -1858,6 +1926,7 @@ async function main() { await testAsyncIterator(); await testSetSamplerParams(); await testSetGrammar(); + await testBranchMetrics(); await testEmbeddings(); // Summary