From 01f117889fe15dfe28d2ba7f3a67370449423921 Mon Sep 17 00:00:00 2001 From: kunchenguid Date: Wed, 1 Apr 2026 10:45:10 -0700 Subject: [PATCH 1/2] fix(repl): prevent hanging predictions in multiline input --- cmd/gsh/defaults/models.gsh | 1 + docs/script/17-model-declarations.md | 3 + docs/sdk/02-models.md | 9 ++- internal/repl/input/multiline_test.go | 51 ++++++++++++++++ internal/repl/input/prediction.go | 11 ++++ internal/repl/input/prediction_history.go | 2 +- internal/repl/predict/event_provider.go | 11 +++- internal/repl/repl.go | 2 +- internal/script/interpreter/context_local.go | 51 ++++++++++++++++ internal/script/interpreter/interpreter.go | 22 +++---- .../script/interpreter/interpreter_test.go | 51 +++++++++++++++- .../script/interpreter/provider_openai.go | 32 ++++++++++ .../interpreter/provider_openai_test.go | 60 +++++++++++++++++++ 13 files changed, 284 insertions(+), 22 deletions(-) create mode 100644 internal/script/interpreter/context_local.go diff --git a/cmd/gsh/defaults/models.gsh b/cmd/gsh/defaults/models.gsh index de069cdf..82b1b90b 100644 --- a/cmd/gsh/defaults/models.gsh +++ b/cmd/gsh/defaults/models.gsh @@ -7,6 +7,7 @@ model lite { apiKey: "ollama", model: "gemma3:1b", baseURL: "http://localhost:11434/v1", + timeout: 15000, } # Use for agent interactions (more capable) diff --git a/docs/script/17-model-declarations.md b/docs/script/17-model-declarations.md index d46d792a..5c955f4f 100644 --- a/docs/script/17-model-declarations.md +++ b/docs/script/17-model-declarations.md @@ -84,6 +84,7 @@ model localLlama { apiKey: "ollama", baseURL: "http://localhost:11434/v1", model: "devstral-small-2", + timeout: 15000, } ``` @@ -95,6 +96,7 @@ The key points here: - `apiKey` is literally the string `"ollama"` (not a real key) - `baseURL` points to your local Ollama server - `model` is the name of the model you've pulled into Ollama +- `timeout` is in milliseconds and protects callers from hanging forever if the backend stops responding --- @@ -262,6 +264,7 @@ Every model declaration needs: - **`temperature`** (default: 0.7) - Controls randomness in responses (0.0-1.0) - **`baseURL`** - For Ollama or self-hosted services, the URL to the API endpoint +- **`timeout`** - Request timeout in milliseconds for model API calls ### Practical Example: Choosing the Right Parameters diff --git a/docs/sdk/02-models.md b/docs/sdk/02-models.md index 63d54eed..dbdfe87d 100644 --- a/docs/sdk/02-models.md +++ b/docs/sdk/02-models.md @@ -63,9 +63,10 @@ model myModel { ### Optional Fields -| Field | Type | Description | -| --------- | -------- | ------------------------------------------- | -| `baseURL` | `string` | API endpoint URL (defaults to OpenAI's API) | +| Field | Type | Description | +| --------- | -------- | -------------------------------------------------------- | +| `baseURL` | `string` | API endpoint URL (defaults to OpenAI's API) | +| `timeout` | `number` | Request timeout in milliseconds for model API calls | ## Provider Examples @@ -89,6 +90,7 @@ model gemma { apiKey: "ollama", baseURL: "http://localhost:11434/v1", model: "gemma3:1b", + timeout: 15000, } gsh.models.lite = gemma @@ -100,6 +102,7 @@ gsh.models.lite = gemma - Set `apiKey: "ollama"` (required placeholder) - Use `baseURL: "http://localhost:11434/v1"` - Model name should match output from `ollama list` +- Set `timeout` to bound requests so hung model backends do not stall REPL features indefinitely ### OpenRouter diff --git a/internal/repl/input/multiline_test.go b/internal/repl/input/multiline_test.go index b8822042..56f0a7fe 100644 --- a/internal/repl/input/multiline_test.go +++ b/internal/repl/input/multiline_test.go @@ -1,14 +1,38 @@ package input import ( + "context" "strings" + "sync" "testing" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/kunchenguid/gsh/internal/script/interpreter" "github.com/muesli/termenv" ) +type countingPredictionProvider struct { + mu sync.Mutex + inputs []string +} + +func (p *countingPredictionProvider) Predict(ctx context.Context, input string, trigger interpreter.PredictTrigger, existingPrediction string) (string, error) { + p.mu.Lock() + defer p.mu.Unlock() + p.inputs = append(p.inputs, input) + return "", nil +} + +func (p *countingPredictionProvider) Inputs() []string { + p.mu.Lock() + defer p.mu.Unlock() + out := make([]string, len(p.inputs)) + copy(out, p.inputs) + return out +} + func TestIsInputComplete(t *testing.T) { tests := []struct { name string @@ -259,6 +283,33 @@ func TestRenderMultiLineNoPrediction(t *testing.T) { } } +func TestMultiLineInputDoesNotStartHiddenPrediction(t *testing.T) { + provider := &countingPredictionProvider{} + predictionState := NewPredictionState(PredictionStateConfig{ + DebounceDelay: 10 * time.Millisecond, + Provider: provider, + }) + + // Simulate normal single-line editing that marks prediction state dirty. + predictionState.OnInputChanged(`echo "test`) + + m := New(Config{ + PredictionState: predictionState, + }) + m.SetValue(`echo "test`) + + newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + m = newModel.(Model) + + time.Sleep(30 * time.Millisecond) + + for _, input := range provider.Inputs() { + if input == "" { + t.Fatalf("multiline editing should cancel prediction without starting an empty-input prediction") + } + } +} + func TestSplitHighlightedByNewlines(t *testing.T) { // Force ANSI output so highlighting produces escape codes in tests oldProfile := lipgloss.DefaultRenderer().ColorProfile() diff --git a/internal/repl/input/prediction.go b/internal/repl/input/prediction.go index cb74db13..de074634 100644 --- a/internal/repl/input/prediction.go +++ b/internal/repl/input/prediction.go @@ -144,6 +144,17 @@ func (ps *PredictionState) Reset() { ps.cancelPendingLocked() } +// Cancel clears any in-flight prediction work without scheduling a replacement. +// It increments the state ID so stale async results are rejected. +func (ps *PredictionState) Cancel() { + ps.mu.Lock() + defer ps.mu.Unlock() + ps.cancelPendingLocked() + ps.stateID.Add(1) + ps.prediction = "" + ps.inputForPrediction = "" +} + // cancelPendingLocked cancels any pending prediction request. // Must be called with mu held. func (ps *PredictionState) cancelPendingLocked() { diff --git a/internal/repl/input/prediction_history.go b/internal/repl/input/prediction_history.go index 2a8f6651..57886497 100644 --- a/internal/repl/input/prediction_history.go +++ b/internal/repl/input/prediction_history.go @@ -19,7 +19,7 @@ func (m Model) onTextChanged() (tea.Model, tea.Cmd) { if strings.Contains(text, "\n") { m.currentPrediction = "" if m.prediction != nil { - m.prediction.OnInputChanged("") + m.prediction.Cancel() } return m, nil } diff --git a/internal/repl/predict/event_provider.go b/internal/repl/predict/event_provider.go index 21252130..35e27068 100644 --- a/internal/repl/predict/event_provider.go +++ b/internal/repl/predict/event_provider.go @@ -3,11 +3,14 @@ package predict import ( "context" "sync" + "time" "github.com/kunchenguid/gsh/internal/script/interpreter" "go.uber.org/zap" ) +const defaultPredictionRequestTimeout = 30 * time.Second + // EventPredictionProvider calls the repl.predict middleware chain to obtain predictions. // If no middleware is registered or middleware returns null, no prediction is returned. type EventPredictionProvider struct { @@ -65,6 +68,12 @@ func (p *EventPredictionProvider) emitPredictEvent(ctx context.Context, input st return "", nil } + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, defaultPredictionRequestTimeout) + defer cancel() + } + // For instant predictions, use TryLock to avoid blocking the UI thread. // If a debounced prediction is running (holding the mutex for expensive // operations like git diff or LLM calls), we skip rather than block. @@ -79,7 +88,7 @@ func (p *EventPredictionProvider) emitPredictEvent(ctx context.Context, input st // Ensure middleware sees the cancellable context used by PredictionState p.interp.SetContext(ctx) defer func() { - p.interp.SetContext(context.Background()) + p.interp.ClearContext() p.mu.Unlock() }() diff --git a/internal/repl/repl.go b/internal/repl/repl.go index 470aa66d..c0152add 100644 --- a/internal/repl/repl.go +++ b/internal/repl/repl.go @@ -437,7 +437,7 @@ func (r *REPL) processCommand(ctx context.Context, command string) error { // Set the cancellable context on the interpreter so agent execution can use it interp := r.executor.Interpreter() interp.SetContext(cmdCtx) - defer interp.SetContext(context.Background()) // Clear context after command completes + defer interp.ClearContext() // Clear context after command completes // Record ALL user input in history (including agent commands like "#...") // This is done before middleware so all user input is captured diff --git a/internal/script/interpreter/context_local.go b/internal/script/interpreter/context_local.go new file mode 100644 index 00000000..856c2c49 --- /dev/null +++ b/internal/script/interpreter/context_local.go @@ -0,0 +1,51 @@ +package interpreter + +import ( + "context" + "sync" +) + +// goroutineContexts stores execution contexts per goroutine so concurrent REPL +// work like prediction and foreground command execution do not overwrite each +// other's cancellation state. +type goroutineContexts struct { + mu sync.RWMutex + contexts map[int64]context.Context +} + +func newGoroutineContexts() *goroutineContexts { + return &goroutineContexts{ + contexts: make(map[int64]context.Context), + } +} + +func (g *goroutineContexts) set(ctx context.Context) { + gid := getGoroutineID() + + g.mu.Lock() + defer g.mu.Unlock() + + g.contexts[gid] = ctx +} + +func (g *goroutineContexts) clear() { + gid := getGoroutineID() + + g.mu.Lock() + defer g.mu.Unlock() + + delete(g.contexts, gid) +} + +func (g *goroutineContexts) get() context.Context { + gid := getGoroutineID() + + g.mu.RLock() + defer g.mu.RUnlock() + + ctx := g.contexts[gid] + if ctx == nil { + return context.Background() + } + return ctx +} diff --git a/internal/script/interpreter/interpreter.go b/internal/script/interpreter/interpreter.go index 07f30182..ea0e0dc1 100644 --- a/internal/script/interpreter/interpreter.go +++ b/internal/script/interpreter/interpreter.go @@ -34,15 +34,12 @@ type Interpreter struct { mcpManager *mcp.Manager providerRegistry *ProviderRegistry callStacks *goroutineCallStacks // Per-goroutine call stacks for error reporting + contexts *goroutineContexts // Per-goroutine execution contexts for cancellation logger *zap.Logger // Optional zap logger for log.* functions stdin io.Reader // Reader for input() function, defaults to os.Stdin runner *interp.Runner // Shared sh runner for env vars, working dir, and exec() runnerMu sync.RWMutex // Protects runner access - // Context for cancellation (e.g., Ctrl+C handling) - ctx context.Context // Current execution context - ctxMu sync.RWMutex // Protects ctx access - // SDK infrastructure eventManager *EventManager sdkConfig *SDKConfig @@ -158,6 +155,7 @@ func New(opts *Options) *Interpreter { stdin: os.Stdin, runner: runner, callStacks: newGoroutineCallStacks(), + contexts: newGoroutineContexts(), eventManager: NewEventManager(), sdkConfig: NewSDKConfig(opts.Logger, atomicLevel), version: version, @@ -216,21 +214,19 @@ func (i *Interpreter) SetStdin(r io.Reader) { // The REPL sets this before executing commands so that long-running operations // (like agent execution or shell commands) can be cancelled. func (i *Interpreter) SetContext(ctx context.Context) { - i.ctxMu.Lock() - defer i.ctxMu.Unlock() - i.ctx = ctx + i.contexts.set(ctx) +} + +// ClearContext removes the execution context for the current goroutine. +func (i *Interpreter) ClearContext() { + i.contexts.clear() } // Context returns the current execution context. // If no context has been set, returns context.Background(). // This should be used by operations that support cancellation. func (i *Interpreter) Context() context.Context { - i.ctxMu.RLock() - defer i.ctxMu.RUnlock() - if i.ctx == nil { - return context.Background() - } - return i.ctx + return i.contexts.get() } // Runner returns the underlying sh runner diff --git a/internal/script/interpreter/interpreter_test.go b/internal/script/interpreter/interpreter_test.go index 61856747..0ad78480 100644 --- a/internal/script/interpreter/interpreter_test.go +++ b/internal/script/interpreter/interpreter_test.go @@ -721,17 +721,17 @@ func TestInterpreter_SetContext_NilReturnsBackground(t *testing.T) { } // Now set nil - interp.SetContext(nil) + interp.ClearContext() // Verify we get a non-cancelled context back ctx := interp.Context() if ctx == nil { - t.Fatal("Context() returned nil after SetContext(nil)") + t.Fatal("Context() returned nil after ClearContext()") } select { case <-ctx.Done(): - t.Error("context after SetContext(nil) should not be cancelled") + t.Error("context after ClearContext() should not be cancelled") default: // Expected - context is not cancelled } @@ -772,6 +772,51 @@ func TestInterpreter_Context_ThreadSafety(t *testing.T) { // If we get here without a race condition, the test passes } +func TestInterpreter_Context_IsolatedPerGoroutine(t *testing.T) { + interp := New(nil) + + mainCtx, mainCancel := context.WithCancel(context.Background()) + defer mainCancel() + otherCtx, otherCancel := context.WithCancel(context.Background()) + defer otherCancel() + + interp.SetContext(mainCtx) + + done := make(chan struct{}) + go func() { + interp.SetContext(otherCtx) + close(done) + }() + <-done + + if interp.Context() != mainCtx { + t.Fatal("main goroutine context should not be overwritten by another goroutine") + } +} + +func TestInterpreter_Context_ClearDoesNotAffectOtherGoroutines(t *testing.T) { + interp := New(nil) + + mainCtx, mainCancel := context.WithCancel(context.Background()) + defer mainCancel() + otherCtx, otherCancel := context.WithCancel(context.Background()) + defer otherCancel() + + interp.SetContext(mainCtx) + + done := make(chan struct{}) + go func() { + interp.SetContext(otherCtx) + interp.ClearContext() + close(done) + }() + <-done + + if interp.Context() != mainCtx { + t.Fatal("clearing context in another goroutine should not clear this goroutine's context") + } +} + // TestConcurrentEmitEvent verifies that concurrent EmitEvent calls from multiple // goroutines do not cause data races or scope corruption. This is the core // correctness property of the env-as-parameter refactor: each EmitEvent call diff --git a/internal/script/interpreter/provider_openai.go b/internal/script/interpreter/provider_openai.go index 6838c691..53cdb626 100644 --- a/internal/script/interpreter/provider_openai.go +++ b/internal/script/interpreter/provider_openai.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strings" + "time" ) // OpenAIProvider implements the ModelProvider interface for OpenAI @@ -85,6 +86,12 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, request ChatRequest return nil, fmt.Errorf("OpenAI provider requires a model") } + ctx, cancel, err := withModelTimeout(ctx, request.Model) + if err != nil { + return nil, err + } + defer cancel() + // Get API key from model config apiKeyVal, ok := request.Model.Config["apiKey"] if !ok { @@ -313,6 +320,12 @@ func (p *OpenAIProvider) StreamingChatCompletion(ctx context.Context, request Ch return nil, fmt.Errorf("OpenAI provider requires a model") } + ctx, cancel, err := withModelTimeout(ctx, request.Model) + if err != nil { + return nil, err + } + defer cancel() + // Get API key from model config apiKeyVal, ok := request.Model.Config["apiKey"] if !ok { @@ -620,6 +633,25 @@ func (p *OpenAIProvider) StreamingChatCompletion(ctx context.Context, request Ch return response, nil } +func withModelTimeout(ctx context.Context, model *ModelValue) (context.Context, context.CancelFunc, error) { + timeoutVal, ok := model.Config["timeout"] + if !ok { + return ctx, func() {}, nil + } + + timeoutNum, ok := timeoutVal.(*NumberValue) + if !ok { + return nil, nil, fmt.Errorf("OpenAI provider requires 'timeout' to be a number (milliseconds)") + } + if timeoutNum.Value <= 0 { + return nil, nil, fmt.Errorf("OpenAI provider requires 'timeout' to be positive") + } + + timeout := time.Duration(timeoutNum.Value) * time.Millisecond + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) + return timeoutCtx, cancel, nil +} + // objectValueToMap converts an ObjectValue to a map[string]interface{} for JSON serialization. func objectValueToMap(obj *ObjectValue) map[string]interface{} { result := make(map[string]interface{}) diff --git a/internal/script/interpreter/provider_openai_test.go b/internal/script/interpreter/provider_openai_test.go index ef0cd9be..529d54cc 100644 --- a/internal/script/interpreter/provider_openai_test.go +++ b/internal/script/interpreter/provider_openai_test.go @@ -3,10 +3,13 @@ package interpreter import ( "context" "encoding/json" + "errors" + "io" "net/http" "net/http/httptest" "strings" "testing" + "time" ) func TestOpenAIProviderChatCompletion(t *testing.T) { @@ -295,6 +298,63 @@ func TestOpenAIProviderChatCompletion(t *testing.T) { } } +func TestOpenAIProviderChatCompletion_UsesModelTimeout(t *testing.T) { + provider := NewOpenAIProvider() + provider.httpClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + select { + case <-time.After(50 * time.Millisecond): + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{ + "id": "chatcmpl-timeout", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "too slow" + }, + "finish_reason": "stop" + }] + }`)), + Header: make(http.Header), + }, nil + case <-req.Context().Done(): + return nil, req.Context().Err() + } + }), + } + req := ChatRequest{ + Model: &ModelValue{ + Name: "gpt4", + Config: map[string]Value{ + "provider": &StringValue{Value: "openai"}, + "apiKey": &StringValue{Value: "test-key"}, + "model": &StringValue{Value: "gpt-4"}, + "baseURL": &StringValue{Value: "http://example.test/v1"}, + "timeout": &NumberValue{Value: 10}, + }, + }, + Messages: []ChatMessage{ + {Role: "user", Content: "Hello"}, + }, + } + + _, err := provider.ChatCompletion(context.Background(), req) + if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("expected context deadline exceeded error, got %v", err) + } +} + +type roundTripFunc func(req *http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + func TestOpenAIProviderToolCallMessageFields(t *testing.T) { // Test that tool_call_id is included in tool result messages // and tool_calls are included in assistant messages From c77835d8c2da5e5216b5559c7d361fad55fdcb97 Mon Sep 17 00:00:00 2001 From: kunchenguid Date: Wed, 1 Apr 2026 10:58:16 -0700 Subject: [PATCH 2/2] Airlock: auto-fixes from Documentation Updates --- docs/tutorial/03-command-prediction.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/tutorial/03-command-prediction.md b/docs/tutorial/03-command-prediction.md index 38d4b31a..794a5d2a 100644 --- a/docs/tutorial/03-command-prediction.md +++ b/docs/tutorial/03-command-prediction.md @@ -51,6 +51,7 @@ model myPredictModel { apiKey: "ollama", baseURL: "http://localhost:11434/v1", model: "gemma3:1b", + timeout: 15000, } # Tell gsh to use this model for predictions @@ -161,11 +162,14 @@ model myPredictModel { apiKey: "ollama", # Magic value for local Ollama baseURL: "http://localhost:11434/v1", model: "gemma3:1b", + timeout: 15000, } gsh.models.lite = myPredictModel ``` +`timeout` is in milliseconds. For prediction models, setting it is recommended so a stuck model backend does not leave REPL suggestions hanging indefinitely. + **Advantages:** - Free and private @@ -190,6 +194,7 @@ model OpenAIPredictor { provider: "openai", apiKey: env.OPENAI_API_KEY, model: "gpt-5-mini", + timeout: 15000, } gsh.models.lite = OpenAIPredictor @@ -244,6 +249,8 @@ Try switching to a smaller model: model: "gemma3:270m", # Smaller than gemma3:1b ``` +If the backend sometimes stalls entirely, set `timeout` on the model declaration so prediction requests fail fast instead of hanging. + ### Predictions Are Wrong Try using a more capable model: