1010#include < lloyal/grammar.hpp>
1111#include < lloyal/kv.hpp>
1212#include < lloyal/embedding.hpp>
13+ #include < lloyal/logits.hpp>
1314#include < cmath>
1415
1516namespace liblloyal_node {
@@ -628,6 +629,34 @@ void SessionContext::initializeContext(
628629 std::cerr << " Shared refcount: " << _model.use_count () << std::endl;
629630}
630631
632+ // ===== LOGITS BUFFER MANAGEMENT =====
633+
634+ void SessionContext::invalidateLogits () {
635+ // The Kill Switch: Detach any active logits buffer
636+ //
637+ // This is called before any operation that invalidates the logits pointer:
638+ // - decode() - new forward pass overwrites logits
639+ // - encode() - embedding pass overwrites logits
640+ // - dispose() - context is destroyed
641+ //
642+ // After detach, any JS code holding a reference to the buffer will get
643+ // a TypeError when trying to access it - exactly what we want.
644+ if (!_logitsBufferRef.IsEmpty ()) {
645+ try {
646+ Napi::ArrayBuffer buffer = _logitsBufferRef.Value ();
647+ if (!buffer.IsDetached ()) {
648+ buffer.Detach ();
649+ }
650+ } catch (...) {
651+ // Buffer may have been garbage collected - that's fine
652+ }
653+ _logitsBufferRef.Reset ();
654+ }
655+
656+ // Increment step counter - any new getLogits() call will create fresh buffer
657+ _decodeStepId++;
658+ }
659+
631660Napi::Value SessionContext::getLogits (const Napi::CallbackInfo& info) {
632661 Napi::Env env = info.Env ();
633662 ensureNotDisposed ();
@@ -636,23 +665,42 @@ Napi::Value SessionContext::getLogits(const Napi::CallbackInfo& info) {
636665 throw Napi::Error::New (env, " Context not initialized" );
637666 }
638667
639- // Get raw logits pointer (zero-copy)
640- float * logits = llama_get_logits_ith (_context, -1 );
641- if (!logits) {
642- throw Napi::Error::New (env, " Failed to get logits" );
668+ // ===== MEMOIZATION: Return same buffer if already created for this step =====
669+ //
670+ // Pattern: "Memoized Step-Scoped Views"
671+ // If caller calls getLogits() twice in the same step, return the same buffer.
672+ // This avoids creating multiple views into the same memory.
673+ if (_logitsStepId == _decodeStepId && !_logitsBufferRef.IsEmpty ()) {
674+ // Same step, reuse existing buffer
675+ Napi::ArrayBuffer existingBuffer = _logitsBufferRef.Value ();
676+ const int n_vocab = lloyal::tokenizer::vocab_size (_model.get ());
677+ return Napi::Float32Array::New (env, n_vocab, existingBuffer, 0 );
678+ }
679+
680+ // ===== NEW BUFFER: Get logits via lloyal wrapper (handles null checks) =====
681+ //
682+ // lloyal::logits::get() throws descriptive errors if:
683+ // - Context is null
684+ // - Logits unavailable (decode() not called with logits=true)
685+ float * logits;
686+ try {
687+ logits = lloyal::logits::get (_context, -1 );
688+ } catch (const std::exception& e) {
689+ throw Napi::Error::New (env, e.what ());
643690 }
644691
645- // Use model overload for vocab_size
646692 const int n_vocab = lloyal::tokenizer::vocab_size (_model.get ());
647693
648- // Create Float32Array wrapping the logits (zero-copy!)
649- // WARNING: This is only valid until next decode() call
650- return Napi::Float32Array::New (
651- env,
652- n_vocab,
653- Napi::ArrayBuffer::New (env, logits, n_vocab * sizeof (float )),
654- 0
655- );
694+ // Create ArrayBuffer wrapping the logits (zero-copy!)
695+ // Store reference for memoization and future revocation
696+ Napi::ArrayBuffer buffer = Napi::ArrayBuffer::New (env, logits, n_vocab * sizeof (float ));
697+
698+ // Store weak reference for memoization
699+ _logitsBufferRef = Napi::Reference<Napi::ArrayBuffer>::New (buffer, 1 );
700+ _logitsStepId = _decodeStepId;
701+
702+ // Return Float32Array view
703+ return Napi::Float32Array::New (env, n_vocab, buffer, 0 );
656704}
657705
658706Napi::Value SessionContext::decode (const Napi::CallbackInfo& info) {
@@ -663,6 +711,9 @@ Napi::Value SessionContext::decode(const Napi::CallbackInfo& info) {
663711 throw Napi::TypeError::New (env, " Expected (tokens: number[], position: number)" );
664712 }
665713
714+ // Revoke any active logits buffer before decode
715+ invalidateLogits ();
716+
666717 // Extract tokens
667718 Napi::Array jsTokens = info[0 ].As <Napi::Array>();
668719 std::vector<llama_token> tokens;
@@ -733,10 +784,12 @@ Napi::Value SessionContext::computeEntropy(const Napi::CallbackInfo& info) {
733784 throw Napi::Error::New (env, " Context not initialized" );
734785 }
735786
736- // Get logits
737- float * logits = llama_get_logits_ith (_context, -1 );
738- if (!logits) {
739- throw Napi::Error::New (env, " Failed to get logits" );
787+ // Get logits via lloyal wrapper (handles null checks)
788+ float * logits;
789+ try {
790+ logits = lloyal::logits::get (_context, -1 );
791+ } catch (const std::exception& e) {
792+ throw Napi::Error::New (env, e.what ());
740793 }
741794
742795 // Use model overload for vocab_size
@@ -821,6 +874,9 @@ Napi::Value SessionContext::encode(const Napi::CallbackInfo& info) {
821874 throw Napi::TypeError::New (env, " Expected (tokens: number[])" );
822875 }
823876
877+ // Revoke any active logits buffer before encode
878+ invalidateLogits ();
879+
824880 // Extract tokens
825881 Napi::Array jsTokens = info[0 ].As <Napi::Array>();
826882 std::vector<llama_token> tokens;
@@ -987,7 +1043,10 @@ Napi::Value SessionContext::dispose(const Napi::CallbackInfo& info) {
9871043 Napi::Env env = info.Env ();
9881044
9891045 if (!_disposed) {
990- // Free grammar sampler first
1046+ // Revoke any active logits buffer before disposing
1047+ invalidateLogits ();
1048+
1049+ // Free grammar sampler
9911050 if (_grammarSampler) {
9921051 llama_sampler_free (_grammarSampler);
9931052 _grammarSampler = nullptr ;
@@ -1027,11 +1086,13 @@ Napi::Value SessionContext::getTokenScores(const Napi::CallbackInfo& info) {
10271086 throw Napi::Error::New (env, " Context not initialized" );
10281087 }
10291088
1030- // Get raw logits pointer from llama.cpp (last-step logits, index -1 )
1089+ // Get raw logits pointer via lloyal wrapper (handles null checks )
10311090 // Returns mutable float* - we need to modify logits for grammar constraints
1032- float * logits = llama_get_logits_ith (_context, -1 );
1033- if (!logits) {
1034- throw Napi::Error::New (env, " Failed to get logits (ensure decode had logits=true)" );
1091+ float * logits;
1092+ try {
1093+ logits = lloyal::logits::get (_context, -1 );
1094+ } catch (const std::exception& e) {
1095+ throw Napi::Error::New (env, e.what ());
10351096 }
10361097
10371098 // Get vocabulary size using model overload
0 commit comments