Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/grammar/grammar.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function* tokenGenerator(ctx, grammarHandle, maxTokens = 100) {
const logits = ctx.getLogits();
ctx.applySampler(grammarHandle, logits);

const token = ctx.sample({ temperature: 0.7 });
const token = ctx.sample();
if (ctx.isStopToken(token)) return;

// Advance grammar state
Expand Down
16 changes: 16 additions & 0 deletions lib/Branch.js
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ class Branch {
this._ctx._branchCaptureLogits(this._handle);
}

/**
* Get a copy of this branch's captured logits snapshot
*
* Returns n_vocab floats — the raw logit distribution from the last
* decode_and_capture or captureLogits() call. Use for distributional
* analysis (KL divergence, entropy, top-k overlap) without crossing
* the sampling chain.
*
* @returns {Float32Array} Copy of the logits snapshot (n_vocab elements)
* @throws {Error} If no logits have been captured yet
*/
getLogits() {
this._ensureNotDisposed();
return this._ctx._branchGetLogits(this._handle);
}

/**
* Single-token forward pass with logit snapshot
*
Expand Down
34 changes: 34 additions & 0 deletions lib/BranchStore.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/**
* BranchStore - Batched multi-branch decode operations
*
* See index.d.ts for full API documentation.
*/
class BranchStore {
constructor(ctx) {
this._ctx = ctx;
}

// entries: [branch, token][] — binding is structural, not positional
commit(entries) {
const handles = [], tokens = [];
for (const [branch, token] of entries) {
if (branch.disposed) throw new Error('BranchStore.commit: branch is disposed');
handles.push(branch.handle);
tokens.push(token);
}
this._ctx._storeCommit(handles, tokens);
}
Comment on lines +12 to +20
Copy link

Copilot AI Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BranchStore.commit()/prefill() are public JS entrypoints but currently rely on destructuring and downstream native validation for most input-shape errors. If a caller passes a non-iterable, malformed tuple, or wrong token types, the thrown error may be confusing. Consider adding lightweight upfront validation (e.g., Array.isArray(entries), tuple length checks, numeric token checks, and Array.isArray(tokens) for prefill) and throwing a TypeError with a clear message.

Copilot uses AI. Check for mistakes.

// entries: [branch, tokens[]][] — binding is structural, not positional
prefill(entries) {
const handles = [], tokenArrays = [];
for (const [branch, tokens] of entries) {
if (branch.disposed) throw new Error('BranchStore.prefill: branch is disposed');
handles.push(branch.handle);
tokenArrays.push(tokens);
}
this._ctx._storePrefill(handles, tokenArrays);
}
}

module.exports = { BranchStore };
93 changes: 93 additions & 0 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,9 @@ export interface SessionContext {
/** @internal Get branch's perplexity */
_branchGetPerplexity(handle: number): number;

/** @internal Get copy of branch's logits snapshot */
_branchGetLogits(handle: number): Float32Array;

/** @internal Prune branch (remove KV cache entries and free handle) */
_branchPrune(handle: number): void;

Expand All @@ -1708,6 +1711,14 @@ export interface SessionContext {

/** @internal Clear all dynamic logit biases from a branch */
_branchClearSteer(handle: number): void;

// ===== STORE API (internal, wrapped by BranchStore) =====

/** @internal Batched accept + decode_each + capture for N branches */
_storeCommit(handles: number[], tokens: number[]): void;

/** @internal Batched decode_scatter + capture for N branches with variable token counts */
_storePrefill(handles: number[], tokenArrays: number[][]): void;
}

/**
Expand Down Expand Up @@ -1945,6 +1956,17 @@ export class Branch {
/** Freeze the current logit distribution into this branch. Essential before fork(). */
captureLogits(): void;

/**
* Get a copy of this branch's captured logits snapshot.
*
* Returns n_vocab floats — the raw logit distribution from the last
* decode_and_capture or captureLogits() call.
*
* @returns Copy of the logits snapshot (n_vocab elements)
* @throws If no logits have been captured yet
*/
getLogits(): Float32Array;

/** Decode a single token, write to KV, and capture resulting logits */
decodeAndCaptureOne(token: number): void;

Expand Down Expand Up @@ -2103,3 +2125,74 @@ export class Branch {
/** Whether this branch has been disposed */
readonly disposed: boolean;
}

/**
* Batched multi-branch decode operations
*
* Packs multiple branches into a single llama_decode() call, reducing
* GPU dispatch overhead from N dispatches to 1.
*
* Both methods take an array of **`[branch, token(s)]` tuples** — the
* branch-to-token binding is structural, not positional. Each branch
* receives exactly the token(s) paired with it.
*
* - `commit()` calls accept_token per branch (updating repeat-penalty windows)
* before the batched decode. Use for model-generated tokens.
* - `prefill()` does NOT accept — use for external/replayed tokens where
* penalty tracking is unwanted.
*
* After either call, each branch's logits snapshot is updated with the
* output distribution from its decoded token(s), ready for the next
* `produce()`/`sample()` call.
*
* @example Best-of-N with batched commit
* ```typescript
* const store = new BranchStore(ctx);
* const branches = [1, 2, 3].map(id => root.fork(id));
*
* for (let step = 0; step < 50; step++) {
* const live = branches.map(b => [b, b.produce()] as const)
* .filter(([, p]) => !p.isStop);
* if (!live.length) break;
* store.commit(live.map(([b, p]) => [b, p.token]));
* }
* ```
*
* @example Rehydrate divergent histories with batched prefill
* ```typescript
* const store = new BranchStore(ctx);
* store.prefill([[b1, historyA], [b2, historyB]]);
* ```
*/
export class BranchStore {
constructor(ctx: SessionContext);

/**
* Batched single-token commit for model-generated tokens
*
* Each tuple `[branch, token]` binds one token to one branch.
* Accepts each token into its branch's repeat-penalty window,
* then decodes all N tokens in a single llama_decode() call via decode_each.
* Logits are captured per-branch after decode.
*
* @param entries - Array of `[branch, token]` tuples (branches must not be disposed)
* @throws If any branch is disposed
*/
commit(entries: [Branch, number][]): void;

/**
* Batched variable-length prefill for external tokens
*
* Each tuple `[branch, tokens]` binds a token array to one branch.
* Each branch can receive a different number of tokens — decode_scatter
* handles variable-length runs and auto-chunks to fit nBatch.
*
* Does NOT call accept_token — use for external/replayed tokens where
* repeat-penalty tracking is unwanted. For model-generated tokens,
* use {@link commit} instead.
*
* @param entries - Array of `[branch, tokens]` tuples (branches must not be disposed)
* @throws If any branch is disposed
*/
prefill(entries: [Branch, number[]][]): void;
}
6 changes: 6 additions & 0 deletions lib/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,19 @@ function withLogits(ctx, fn) {
}

const { Branch } = require('./Branch');
const { BranchStore } = require('./BranchStore');

module.exports = {
/**
* Branch class for parallel generation
* @see Branch.create()
*/
Branch,
/**
* BranchStore class for batched multi-branch decode
* @see BranchStore
*/
BranchStore,
/**
* Create a new inference context
*
Expand Down
156 changes: 143 additions & 13 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading