Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions ggml/llamacpp/llamacpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,53 @@ func (c *Context) Generate(prompt string, opts ...GenerateOption) (string, error
return sb.String(), nil
}

// streamState holds the callback state passed through CGO.
// streamState holds the callback state for a generation request.
type streamState struct {
cb func(string) bool
stopped bool
}

// callbackRegistry stores streamState objects keyed by an integer handle.
// This avoids passing Go pointers (which contain other Go pointers like
// function closures) to C, which violates CGo pointer rules.
var (
callbackMu sync.Mutex
callbackRegistry = map[uintptr]*streamState{}
callbackNextID uintptr
)

// registerCallback stores a streamState and returns an opaque handle
// that can safely be passed to C as a uintptr.
func registerCallback(state *streamState) uintptr {
callbackMu.Lock()
defer callbackMu.Unlock()
callbackNextID++
id := callbackNextID
callbackRegistry[id] = state
return id
}

// unregisterCallback removes a streamState from the registry.
func unregisterCallback(id uintptr) {
callbackMu.Lock()
defer callbackMu.Unlock()
delete(callbackRegistry, id)
}

// lookupCallback retrieves a streamState by handle.
func lookupCallback(id uintptr) *streamState {
callbackMu.Lock()
defer callbackMu.Unlock()
return callbackRegistry[id]
}

//export goTokenCallback
func goTokenCallback(token *C.char, length C.int, userData unsafe.Pointer) C.int {
state := (*streamState)(userData)
id := uintptr(userData)
state := lookupCallback(id)
if state == nil {
return 0
}
goToken := C.GoStringN(token, length)
if !state.cb(goToken) {
state.stopped = true
Expand Down Expand Up @@ -220,15 +258,22 @@ func (c *Context) GenerateStream(prompt string, cb func(token string) bool, opts
params.seed = C.int(cfg.seed)
params.penalty_last_n = C.int(cfg.penaltyLastN)

// Use a handle registry instead of passing Go pointers to C.
// streamState contains a Go function pointer (cb), so passing it
// directly to C via unsafe.Pointer violates CGo pointer rules:
// "cgo argument has Go pointer to unpinned Go pointer"
// Instead, store the state in a Go-side map and pass only an integer handle.
state := &streamState{cb: cb}
handle := registerCallback(state)
defer unregisterCallback(handle)

rc := C.go_llama_generate(
unsafe.Pointer(c.c),
unsafe.Pointer(c.model.c),
cprompt,
params,
C.go_llama_token_callback(C.goTokenCallbackBridge),
unsafe.Pointer(state),
unsafe.Pointer(handle),
)

if rc != 0 {
Expand Down
13 changes: 10 additions & 3 deletions ggml/llamacpp/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ int go_llama_generate(
return -1;
}

// Clear KV cache from any previous generation so positions don't collide.
// Use data=false to only clear metadata (positions/sequences), not the underlying buffers.
llama_memory_clear(llama_get_memory(ctx), false);

const llama_vocab* vocab = llama_model_get_vocab(model);
if (!vocab) {
set_error("failed to get vocab from model");
Expand Down Expand Up @@ -120,6 +124,7 @@ int go_llama_generate(
int n_cur = n_tokens;
const int n_ctx = (int)llama_n_ctx(ctx);
const int max_tokens = params.max_tokens > 0 ? params.max_tokens : 512;
int generated = 0;

for (int i = 0; i < max_tokens; i++) {
// Sample next token.
Expand All @@ -131,20 +136,22 @@ int go_llama_generate(
break;
}

generated++;

// Convert token to text.
int n_piece = llama_token_to_piece(vocab, new_token,
piece_buf, sizeof(piece_buf) - 1,
/*lstrip=*/0, /*special=*/false);
if (n_piece < 0) {
// Token requires more buffer space — skip it.
continue;
}
piece_buf[n_piece] = '\0';

// Stream token to callback.
if (callback) {
if (!callback(piece_buf, n_piece, user_data)) {
break; // caller requested stop
bool cb_ok = callback(piece_buf, n_piece, user_data);
if (!cb_ok) {
break;
}
}

Expand Down
Loading