Skip to content

Commit 5b04fd5

Browse files
Merge pull request #16 from lloyal-ai/feat/throughput
Adopt the new BranchStore API from liblloyal for high throughput multi-branch operations
2 parents 2359fb1 + 0fc0d35 commit 5b04fd5

10 files changed

Lines changed: 703 additions & 20 deletions

File tree

examples/grammar/grammar.mjs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function* tokenGenerator(ctx, grammarHandle, maxTokens = 100) {
4242
const logits = ctx.getLogits();
4343
ctx.applySampler(grammarHandle, logits);
4444

45-
const token = ctx.sample({ temperature: 0.7 });
45+
const token = ctx.sample();
4646
if (ctx.isStopToken(token)) return;
4747

4848
// Advance grammar state

lib/Branch.js

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,22 @@ class Branch {
111111
this._ctx._branchCaptureLogits(this._handle);
112112
}
113113

114+
/**
115+
* Get a copy of this branch's captured logits snapshot
116+
*
117+
* Returns n_vocab floats — the raw logit distribution from the last
118+
* decode_and_capture or captureLogits() call. Use for distributional
119+
* analysis (KL divergence, entropy, top-k overlap) without crossing
120+
* the sampling chain.
121+
*
122+
* @returns {Float32Array} Copy of the logits snapshot (n_vocab elements)
123+
* @throws {Error} If no logits have been captured yet
124+
*/
125+
getLogits() {
126+
this._ensureNotDisposed();
127+
return this._ctx._branchGetLogits(this._handle);
128+
}
129+
114130
/**
115131
* Single-token forward pass with logit snapshot
116132
*

lib/BranchStore.js

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/**
2+
* BranchStore - Batched multi-branch decode operations
3+
*
4+
* See index.d.ts for full API documentation.
5+
*/
6+
class BranchStore {
7+
constructor(ctx) {
8+
this._ctx = ctx;
9+
}
10+
11+
// entries: [branch, token][] — binding is structural, not positional
12+
commit(entries) {
13+
const handles = [], tokens = [];
14+
for (const [branch, token] of entries) {
15+
if (branch.disposed) throw new Error('BranchStore.commit: branch is disposed');
16+
handles.push(branch.handle);
17+
tokens.push(token);
18+
}
19+
this._ctx._storeCommit(handles, tokens);
20+
}
21+
22+
// entries: [branch, tokens[]][] — binding is structural, not positional
23+
prefill(entries) {
24+
const handles = [], tokenArrays = [];
25+
for (const [branch, tokens] of entries) {
26+
if (branch.disposed) throw new Error('BranchStore.prefill: branch is disposed');
27+
handles.push(branch.handle);
28+
tokenArrays.push(tokens);
29+
}
30+
this._ctx._storePrefill(handles, tokenArrays);
31+
}
32+
}
33+
34+
module.exports = { BranchStore };

lib/index.d.ts

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,9 @@ export interface SessionContext {
16941694
/** @internal Get branch's perplexity */
16951695
_branchGetPerplexity(handle: number): number;
16961696

1697+
/** @internal Get copy of branch's logits snapshot */
1698+
_branchGetLogits(handle: number): Float32Array;
1699+
16971700
/** @internal Prune branch (remove KV cache entries and free handle) */
16981701
_branchPrune(handle: number): void;
16991702

@@ -1708,6 +1711,14 @@ export interface SessionContext {
17081711

17091712
/** @internal Clear all dynamic logit biases from a branch */
17101713
_branchClearSteer(handle: number): void;
1714+
1715+
// ===== STORE API (internal, wrapped by BranchStore) =====
1716+
1717+
/** @internal Batched accept + decode_each + capture for N branches */
1718+
_storeCommit(handles: number[], tokens: number[]): void;
1719+
1720+
/** @internal Batched decode_scatter + capture for N branches with variable token counts */
1721+
_storePrefill(handles: number[], tokenArrays: number[][]): void;
17111722
}
17121723

17131724
/**
@@ -1945,6 +1956,17 @@ export class Branch {
19451956
/** Freeze the current logit distribution into this branch. Essential before fork(). */
19461957
captureLogits(): void;
19471958

1959+
/**
1960+
* Get a copy of this branch's captured logits snapshot.
1961+
*
1962+
* Returns n_vocab floats — the raw logit distribution from the last
1963+
* decode_and_capture or captureLogits() call.
1964+
*
1965+
* @returns Copy of the logits snapshot (n_vocab elements)
1966+
* @throws If no logits have been captured yet
1967+
*/
1968+
getLogits(): Float32Array;
1969+
19481970
/** Decode a single token, write to KV, and capture resulting logits */
19491971
decodeAndCaptureOne(token: number): void;
19501972

@@ -2103,3 +2125,74 @@ export class Branch {
21032125
/** Whether this branch has been disposed */
21042126
readonly disposed: boolean;
21052127
}
2128+
2129+
/**
2130+
* Batched multi-branch decode operations
2131+
*
2132+
* Packs multiple branches into a single llama_decode() call, reducing
2133+
* GPU dispatch overhead from N dispatches to 1.
2134+
*
2135+
* Both methods take an array of **`[branch, token(s)]` tuples** — the
2136+
* branch-to-token binding is structural, not positional. Each branch
2137+
* receives exactly the token(s) paired with it.
2138+
*
2139+
* - `commit()` calls accept_token per branch (updating repeat-penalty windows)
2140+
* before the batched decode. Use for model-generated tokens.
2141+
* - `prefill()` does NOT accept — use for external/replayed tokens where
2142+
* penalty tracking is unwanted.
2143+
*
2144+
* After either call, each branch's logits snapshot is updated with the
2145+
* output distribution from its decoded token(s), ready for the next
2146+
* `produce()`/`sample()` call.
2147+
*
2148+
* @example Best-of-N with batched commit
2149+
* ```typescript
2150+
* const store = new BranchStore(ctx);
2151+
* const branches = [1, 2, 3].map(id => root.fork(id));
2152+
*
2153+
* for (let step = 0; step < 50; step++) {
2154+
* const live = branches.map(b => [b, b.produce()] as const)
2155+
* .filter(([, p]) => !p.isStop);
2156+
* if (!live.length) break;
2157+
* store.commit(live.map(([b, p]) => [b, p.token]));
2158+
* }
2159+
* ```
2160+
*
2161+
* @example Rehydrate divergent histories with batched prefill
2162+
* ```typescript
2163+
* const store = new BranchStore(ctx);
2164+
* store.prefill([[b1, historyA], [b2, historyB]]);
2165+
* ```
2166+
*/
2167+
export class BranchStore {
2168+
constructor(ctx: SessionContext);
2169+
2170+
/**
2171+
* Batched single-token commit for model-generated tokens
2172+
*
2173+
* Each tuple `[branch, token]` binds one token to one branch.
2174+
* Accepts each token into its branch's repeat-penalty window,
2175+
* then decodes all N tokens in a single llama_decode() call via decode_each.
2176+
* Logits are captured per-branch after decode.
2177+
*
2178+
* @param entries - Array of `[branch, token]` tuples (branches must not be disposed)
2179+
* @throws If any branch is disposed
2180+
*/
2181+
commit(entries: [Branch, number][]): void;
2182+
2183+
/**
2184+
* Batched variable-length prefill for external tokens
2185+
*
2186+
* Each tuple `[branch, tokens]` binds a token array to one branch.
2187+
* Each branch can receive a different number of tokens — decode_scatter
2188+
* handles variable-length runs and auto-chunks to fit nBatch.
2189+
*
2190+
* Does NOT call accept_token — use for external/replayed tokens where
2191+
* repeat-penalty tracking is unwanted. For model-generated tokens,
2192+
* use {@link commit} instead.
2193+
*
2194+
* @param entries - Array of `[branch, tokens]` tuples (branches must not be disposed)
2195+
* @throws If any branch is disposed
2196+
*/
2197+
prefill(entries: [Branch, number[]][]): void;
2198+
}

lib/index.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,19 @@ function withLogits(ctx, fn) {
256256
}
257257

258258
const { Branch } = require('./Branch');
259+
const { BranchStore } = require('./BranchStore');
259260

260261
module.exports = {
261262
/**
262263
* Branch class for parallel generation
263264
* @see Branch.create()
264265
*/
265266
Branch,
267+
/**
268+
* BranchStore class for batched multi-branch decode
269+
* @see BranchStore
270+
*/
271+
BranchStore,
266272
/**
267273
* Create a new inference context
268274
*

package-lock.json

Lines changed: 143 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)