Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions lib/Branch.js
Original file line number Diff line number Diff line change
Expand Up @@ -304,19 +304,20 @@ class Branch {
}

/**
* Decode and advancewrite token to KV and update branch state
* Accept and decodeupdate branch state, then write token to KV
*
* Decodes the token (writing to KV cache), captures the resulting logits
* for the next produce() call, then accepts into the sampler penalty window.
* Decode-first ordering: if decode throws, sampler state stays consistent.
* Accepts the token into the sampler penalty window (for correct PPL
* measurement), then decodes (writing to KV cache) and captures the
* resulting logits for the next produce() call. Accept-first ordering
* with rollback: if decode throws, sampler/grammar/metrics are restored
* from clones taken before the accept.
*
* @param {number} token - Token to commit (from produce())
* @returns {Promise<void>}
*/
async commit(token) {
this._ensureNotDisposed();
await this.decodeAndCaptureOne(token);
this.accept(token);
await this._ctx._storeCommit([this._handle], [token]);
}

// ===== ACCESSORS =====
Expand Down
20 changes: 11 additions & 9 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2347,12 +2347,13 @@ export class Branch {
produce(): Produced;

/**
* Decode and advancewrite token to KV and update branch state
* Accept and decodeupdate branch state, then write token to KV
*
* Decodes the token (writing to KV cache via AsyncWorker on the libuv
* thread pool), captures the resulting logits for the next produce() call,
* then accepts into the sampler penalty window. Decode-first ordering:
* if decode throws, sampler state stays consistent.
* Accepts the token into the sampler penalty window (for correct PPL
* measurement), then decodes (writing to KV cache via AsyncWorker on
* the libuv thread pool) and captures the resulting logits for the next
* produce() call. Accept-first ordering with rollback: if decode throws,
* sampler/grammar/metrics are restored from clones.
*
* @param token Token to commit (from produce())
*/
Expand Down Expand Up @@ -2495,10 +2496,11 @@ export class BranchStore {
* Batched single-token commit for model-generated tokens
*
* Each tuple `[branch, token]` binds one token to one branch.
* Decodes all N tokens in a single llama_decode() call via decode_each,
* captures logits per-branch, then accepts each token into its branch's
* repeat-penalty window. Decode-first ordering ensures sampler state
* stays consistent if decode throws.
* Accepts each token into its branch's repeat-penalty window (for correct
* PPL measurement), then decodes all N tokens in a single llama_decode()
* call via decode_each and captures logits per-branch. Accept-first
* ordering with rollback: if decode throws, sampler/grammar/metrics are
* restored from clones taken before the accept.
*
* @param entries - Array of `[branch, token]` tuples (branches must not be disposed)
* @throws If any branch is disposed
Expand Down
64 changes: 59 additions & 5 deletions src/SessionContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,9 @@ class BranchDecodeAndCaptureBatchWorker : public Napi::AsyncWorker {
};

/**
* AsyncWorker for batched multi-branch commit (decode_each + accept)
* Decode-first ordering: if decode_each throws, accept_token never runs
* AsyncWorker for batched multi-branch commit (accept + decode_each)
* Accept-first ordering with rollback: accepts tokens for correct PPL measurement,
* then decodes. On decode failure, restores sampler/grammar/metrics from clones.
*/
class StoreCommitWorker : public Napi::AsyncWorker {
public:
Expand All @@ -660,12 +661,65 @@ class StoreCommitWorker : public Napi::AsyncWorker {
: AsyncWorker(env), _deferred(env), _store(store), _items(std::move(items)) {}

void Execute() override {
// RAII snapshot of accept-mutable state. Destructor frees anything still
// owned, so partial clones from a throwing OOM don't leak.
struct Snapshot {
llama_sampler* sampler = nullptr;
llama_sampler* grammar = nullptr;
lloyal::metrics::BranchMetricsHandle metrics = 0;

~Snapshot() {
if (sampler) lloyal::sampler::free_chain(sampler);
if (grammar) lloyal::grammar::free_sampler(grammar);
if (metrics) lloyal::metrics::free_branch_metrics(metrics);
}

void restore_into(lloyal::branch::BranchState& st) {
std::swap(sampler, st.sampler_chain);
std::swap(grammar, st.grammar);
std::swap(metrics, st.metrics);
}
};

// Pre-size with unique_ptr so the vector never needs to move elements
std::vector<std::unique_ptr<Snapshot>> snaps(_items.size());

try {
_store.decode_each(_items);
for (auto& item : _items) {
// Phase 1: snapshot all accept-mutable state (no mutations yet)
for (size_t i = 0; i < _items.size(); ++i) {
auto* st = _store.get(_items[i].handle);
if (!st) throw std::runtime_error("StoreCommitWorker: invalid handle");

auto s = std::make_unique<Snapshot>();
s->sampler = st->sampler_chain
? lloyal::sampler::clone_chain(st->sampler_chain) : nullptr;
s->grammar = st->grammar
? lloyal::grammar::clone_sampler(st->grammar) : nullptr;
s->metrics = st->metrics != 0
? lloyal::metrics::clone_branch_metrics(st->metrics) : 0;
snaps[i] = std::move(s);
}

// Phase 2: accept all tokens (in-memory state mutation, won't throw)
for (auto& item : _items)
lloyal::branch::accept_token(item.handle, item.token, _store);

// Phase 3: decode (single GPU batch — the only realistic failure point)
_store.decode_each(_items);

// Success — discard snapshots
snaps.clear();

} catch (const std::exception& e) {
// Restore all branches — un-mutated branches get a harmless equivalent swap
for (size_t i = 0; i < _items.size(); ++i) {
auto* st = _store.get(_items[i].handle);
if (st && snaps[i]) snaps[i]->restore_into(*st);
}
} catch (const std::exception& e) { SetError(e.what()); }
// ~Snapshot frees the swapped-out (post-accept) state

SetError(e.what());
}
}

void OnOK() override { _deferred.Resolve(Env().Undefined()); }
Expand Down
2 changes: 1 addition & 1 deletion test/examples.js
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ const EXAMPLES = {
assert(candidates.length === 5, 'should have 5 candidates');

for (const c of candidates) {
assert(c.ppl >= 1, 'candidate ppl should be >= 1');
assert(c.ppl > 1 && c.ppl < 1000, `candidate ppl should be in (1, 1000), got ${c.ppl}`);
assert(c.tokenCount > 0, 'candidate should have tokens');
}

Expand Down
125 changes: 125 additions & 0 deletions test/integration.js
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,129 @@ async function testBranchStore() {
}
}

// ═══════════════════════════════════════════════════════════════════════════
// PPL SANITY — commit() must produce sane perplexity (not millions)
// ═══════════════════════════════════════════════════════════════════════════

async function testPplSanity() {
console.log('\n--- PPL Sanity ---');

const ctx = await addon.createContext({
modelPath: MODEL_PATH,
nCtx: CTX_SIZE,
nThreads: 4
});

try {
const messages = [{ role: 'user', content: 'Tell me about the weather.' }];
const { prompt } = await ctx.formatChat(JSON.stringify(messages));
const promptToks = await ctx.tokenize(prompt);
await ctx.decode(promptToks, 0, 0);

const branch = Branch.create(ctx, promptToks.length, { temperature: 0 });
branch.captureLogits();

for (let i = 0; i < 10; i++) {
const { token, isStop } = branch.produce();
if (isStop) break;
await branch.commit(token);
}

const ppl = branch.perplexity;
console.log(` perplexity after 10 commits: ${ppl.toFixed(2)}`);
assert(isFinite(ppl) && ppl >= 1.0 && ppl < 1000,
`PPL sanity: ${ppl.toFixed(2)} is in [1, 1000)`);

await branch.prune();
} finally {
ctx.dispose();
}
}

// ═══════════════════════════════════════════════════════════════════════════
// COMMIT ROLLBACK — decode failure must restore sampler/grammar/metrics
// ═══════════════════════════════════════════════════════════════════════════

async function testCommitRollback() {
console.log('\n--- Commit Rollback ---');

// Tiny KV (nCtx=32) with many branches (nSeqMax=8). Each branch consumes
// 1 KV cell per commit. With 8 branches and ~5 shared prefix cells, the
// 32-cell budget exhausts after ~3 commits per branch. decode_each returns
// non-zero (find_slot fails) → StoreCommitWorker throws → rollback fires.
const ctx = await addon.createContext({
modelPath: MODEL_PATH,
nCtx: 32,
nBatch: 512,
nThreads: 4,
nSeqMax: 8
});

try {
const promptToks = await ctx.tokenize("Hi");
await ctx.decode(promptToks, 0, 0);

const root = Branch.create(ctx, promptToks.length, { temperature: 1.0 });
root.captureLogits();
const branches = [root];
for (let i = 1; i < 8; i++) {
const b = await root.fork();
b.reseedSampler(1000 + i); // Divergent tokens → separate KV cells
branches.push(b);
}

const store = new BranchStore(ctx);

// Commit until decode fails from KV exhaustion
// nCtx may be clamped to a model minimum (e.g. 256), so we need enough
// rounds for 8 branches to exhaust ~256 cells: 256/8 = 32 rounds
let successfulRounds = 0;
let failedRound = false;
for (let round = 0; round < 50; round++) {
const live = branches
.map(b => [b, b.produce()])
.filter(([, p]) => !p.isStop);
if (!live.length) break;

// Snapshot PPL before this round
const pplsBefore = live.map(([b]) => b.perplexity);

try {
await store.commit(live.map(([b, p]) => [b, p.token]));
successfulRounds++;
} catch {
// Decode failed — verify PPL restored
const pplsAfter = live.map(([b]) => b.perplexity);
const allRestored = pplsBefore.every((p, i) => p === pplsAfter[i]);
assert(allRestored,
`rollback: all PPLs restored after decode failure at round ${round}`);

// Branches still usable for single commits (1 token fits)
const [b0, p0] = live[0];
const posBefore = b0.position;
try {
await b0.commit(p0.token);
assert(b0.position === posBefore + 1,
`rollback: single commit succeeds after failed batch (pos ${b0.position})`);
} catch {
// KV may be truly full even for 1 token — that's OK, test the PPL assertion above
}

failedRound = true;
break;
}
}

console.log(` ${successfulRounds} successful rounds before KV exhaustion`);
assert(failedRound,
`rollback: decode failure triggered (nCtx=32, 8 branches, ${successfulRounds} rounds)`);

await root.pruneSubtree();
} finally {
ctx.dispose();
}
}

// ═══════════════════════════════════════════════════════════════════════════
// ASYNC REJECTION — Worker failures must reject, branch state un-advanced
// ═══════════════════════════════════════════════════════════════════════════
Expand Down Expand Up @@ -1662,6 +1785,8 @@ async function main() {
await testDeterminism();
await testDecodeAndCapture();
await testBranchStore();
await testPplSanity();
await testCommitRollback();
await testAsyncRejection();
await testEmptyInputEdgeCases();
await testJsonSchemaToGrammar();
Expand Down
Loading