Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -875,3 +875,65 @@ export interface SessionContext {
* ```
*/
export function createContext(options: ContextOptions): Promise<SessionContext>;

/**
* Safe logits access with Runtime Borrow Checker pattern
*
* Ensures logits are only accessed synchronously within the callback.
* The callback MUST NOT:
* - Store the logits reference
* - Return a Promise (will throw)
* - Call decode() (would invalidate logits)
*
* This is a "runtime borrow checker" - it prevents async mutations
* while you're working with borrowed logits.
*
* Pattern: "Memoized Step-Scoped Views with Explicit Revocation"
* - Memoization: If getLogits() called twice in same step, returns same buffer
* - Revocation: On decode(), the previous buffer is detached
*
* @template T Return type of the callback
* @param ctx The session context
* @param fn Synchronous callback that uses logits - must not return a Promise
* @returns The result from the callback
* @throws Error if callback returns a Promise (async usage not allowed)
*
* @example Safe synchronous usage
* ```typescript
* // Compute entropy synchronously
* const entropy = withLogits(ctx, (logits) => {
* let maxLogit = logits[0];
* for (let i = 1; i < logits.length; i++) {
* if (logits[i] > maxLogit) maxLogit = logits[i];
* }
*
* let sumExp = 0;
* for (let i = 0; i < logits.length; i++) {
* sumExp += Math.exp(logits[i] - maxLogit);
* }
*
* let entropy = 0;
* for (let i = 0; i < logits.length; i++) {
* const p = Math.exp(logits[i] - maxLogit) / sumExp;
* if (p > 0) entropy -= p * Math.log(p);
* }
* return entropy;
* });
*
* // Now safe to decode (previous logits buffer is revoked)
* await ctx.decode([nextToken], position++);
* ```
*
* @example Error: async callback
* ```typescript
* // This will throw!
* withLogits(ctx, async (logits) => {
* await something(); // NOT ALLOWED
* return logits[0];
* });
* ```
*/
export function withLogits<T>(
ctx: SessionContext,
fn: (logits: Float32Array) => T
): T;
76 changes: 71 additions & 5 deletions lib/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const binary = require('node-gyp-build')(path.join(__dirname, '..'));
*
* @example
* ```js
* const { createContext } = require('liblloyal-node');
* const { createContext, withLogits } = require('lloyal.node');
*
* const ctx = await createContext({
* modelPath: './model.gguf',
Expand All @@ -23,18 +23,76 @@ const binary = require('node-gyp-build')(path.join(__dirname, '..'));
* // Decode
* await ctx.decode(tokens, 0);
*
* // Get raw logits (zero-copy Float32Array)
* const logits = ctx.getLogits();
* // Safe logits access (Runtime Borrow Checker pattern)
* const entropy = await withLogits(ctx, (logits) => {
* // logits is valid here - use synchronously only!
* return computeEntropy(logits);
* });
*
* // Native reference implementations (for testing)
* const entropy = ctx.computeEntropy();
* // Or with native reference implementations (for testing)
* const nativeEntropy = ctx.computeEntropy();
* const token = ctx.greedySample();
*
* // Cleanup
* ctx.dispose();
* ```
*/

/**
* Safe logits access with Runtime Borrow Checker pattern
*
* Ensures logits are only accessed synchronously within the callback.
* The callback MUST NOT:
* - Store the logits reference
* - Return a Promise (will throw)
* - Call decode() (would invalidate logits)
*
* This is a "runtime borrow checker" - it prevents async mutations
* while you're working with borrowed logits.
*
* @template T
* @param {SessionContext} ctx - The session context
* @param {(logits: Float32Array) => T} fn - Synchronous callback that uses logits
* @returns {T} The result from the callback
* @throws {Error} If callback returns a Promise (async usage not allowed)
*
* @example
* ```js
* // Safe: synchronous computation
* const entropy = withLogits(ctx, (logits) => {
* let sum = 0;
* for (let i = 0; i < logits.length; i++) {
* sum += Math.exp(logits[i]);
* }
* return Math.log(sum);
* });
*
* // ERROR: callback returns Promise (will throw)
* withLogits(ctx, async (logits) => {
* await something(); // NOT ALLOWED
* return logits[0];
* });
* ```
*/
function withLogits(ctx, fn) {
// Get logits (memoized - same buffer if called twice in same step)
const logits = ctx.getLogits();

// Execute user callback with logits
const result = fn(logits);

// Detect async usage (not allowed - logits would be invalidated)
if (result && typeof result.then === 'function') {
throw new Error(
'withLogits callback must be synchronous. ' +
'Returning a Promise is not allowed because logits become invalid after decode(). ' +
'Complete all logits processing synchronously within the callback.'
);
}

return result;
}

module.exports = {
/**
* Create a new inference context
Expand All @@ -51,5 +109,13 @@ module.exports = {
return binary.createContext(options);
},

/**
* Safe logits access with Runtime Borrow Checker pattern
*
* Ensures logits are only accessed synchronously within the callback.
* See function JSDoc for full documentation.
*/
withLogits,

SessionContext: binary.SessionContext
};
105 changes: 83 additions & 22 deletions src/SessionContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <lloyal/grammar.hpp>
#include <lloyal/kv.hpp>
#include <lloyal/embedding.hpp>
#include <lloyal/logits.hpp>
#include <cmath>

namespace liblloyal_node {
Expand Down Expand Up @@ -628,6 +629,34 @@ void SessionContext::initializeContext(
std::cerr << " Shared refcount: " << _model.use_count() << std::endl;
}

// ===== LOGITS BUFFER MANAGEMENT =====

void SessionContext::invalidateLogits() {
// The Kill Switch: Detach any active logits buffer
//
// This is called before any operation that invalidates the logits pointer:
// - decode() - new forward pass overwrites logits
// - encode() - embedding pass overwrites logits
// - dispose() - context is destroyed
//
// After detach, any JS code holding a reference to the buffer will get
// a TypeError when trying to access it - exactly what we want.
if (!_logitsBufferRef.IsEmpty()) {
try {
Napi::ArrayBuffer buffer = _logitsBufferRef.Value();
if (!buffer.IsDetached()) {
buffer.Detach();
}
} catch (...) {
// Buffer may have been garbage collected - that's fine
}
_logitsBufferRef.Reset();
}

// Increment step counter - any new getLogits() call will create fresh buffer
_decodeStepId++;
}

Napi::Value SessionContext::getLogits(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
ensureNotDisposed();
Expand All @@ -636,23 +665,42 @@ Napi::Value SessionContext::getLogits(const Napi::CallbackInfo& info) {
throw Napi::Error::New(env, "Context not initialized");
}

// Get raw logits pointer (zero-copy)
float* logits = llama_get_logits_ith(_context, -1);
if (!logits) {
throw Napi::Error::New(env, "Failed to get logits");
// ===== MEMOIZATION: Return same buffer if already created for this step =====
//
// Pattern: "Memoized Step-Scoped Views"
// If caller calls getLogits() twice in the same step, return the same buffer.
// This avoids creating multiple views into the same memory.
if (_logitsStepId == _decodeStepId && !_logitsBufferRef.IsEmpty()) {
// Same step, reuse existing buffer
Napi::ArrayBuffer existingBuffer = _logitsBufferRef.Value();
const int n_vocab = lloyal::tokenizer::vocab_size(_model.get());
return Napi::Float32Array::New(env, n_vocab, existingBuffer, 0);
}

// ===== NEW BUFFER: Get logits via lloyal wrapper (handles null checks) =====
//
// lloyal::logits::get() throws descriptive errors if:
// - Context is null
// - Logits unavailable (decode() not called with logits=true)
float* logits;
try {
logits = lloyal::logits::get(_context, -1);
} catch (const std::exception& e) {
throw Napi::Error::New(env, e.what());
}

// Use model overload for vocab_size
const int n_vocab = lloyal::tokenizer::vocab_size(_model.get());

// Create Float32Array wrapping the logits (zero-copy!)
// WARNING: This is only valid until next decode() call
return Napi::Float32Array::New(
env,
n_vocab,
Napi::ArrayBuffer::New(env, logits, n_vocab * sizeof(float)),
0
);
// Create ArrayBuffer wrapping the logits (zero-copy!)
// Store reference for memoization and future revocation
Napi::ArrayBuffer buffer = Napi::ArrayBuffer::New(env, logits, n_vocab * sizeof(float));

// Store weak reference for memoization
_logitsBufferRef = Napi::Reference<Napi::ArrayBuffer>::New(buffer, 1);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: comment says "weak reference" but reference count is 1 (strong reference). should use 0 for weak reference to avoid preventing garbage collection

Suggested change
_logitsBufferRef = Napi::Reference<Napi::ArrayBuffer>::New(buffer, 1);
_logitsBufferRef = Napi::Reference<Napi::ArrayBuffer>::New(buffer, 0);
Prompt To Fix With AI
This is a comment left during a code review.
Path: src/SessionContext.cpp
Line: 699:699

Comment:
**logic:** comment says "weak reference" but reference count is 1 (strong reference). should use `0` for weak reference to avoid preventing garbage collection

```suggestion
  _logitsBufferRef = Napi::Reference<Napi::ArrayBuffer>::New(buffer, 0);
```

How can I resolve this? If you propose a fix, please make it concise.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment in the header says "Weak reference to detach on revocation" but that's describing the intent (we hold it so we can detach it later), not the N-API reference semantics.

Why we need refcount=1 (strong reference):

  • We call buffer.Detach() later in invalidateLogits()
  • If refcount=0, the buffer could be GC'd before we call Detach
  • We need to keep it alive so we can explicitly invalidate it

Why refcount=0 would break things:

  • Buffer gets GC'd
  • _logitsBufferRef.Value() returns invalid/detached buffer
  • buffer.Detach() fails or crashes

The fix is to update the misleading comment in the header, not change the refcount

_logitsStepId = _decodeStepId;

// Return Float32Array view
return Napi::Float32Array::New(env, n_vocab, buffer, 0);
}

Napi::Value SessionContext::decode(const Napi::CallbackInfo& info) {
Expand All @@ -663,6 +711,9 @@ Napi::Value SessionContext::decode(const Napi::CallbackInfo& info) {
throw Napi::TypeError::New(env, "Expected (tokens: number[], position: number)");
}

// Revoke any active logits buffer before decode
invalidateLogits();

// Extract tokens
Napi::Array jsTokens = info[0].As<Napi::Array>();
std::vector<llama_token> tokens;
Expand Down Expand Up @@ -733,10 +784,12 @@ Napi::Value SessionContext::computeEntropy(const Napi::CallbackInfo& info) {
throw Napi::Error::New(env, "Context not initialized");
}

// Get logits
float* logits = llama_get_logits_ith(_context, -1);
if (!logits) {
throw Napi::Error::New(env, "Failed to get logits");
// Get logits via lloyal wrapper (handles null checks)
float* logits;
try {
logits = lloyal::logits::get(_context, -1);
} catch (const std::exception& e) {
throw Napi::Error::New(env, e.what());
}

// Use model overload for vocab_size
Expand Down Expand Up @@ -821,6 +874,9 @@ Napi::Value SessionContext::encode(const Napi::CallbackInfo& info) {
throw Napi::TypeError::New(env, "Expected (tokens: number[])");
}

// Revoke any active logits buffer before encode
invalidateLogits();

// Extract tokens
Napi::Array jsTokens = info[0].As<Napi::Array>();
std::vector<llama_token> tokens;
Expand Down Expand Up @@ -987,7 +1043,10 @@ Napi::Value SessionContext::dispose(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();

if (!_disposed) {
// Free grammar sampler first
// Revoke any active logits buffer before disposing
invalidateLogits();

// Free grammar sampler
if (_grammarSampler) {
llama_sampler_free(_grammarSampler);
_grammarSampler = nullptr;
Expand Down Expand Up @@ -1027,11 +1086,13 @@ Napi::Value SessionContext::getTokenScores(const Napi::CallbackInfo& info) {
throw Napi::Error::New(env, "Context not initialized");
}

// Get raw logits pointer from llama.cpp (last-step logits, index -1)
// Get raw logits pointer via lloyal wrapper (handles null checks)
// Returns mutable float* - we need to modify logits for grammar constraints
float* logits = llama_get_logits_ith(_context, -1);
if (!logits) {
throw Napi::Error::New(env, "Failed to get logits (ensure decode had logits=true)");
float* logits;
try {
logits = lloyal::logits::get(_context, -1);
} catch (const std::exception& e) {
throw Napi::Error::New(env, e.what());
}

// Get vocabulary size using model overload
Expand Down
25 changes: 25 additions & 0 deletions src/SessionContext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@ class SessionContext : public Napi::ObjectWrap<SessionContext> {
llama_sampler* _grammarSampler = nullptr;
std::string _currentGrammar; // Track current grammar string to avoid re-initialization

// ===== LOGITS BUFFER MANAGEMENT (Memoization + Revocation) =====
//
// Pattern: "Memoized Step-Scoped Views with Explicit Revocation"
//
// - Memoization: If getLogits() called twice in same step, return same buffer
// - Revocation: On decode(), detach previous buffer to prevent use-after-invalidation
//
// See: lloyal::logits::get() for the underlying safe wrapper
uint64_t _decodeStepId = 0; // Incremented on each decode()
uint64_t _logitsStepId = 0; // Step when _logitsBuffer was created
Napi::Reference<Napi::ArrayBuffer> _logitsBufferRef; // Strong reference - kept alive so we can Detach() on revocation

// ===== INLINE HELPERS =====
// Pattern matches HybridSessionContext.hpp:170-176

Expand All @@ -218,6 +230,19 @@ class SessionContext : public Napi::ObjectWrap<SessionContext> {
inline llama_pos toPos(double pos) {
return static_cast<llama_pos>(pos);
}

/**
* Invalidate any active logits buffer (The Kill Switch)
*
* Called before any operation that would invalidate the logits pointer:
* - decode()
* - encode()
* - dispose()
*
* Detaches the ArrayBuffer so any JS code holding a reference
* will get a TypeError when trying to access it.
*/
void invalidateLogits();
};

/**
Expand Down
Loading
Loading