Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/entropy/entropy.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ async function generate(ctx, prompt, strategy, strategyName, maxTokens = 50) {
const entropies = [];

for (let i = 0; i < maxTokens; i++) {
const branchLogits = branch.getLogits();
const entropy = ctx.modelEntropy('nats', branchLogits);
const entropy = branch.modelEntropy('nats');
entropies.push(entropy);

const temp = strategy === 'edt' ? edtTemperature(entropy) : strategy;
Expand Down
2 changes: 1 addition & 1 deletion examples/speculative/speculative.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async function main() {

for (let i = 0; i < DRAFT_COUNT && output.length + drafts.length < GENERATION_LENGTH; i++) {
// Get entropy BEFORE sampling (from draft branch's logits snapshot)
const entropy = ctx.modelEntropy('nats', draft.getLogits());
const entropy = draft.modelEntropy('nats');

// produce() samples from captured logits (no KV write yet)
const { token, text, isStop } = draft.produceSync();
Expand Down
3 changes: 1 addition & 2 deletions examples/streaming/streaming-summary.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,7 @@ Begin:
break;
}

const branchLogits = branch.getLogits();
const surprisal = ctx.modelSurprisal(token, 'nats', branchLogits);
const surprisal = branch.modelSurprisal(token, 'nats');
nllSum += Math.max(0, surprisal);
nllCount++;

Expand Down
4 changes: 2 additions & 2 deletions examples/streaming/streaming-tsampler.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ Begin:
// tokenHistory.accept(token); // Disabled - matching baseline
ngramTracker.accept(token);

// Track surprisal from original (unmodified) logits
const surprisal = ctx.modelSurprisal(token, 'nats', originalLogits);
// Track surprisal from branch's logits snapshot (before N-gram steering)
Copy link

Copilot AI Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states 'before N-gram steering' but the logits snapshot is captured after prefill/commit, which occurs after N-gram steering has been applied. Consider clarifying whether this comment accurately reflects when the snapshot is taken relative to steering operations.

Suggested change
// Track surprisal from branch's logits snapshot (before N-gram steering)
// Track surprisal from branch's logits snapshot (after N-gram steering has been applied)

Copilot uses AI. Check for mistakes.
const surprisal = branch.modelSurprisal(token, 'nats');
nllSum += Math.max(0, surprisal);
nllCount++;

Expand Down
3 changes: 1 addition & 2 deletions examples/streaming/streaming.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ Begin:
}

// Track surprisal from the logits used by produce()
const branchLogits = branch.getLogits();
const surprisal = ctx.modelSurprisal(token, 'nats', branchLogits);
const surprisal = branch.modelSurprisal(token, 'nats');
nllSum += Math.max(0, surprisal);
nllCount++;

Expand Down
53 changes: 53 additions & 0 deletions lib/Branch.js
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,59 @@ class Branch {
await this._ctx._storeCommit([this._handle], [token]);
}

// ===== METRICS =====

/**
* Compute entropy of the branch's logits distribution
*
* @param {'nats'|'bits'} [base='nats']
* @returns {number}
*/
modelEntropy(base = 'nats') {
this._ensureNotDisposed();
return this._ctx._branchModelEntropy(this._handle, base);
}

/**
* Compute surprisal for a specific token from the branch's logits
*
* @param {number} token
* @param {'nats'|'bits'} [base='nats']
* @returns {number}
*/
modelSurprisal(token, base = 'nats') {
this._ensureNotDisposed();
return this._ctx._branchModelSurprisal(this._handle, token, base);
}

/**
* Sampling-level perplexity (from filtered distribution)
*
* @returns {number}
*/
get samplingPerplexity() {
this._ensureNotDisposed();
return this._ctx._branchGetSamplingPerplexity(this._handle);
}

/**
* Set static logit biases on this branch (cloned on fork)
*
* @param {Array<{token: number, bias: number}>} biases
*/
setLogitBias(biases) {
this._ensureNotDisposed();
this._ctx._branchSetLogitBias(this._handle, biases);
}

/**
* Clear all static logit biases from this branch
*/
clearLogitBias() {
this._ensureNotDisposed();
this._ctx._branchClearLogitBias(this._handle);
}

// ===== ACCESSORS =====

/** @returns {number} Branch's current position (number of tokens decoded) */
Expand Down
199 changes: 88 additions & 111 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,10 @@ export interface SamplingParams {
* per-branch KV sequences, sampler chains, and logits snapshots with O(1)
* GPU dispatches via batched decode.
*
* **Logits lifetime**: `getLogits()` returns a zero-copy Float32Array wrapping
* llama.cpp's internal buffer. It is invalidated (ArrayBuffer detached) on
* the next `encode()` or `dispose()`. Use {@link withLogits} for safe scoped
* access. For branch-level logits, use {@link Branch.getLogits} instead.
* **Logits**: For branch-level logits, use {@link Branch.getLogits} which
* returns an independent copy of the branch's snapshot. For metrics, use
* {@link Branch.modelEntropy} and {@link Branch.modelSurprisal} which
* operate directly on the branch's logits without JS round-trips.
*
* **KV cache**: Supports multi-sequence operation (`nSeqMax > 1`), per-sequence
* copy/clear/eviction, file-based persistence, and context compression via
Expand All @@ -640,32 +640,6 @@ export interface SamplingParams {
*/
export interface SessionContext {

/**
* Get logits (zero-copy view into model memory)
*
* Returns unnormalized scores for every possible next token.
* Higher score = model thinks this token is more likely.
* The returned Float32Array wraps llama.cpp's internal buffer directly
* (zero-copy).
*
* LIFETIME CONSTRAINTS:
* - Valid ONLY until the next encode() or dispose() call
* - The ArrayBuffer is detached on invalidation — accessing a stale
* buffer throws a TypeError
* - DO NOT retain references across async boundaries
*
* For a safe scoped access pattern, use {@link withLogits} instead.
*
* NOTE: This is the context-level logits view. For branch-level logits,
* use {@link Branch.getLogits} which returns an independent copy of the
* branch's snapshot (safe to hold, not invalidated by decode).
*
* Cost: ~0.5ms (zero-copy pointer, no data copied)
*
* @returns Float32Array of unnormalized logits (vocabSize elements)
*/
getLogits(): Float32Array;

/**
* Convert token ID to text piece
*
Expand Down Expand Up @@ -1005,56 +979,6 @@ export interface SessionContext {
*/
kvSeqPosMax(seqId: number): number;

// ===== METRICS API =====

/**
* Compute surprisal (negative log-likelihood) for a specific token.
*
* Measures how "surprising" the model finds the given token:
* - Low surprisal: Model expected this token (high probability)
* - High surprisal: Model didn't expect this token (low probability)
*
* Pass captured logits (e.g., from {@link Branch.getLogits}) for
* offline computation, or omit to use the current context logits.
*
* @param pickedTokenId - Token ID to compute surprisal for
* @param base - Logarithm base: "nats" (default) or "bits"
* @param logits - Optional Float32Array of logits (uses current context logits if omitted)
* @returns Surprisal value in specified base
*
* @example With branch logits
* ```typescript
* const { token } = await branch.produce();
* const surprisal = ctx.modelSurprisal(token, "bits", branch.getLogits());
* ```
*
* COST: O(n_vocab) - softmax normalization required
*/
modelSurprisal(pickedTokenId: number, base?: 'nats' | 'bits', logits?: Float32Array): number;

/**
* Compute entropy of the entire logits distribution.
*
* Measures model uncertainty:
* - Low entropy: Model is confident (peaked distribution)
* - High entropy: Model is uncertain (flat distribution)
*
* Pass captured logits (e.g., from {@link Branch.getLogits}) for
* offline analysis, or omit to use the current context logits.
*
* @param base - Logarithm base: "nats" (default) or "bits"
* @param logits - Optional Float32Array of logits (uses current context logits if omitted)
* @returns Entropy value in specified base
*
* @example With branch logits
* ```typescript
* const entropy = ctx.modelEntropy("bits", branch.getLogits());
* ```
*
* COST: O(n_vocab) - must sum over all token probabilities
*/
modelEntropy(base?: 'nats' | 'bits', logits?: Float32Array): number;

// ===== KV CACHE FILE PERSISTENCE =====

/**
Expand Down Expand Up @@ -1352,7 +1276,7 @@ export interface SessionContext {
/**
* Model vocabulary size (number of possible tokens)
*
* This is the length of the logits array from getLogits().
* This is the length of the logits array from Branch.getLogits().
*/
readonly vocabSize: number;

Expand Down Expand Up @@ -1433,6 +1357,21 @@ export interface SessionContext {
/** @internal Replace or remove grammar constraint */
_branchSetGrammar(handle: number, grammarStr: string): void;

/** @internal Compute entropy from branch's logits snapshot */
_branchModelEntropy(handle: number, base?: string): number;

/** @internal Compute surprisal from branch's logits snapshot */
_branchModelSurprisal(handle: number, token: number, base?: string): number;

/** @internal Get sampling-level perplexity */
_branchGetSamplingPerplexity(handle: number): number;

/** @internal Set static logit biases on a branch */
_branchSetLogitBias(handle: number, biases: Array<{ token: number; bias: number }>): void;

/** @internal Clear all static logit biases from a branch */
_branchClearLogitBias(handle: number): void;

// ===== STORE API (internal, wrapped by BranchStore) =====

/** @internal Batched accept + decode_each + capture for N branches */
Expand Down Expand Up @@ -1564,30 +1503,6 @@ export function loadBinary(variant?: GpuVariant): {
createContext(options: ContextOptions): Promise<SessionContext>;
};

/**
* Safe logits access with automatic lifetime management
*
* Ensures logits are only accessed synchronously within the callback.
* The callback MUST NOT:
* - Store the logits reference
* - Return a Promise (will throw)
*
* This prevents common bugs where logits become invalid due to
* async operations between access and usage.
*
* @template T Return type of the callback
* @param ctx The session context
* @param fn Synchronous callback that uses logits - must not return a Promise
* @returns The result from the callback
* @throws Error if callback returns a Promise (async usage not allowed)
*
* @category Core
*/
export function withLogits<T>(
ctx: SessionContext,
fn: (logits: Float32Array) => T
): T;

/**
* Result from Branch.produce()
*
Expand Down Expand Up @@ -1681,11 +1596,9 @@ export class Branch {
* Returns n_vocab floats — the raw logit distribution from the last
* prefill() or commit() call.
*
* Unlike {@link SessionContext.getLogits} (zero-copy view into shared
* model memory, invalidated by next decode), this returns an independent
* copy of the branch's internal snapshot. The returned Float32Array is
* safe to hold across async boundaries and is not affected by subsequent
* decode operations.
* Returns an independent copy of the branch's internal snapshot.
* The returned Float32Array is safe to hold across async boundaries
* and is not affected by subsequent decode operations.
*
* @returns Independent copy of the logits snapshot (n_vocab elements)
* @throws If no logits have been captured yet
Expand Down Expand Up @@ -1842,6 +1755,70 @@ export class Branch {
*/
clearSteer(): void;

/**
* Compute entropy of the branch's logits distribution
*
* Measures model uncertainty from the branch's captured logits snapshot:
* - Low entropy: Model is confident (peaked distribution)
* - High entropy: Model is uncertain (flat distribution)
*
* Operates directly on `state->logits_snapshot` — no JS round-trip.
*
* @param base - Logarithm base: "nats" (default) or "bits"
* @returns Entropy value in specified base
*
* COST: O(n_vocab) - must sum over all token probabilities
*/
modelEntropy(base?: 'nats' | 'bits'): number;

/**
* Compute surprisal (negative log-likelihood) for a specific token
*
* Measures how "surprising" the model finds the given token from
* the branch's captured logits snapshot:
* - Low surprisal: Model expected this token (high probability)
* - High surprisal: Model didn't expect this token (low probability)
*
* Operates directly on `state->logits_snapshot` — no JS round-trip.
*
* @param token - Token ID to compute surprisal for
* @param base - Logarithm base: "nats" (default) or "bits"
* @returns Surprisal value in specified base
*
* COST: O(n_vocab) - softmax normalization required
*/
modelSurprisal(token: number, base?: 'nats' | 'bits'): number;

/**
* Sampling-level perplexity (from filtered distribution)
*
* Returns perplexity from the distribution actually sampled from
* (after top-k/p/temp/penalties). Useful for policy priors and
* monitoring sampler chain impact.
*
* Compare with {@link perplexity} which is model-level (raw logits).
*/
readonly samplingPerplexity: number;

/**
* Set static logit biases on this branch
*
* Unlike {@link steer} (which is NOT inherited on fork), logit biases
* ARE cloned when forking. Use for persistent constraints that should
* propagate to child branches.
*
* Applied during sample() in order: Grammar -> Logit Bias -> Steer -> Sampler Chain
*
* @param biases - Array of token adjustments. Use `-Infinity` to block,
* positive to boost, negative to reduce.
*/
setLogitBias(biases: Array<{ token: number; bias: number }>): void;

/**
* Clear all static logit biases from this branch
*/
clearLogitBias(): void;

/**
* Replace the sampler chain with new parameters (memoized)
*
Expand All @@ -1854,7 +1831,7 @@ export class Branch {
*
* @example Entropy-Driven Temperature
* ```typescript
* const entropy = ctx.modelEntropy('nats', branch.getLogits());
* const entropy = branch.modelEntropy('nats');
* branch.setSamplerParams({ temperature: edtTemperature(entropy) });
* const { token } = await branch.produce();
* await branch.commit(token);
Expand Down
Loading
Loading