From be811421b7867cab410335b890537d49f7a96e27 Mon Sep 17 00:00:00 2001 From: hsinhoyeh Date: Thu, 12 Mar 2026 16:38:31 +0800 Subject: [PATCH] fix panic: panic in streamState{cb: cb} --- ggml/llamacpp/llamacpp.go | 51 ++++++++++++++++++++++++++++++++++++--- ggml/llamacpp/wrapper.cpp | 13 +++++++--- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/ggml/llamacpp/llamacpp.go b/ggml/llamacpp/llamacpp.go index 900357a..4ea44f2 100644 --- a/ggml/llamacpp/llamacpp.go +++ b/ggml/llamacpp/llamacpp.go @@ -173,15 +173,53 @@ func (c *Context) Generate(prompt string, opts ...GenerateOption) (string, error return sb.String(), nil } -// streamState holds the callback state passed through CGO. +// streamState holds the callback state for a generation request. type streamState struct { cb func(string) bool stopped bool } +// callbackRegistry stores streamState objects keyed by an integer handle. +// This avoids passing Go pointers (which contain other Go pointers like +// function closures) to C, which violates CGo pointer rules. +var ( + callbackMu sync.Mutex + callbackRegistry = map[uintptr]*streamState{} + callbackNextID uintptr +) + +// registerCallback stores a streamState and returns an opaque handle +// that can safely be passed to C as a uintptr. +func registerCallback(state *streamState) uintptr { + callbackMu.Lock() + defer callbackMu.Unlock() + callbackNextID++ + id := callbackNextID + callbackRegistry[id] = state + return id +} + +// unregisterCallback removes a streamState from the registry. +func unregisterCallback(id uintptr) { + callbackMu.Lock() + defer callbackMu.Unlock() + delete(callbackRegistry, id) +} + +// lookupCallback retrieves a streamState by handle. +func lookupCallback(id uintptr) *streamState { + callbackMu.Lock() + defer callbackMu.Unlock() + return callbackRegistry[id] +} + //export goTokenCallback func goTokenCallback(token *C.char, length C.int, userData unsafe.Pointer) C.int { - state := (*streamState)(userData) + id := uintptr(userData) + state := lookupCallback(id) + if state == nil { + return 0 + } goToken := C.GoStringN(token, length) if !state.cb(goToken) { state.stopped = true @@ -220,7 +258,14 @@ func (c *Context) GenerateStream(prompt string, cb func(token string) bool, opts params.seed = C.int(cfg.seed) params.penalty_last_n = C.int(cfg.penaltyLastN) + // Use a handle registry instead of passing Go pointers to C. + // streamState contains a Go function pointer (cb), so passing it + // directly to C via unsafe.Pointer violates CGo pointer rules: + // "cgo argument has Go pointer to unpinned Go pointer" + // Instead, store the state in a Go-side map and pass only an integer handle. state := &streamState{cb: cb} + handle := registerCallback(state) + defer unregisterCallback(handle) rc := C.go_llama_generate( unsafe.Pointer(c.c), @@ -228,7 +273,7 @@ func (c *Context) GenerateStream(prompt string, cb func(token string) bool, opts cprompt, params, C.go_llama_token_callback(C.goTokenCallbackBridge), - unsafe.Pointer(state), + unsafe.Pointer(handle), ) if rc != 0 { diff --git a/ggml/llamacpp/wrapper.cpp b/ggml/llamacpp/wrapper.cpp index 3345c46..074bb15 100644 --- a/ggml/llamacpp/wrapper.cpp +++ b/ggml/llamacpp/wrapper.cpp @@ -55,6 +55,10 @@ int go_llama_generate( return -1; } + // Clear KV cache from any previous generation so positions don't collide. + // Use data=false to only clear metadata (positions/sequences), not the underlying buffers. + llama_memory_clear(llama_get_memory(ctx), false); + const llama_vocab* vocab = llama_model_get_vocab(model); if (!vocab) { set_error("failed to get vocab from model"); @@ -120,6 +124,7 @@ int go_llama_generate( int n_cur = n_tokens; const int n_ctx = (int)llama_n_ctx(ctx); const int max_tokens = params.max_tokens > 0 ? params.max_tokens : 512; + int generated = 0; for (int i = 0; i < max_tokens; i++) { // Sample next token. @@ -131,20 +136,22 @@ int go_llama_generate( break; } + generated++; + // Convert token to text. int n_piece = llama_token_to_piece(vocab, new_token, piece_buf, sizeof(piece_buf) - 1, /*lstrip=*/0, /*special=*/false); if (n_piece < 0) { - // Token requires more buffer space — skip it. continue; } piece_buf[n_piece] = '\0'; // Stream token to callback. if (callback) { - if (!callback(piece_buf, n_piece, user_data)) { - break; // caller requested stop + bool cb_ok = callback(piece_buf, n_piece, user_data); + if (!cb_ok) { + break; } }