Skip to content

Commit b7e02e7

Browse files
Merge pull request #2 from lloyal-ai/feat/api-expansion
feat(logits): add safety for zero copy access
2 parents a674655 + 2e02c26 commit b7e02e7

11 files changed

Lines changed: 534 additions & 58 deletions

File tree

lib/index.d.ts

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,3 +875,65 @@ export interface SessionContext {
875875
* ```
876876
*/
877877
export function createContext(options: ContextOptions): Promise<SessionContext>;
878+
879+
/**
880+
* Safe logits access with Runtime Borrow Checker pattern
881+
*
882+
* Ensures logits are only accessed synchronously within the callback.
883+
* The callback MUST NOT:
884+
* - Store the logits reference
885+
* - Return a Promise (will throw)
886+
* - Call decode() (would invalidate logits)
887+
*
888+
* This is a "runtime borrow checker" - it prevents async mutations
889+
* while you're working with borrowed logits.
890+
*
891+
* Pattern: "Memoized Step-Scoped Views with Explicit Revocation"
892+
* - Memoization: If getLogits() called twice in same step, returns same buffer
893+
* - Revocation: On decode(), the previous buffer is detached
894+
*
895+
* @template T Return type of the callback
896+
* @param ctx The session context
897+
* @param fn Synchronous callback that uses logits - must not return a Promise
898+
* @returns The result from the callback
899+
* @throws Error if callback returns a Promise (async usage not allowed)
900+
*
901+
* @example Safe synchronous usage
902+
* ```typescript
903+
* // Compute entropy synchronously
904+
* const entropy = withLogits(ctx, (logits) => {
905+
* let maxLogit = logits[0];
906+
* for (let i = 1; i < logits.length; i++) {
907+
* if (logits[i] > maxLogit) maxLogit = logits[i];
908+
* }
909+
*
910+
* let sumExp = 0;
911+
* for (let i = 0; i < logits.length; i++) {
912+
* sumExp += Math.exp(logits[i] - maxLogit);
913+
* }
914+
*
915+
* let entropy = 0;
916+
* for (let i = 0; i < logits.length; i++) {
917+
* const p = Math.exp(logits[i] - maxLogit) / sumExp;
918+
* if (p > 0) entropy -= p * Math.log(p);
919+
* }
920+
* return entropy;
921+
* });
922+
*
923+
* // Now safe to decode (previous logits buffer is revoked)
924+
* await ctx.decode([nextToken], position++);
925+
* ```
926+
*
927+
* @example Error: async callback
928+
* ```typescript
929+
* // This will throw!
930+
* withLogits(ctx, async (logits) => {
931+
* await something(); // NOT ALLOWED
932+
* return logits[0];
933+
* });
934+
* ```
935+
*/
936+
export function withLogits<T>(
937+
ctx: SessionContext,
938+
fn: (logits: Float32Array) => T
939+
): T;

lib/index.js

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ const binary = require('node-gyp-build')(path.join(__dirname, '..'));
99
*
1010
* @example
1111
* ```js
12-
* const { createContext } = require('liblloyal-node');
12+
* const { createContext, withLogits } = require('lloyal.node');
1313
*
1414
* const ctx = await createContext({
1515
* modelPath: './model.gguf',
@@ -23,18 +23,76 @@ const binary = require('node-gyp-build')(path.join(__dirname, '..'));
2323
* // Decode
2424
* await ctx.decode(tokens, 0);
2525
*
26-
* // Get raw logits (zero-copy Float32Array)
27-
* const logits = ctx.getLogits();
26+
* // Safe logits access (Runtime Borrow Checker pattern)
27+
* const entropy = await withLogits(ctx, (logits) => {
28+
* // logits is valid here - use synchronously only!
29+
* return computeEntropy(logits);
30+
* });
2831
*
29-
* // Native reference implementations (for testing)
30-
* const entropy = ctx.computeEntropy();
32+
* // Or with native reference implementations (for testing)
33+
* const nativeEntropy = ctx.computeEntropy();
3134
* const token = ctx.greedySample();
3235
*
3336
* // Cleanup
3437
* ctx.dispose();
3538
* ```
3639
*/
3740

41+
/**
42+
* Safe logits access with Runtime Borrow Checker pattern
43+
*
44+
* Ensures logits are only accessed synchronously within the callback.
45+
* The callback MUST NOT:
46+
* - Store the logits reference
47+
* - Return a Promise (will throw)
48+
* - Call decode() (would invalidate logits)
49+
*
50+
* This is a "runtime borrow checker" - it prevents async mutations
51+
* while you're working with borrowed logits.
52+
*
53+
* @template T
54+
* @param {SessionContext} ctx - The session context
55+
* @param {(logits: Float32Array) => T} fn - Synchronous callback that uses logits
56+
* @returns {T} The result from the callback
57+
* @throws {Error} If callback returns a Promise (async usage not allowed)
58+
*
59+
* @example
60+
* ```js
61+
* // Safe: synchronous computation
62+
* const entropy = withLogits(ctx, (logits) => {
63+
* let sum = 0;
64+
* for (let i = 0; i < logits.length; i++) {
65+
* sum += Math.exp(logits[i]);
66+
* }
67+
* return Math.log(sum);
68+
* });
69+
*
70+
* // ERROR: callback returns Promise (will throw)
71+
* withLogits(ctx, async (logits) => {
72+
* await something(); // NOT ALLOWED
73+
* return logits[0];
74+
* });
75+
* ```
76+
*/
77+
function withLogits(ctx, fn) {
78+
// Get logits (memoized - same buffer if called twice in same step)
79+
const logits = ctx.getLogits();
80+
81+
// Execute user callback with logits
82+
const result = fn(logits);
83+
84+
// Detect async usage (not allowed - logits would be invalidated)
85+
if (result && typeof result.then === 'function') {
86+
throw new Error(
87+
'withLogits callback must be synchronous. ' +
88+
'Returning a Promise is not allowed because logits become invalid after decode(). ' +
89+
'Complete all logits processing synchronously within the callback.'
90+
);
91+
}
92+
93+
return result;
94+
}
95+
3896
module.exports = {
3997
/**
4098
* Create a new inference context
@@ -51,5 +109,13 @@ module.exports = {
51109
return binary.createContext(options);
52110
},
53111

112+
/**
113+
* Safe logits access with Runtime Borrow Checker pattern
114+
*
115+
* Ensures logits are only accessed synchronously within the callback.
116+
* See function JSDoc for full documentation.
117+
*/
118+
withLogits,
119+
54120
SessionContext: binary.SessionContext
55121
};

src/SessionContext.cpp

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <lloyal/grammar.hpp>
1111
#include <lloyal/kv.hpp>
1212
#include <lloyal/embedding.hpp>
13+
#include <lloyal/logits.hpp>
1314
#include <cmath>
1415

1516
namespace liblloyal_node {
@@ -628,6 +629,34 @@ void SessionContext::initializeContext(
628629
std::cerr << " Shared refcount: " << _model.use_count() << std::endl;
629630
}
630631

632+
// ===== LOGITS BUFFER MANAGEMENT =====
633+
634+
void SessionContext::invalidateLogits() {
635+
// The Kill Switch: Detach any active logits buffer
636+
//
637+
// This is called before any operation that invalidates the logits pointer:
638+
// - decode() - new forward pass overwrites logits
639+
// - encode() - embedding pass overwrites logits
640+
// - dispose() - context is destroyed
641+
//
642+
// After detach, any JS code holding a reference to the buffer will get
643+
// a TypeError when trying to access it - exactly what we want.
644+
if (!_logitsBufferRef.IsEmpty()) {
645+
try {
646+
Napi::ArrayBuffer buffer = _logitsBufferRef.Value();
647+
if (!buffer.IsDetached()) {
648+
buffer.Detach();
649+
}
650+
} catch (...) {
651+
// Buffer may have been garbage collected - that's fine
652+
}
653+
_logitsBufferRef.Reset();
654+
}
655+
656+
// Increment step counter - any new getLogits() call will create fresh buffer
657+
_decodeStepId++;
658+
}
659+
631660
Napi::Value SessionContext::getLogits(const Napi::CallbackInfo& info) {
632661
Napi::Env env = info.Env();
633662
ensureNotDisposed();
@@ -636,23 +665,42 @@ Napi::Value SessionContext::getLogits(const Napi::CallbackInfo& info) {
636665
throw Napi::Error::New(env, "Context not initialized");
637666
}
638667

639-
// Get raw logits pointer (zero-copy)
640-
float* logits = llama_get_logits_ith(_context, -1);
641-
if (!logits) {
642-
throw Napi::Error::New(env, "Failed to get logits");
668+
// ===== MEMOIZATION: Return same buffer if already created for this step =====
669+
//
670+
// Pattern: "Memoized Step-Scoped Views"
671+
// If caller calls getLogits() twice in the same step, return the same buffer.
672+
// This avoids creating multiple views into the same memory.
673+
if (_logitsStepId == _decodeStepId && !_logitsBufferRef.IsEmpty()) {
674+
// Same step, reuse existing buffer
675+
Napi::ArrayBuffer existingBuffer = _logitsBufferRef.Value();
676+
const int n_vocab = lloyal::tokenizer::vocab_size(_model.get());
677+
return Napi::Float32Array::New(env, n_vocab, existingBuffer, 0);
678+
}
679+
680+
// ===== NEW BUFFER: Get logits via lloyal wrapper (handles null checks) =====
681+
//
682+
// lloyal::logits::get() throws descriptive errors if:
683+
// - Context is null
684+
// - Logits unavailable (decode() not called with logits=true)
685+
float* logits;
686+
try {
687+
logits = lloyal::logits::get(_context, -1);
688+
} catch (const std::exception& e) {
689+
throw Napi::Error::New(env, e.what());
643690
}
644691

645-
// Use model overload for vocab_size
646692
const int n_vocab = lloyal::tokenizer::vocab_size(_model.get());
647693

648-
// Create Float32Array wrapping the logits (zero-copy!)
649-
// WARNING: This is only valid until next decode() call
650-
return Napi::Float32Array::New(
651-
env,
652-
n_vocab,
653-
Napi::ArrayBuffer::New(env, logits, n_vocab * sizeof(float)),
654-
0
655-
);
694+
// Create ArrayBuffer wrapping the logits (zero-copy!)
695+
// Store reference for memoization and future revocation
696+
Napi::ArrayBuffer buffer = Napi::ArrayBuffer::New(env, logits, n_vocab * sizeof(float));
697+
698+
// Store weak reference for memoization
699+
_logitsBufferRef = Napi::Reference<Napi::ArrayBuffer>::New(buffer, 1);
700+
_logitsStepId = _decodeStepId;
701+
702+
// Return Float32Array view
703+
return Napi::Float32Array::New(env, n_vocab, buffer, 0);
656704
}
657705

658706
Napi::Value SessionContext::decode(const Napi::CallbackInfo& info) {
@@ -663,6 +711,9 @@ Napi::Value SessionContext::decode(const Napi::CallbackInfo& info) {
663711
throw Napi::TypeError::New(env, "Expected (tokens: number[], position: number)");
664712
}
665713

714+
// Revoke any active logits buffer before decode
715+
invalidateLogits();
716+
666717
// Extract tokens
667718
Napi::Array jsTokens = info[0].As<Napi::Array>();
668719
std::vector<llama_token> tokens;
@@ -733,10 +784,12 @@ Napi::Value SessionContext::computeEntropy(const Napi::CallbackInfo& info) {
733784
throw Napi::Error::New(env, "Context not initialized");
734785
}
735786

736-
// Get logits
737-
float* logits = llama_get_logits_ith(_context, -1);
738-
if (!logits) {
739-
throw Napi::Error::New(env, "Failed to get logits");
787+
// Get logits via lloyal wrapper (handles null checks)
788+
float* logits;
789+
try {
790+
logits = lloyal::logits::get(_context, -1);
791+
} catch (const std::exception& e) {
792+
throw Napi::Error::New(env, e.what());
740793
}
741794

742795
// Use model overload for vocab_size
@@ -821,6 +874,9 @@ Napi::Value SessionContext::encode(const Napi::CallbackInfo& info) {
821874
throw Napi::TypeError::New(env, "Expected (tokens: number[])");
822875
}
823876

877+
// Revoke any active logits buffer before encode
878+
invalidateLogits();
879+
824880
// Extract tokens
825881
Napi::Array jsTokens = info[0].As<Napi::Array>();
826882
std::vector<llama_token> tokens;
@@ -987,7 +1043,10 @@ Napi::Value SessionContext::dispose(const Napi::CallbackInfo& info) {
9871043
Napi::Env env = info.Env();
9881044

9891045
if (!_disposed) {
990-
// Free grammar sampler first
1046+
// Revoke any active logits buffer before disposing
1047+
invalidateLogits();
1048+
1049+
// Free grammar sampler
9911050
if (_grammarSampler) {
9921051
llama_sampler_free(_grammarSampler);
9931052
_grammarSampler = nullptr;
@@ -1027,11 +1086,13 @@ Napi::Value SessionContext::getTokenScores(const Napi::CallbackInfo& info) {
10271086
throw Napi::Error::New(env, "Context not initialized");
10281087
}
10291088

1030-
// Get raw logits pointer from llama.cpp (last-step logits, index -1)
1089+
// Get raw logits pointer via lloyal wrapper (handles null checks)
10311090
// Returns mutable float* - we need to modify logits for grammar constraints
1032-
float* logits = llama_get_logits_ith(_context, -1);
1033-
if (!logits) {
1034-
throw Napi::Error::New(env, "Failed to get logits (ensure decode had logits=true)");
1091+
float* logits;
1092+
try {
1093+
logits = lloyal::logits::get(_context, -1);
1094+
} catch (const std::exception& e) {
1095+
throw Napi::Error::New(env, e.what());
10351096
}
10361097

10371098
// Get vocabulary size using model overload

src/SessionContext.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,18 @@ class SessionContext : public Napi::ObjectWrap<SessionContext> {
202202
llama_sampler* _grammarSampler = nullptr;
203203
std::string _currentGrammar; // Track current grammar string to avoid re-initialization
204204

205+
// ===== LOGITS BUFFER MANAGEMENT (Memoization + Revocation) =====
206+
//
207+
// Pattern: "Memoized Step-Scoped Views with Explicit Revocation"
208+
//
209+
// - Memoization: If getLogits() called twice in same step, return same buffer
210+
// - Revocation: On decode(), detach previous buffer to prevent use-after-invalidation
211+
//
212+
// See: lloyal::logits::get() for the underlying safe wrapper
213+
uint64_t _decodeStepId = 0; // Incremented on each decode()
214+
uint64_t _logitsStepId = 0; // Step when _logitsBuffer was created
215+
Napi::Reference<Napi::ArrayBuffer> _logitsBufferRef; // Strong reference - kept alive so we can Detach() on revocation
216+
205217
// ===== INLINE HELPERS =====
206218
// Pattern matches HybridSessionContext.hpp:170-176
207219

@@ -218,6 +230,19 @@ class SessionContext : public Napi::ObjectWrap<SessionContext> {
218230
inline llama_pos toPos(double pos) {
219231
return static_cast<llama_pos>(pos);
220232
}
233+
234+
/**
235+
* Invalidate any active logits buffer (The Kill Switch)
236+
*
237+
* Called before any operation that would invalidate the logits pointer:
238+
* - decode()
239+
* - encode()
240+
* - dispose()
241+
*
242+
* Detaches the ArrayBuffer so any JS code holding a reference
243+
* will get a TypeError when trying to access it.
244+
*/
245+
void invalidateLogits();
221246
};
222247

223248
/**

0 commit comments

Comments
 (0)