Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/gsh/defaults/models.gsh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ model lite {
apiKey: "ollama",
model: "gemma3:1b",
baseURL: "http://localhost:11434/v1",
timeout: 15000,
}

# Use for agent interactions (more capable)
Expand Down
3 changes: 3 additions & 0 deletions docs/script/17-model-declarations.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ model localLlama {
apiKey: "ollama",
baseURL: "http://localhost:11434/v1",
model: "devstral-small-2",
timeout: 15000,
}
```

Expand All @@ -95,6 +96,7 @@ The key points here:
- `apiKey` is literally the string `"ollama"` (not a real key)
- `baseURL` points to your local Ollama server
- `model` is the name of the model you've pulled into Ollama
- `timeout` is in milliseconds and protects callers from hanging forever if the backend stops responding

---

Expand Down Expand Up @@ -262,6 +264,7 @@ Every model declaration needs:

- **`temperature`** (default: 0.7) - Controls randomness in responses (0.0-1.0)
- **`baseURL`** - For Ollama or self-hosted services, the URL to the API endpoint
- **`timeout`** - Request timeout in milliseconds for model API calls

### Practical Example: Choosing the Right Parameters

Expand Down
9 changes: 6 additions & 3 deletions docs/sdk/02-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ model myModel {

### Optional Fields

| Field | Type | Description |
| --------- | -------- | ------------------------------------------- |
| `baseURL` | `string` | API endpoint URL (defaults to OpenAI's API) |
| Field | Type | Description |
| --------- | -------- | -------------------------------------------------------- |
| `baseURL` | `string` | API endpoint URL (defaults to OpenAI's API) |
| `timeout` | `number` | Request timeout in milliseconds for model API calls |

## Provider Examples

Expand All @@ -89,6 +90,7 @@ model gemma {
apiKey: "ollama",
baseURL: "http://localhost:11434/v1",
model: "gemma3:1b",
timeout: 15000,
}

gsh.models.lite = gemma
Expand All @@ -100,6 +102,7 @@ gsh.models.lite = gemma
- Set `apiKey: "ollama"` (required placeholder)
- Use `baseURL: "http://localhost:11434/v1"`
- Model name should match output from `ollama list`
- Set `timeout` to bound requests so hung model backends do not stall REPL features indefinitely

### OpenRouter

Expand Down
7 changes: 7 additions & 0 deletions docs/tutorial/03-command-prediction.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ model myPredictModel {
apiKey: "ollama",
baseURL: "http://localhost:11434/v1",
model: "gemma3:1b",
timeout: 15000,
}

# Tell gsh to use this model for predictions
Expand Down Expand Up @@ -161,11 +162,14 @@ model myPredictModel {
apiKey: "ollama", # Magic value for local Ollama
baseURL: "http://localhost:11434/v1",
model: "gemma3:1b",
timeout: 15000,
}

gsh.models.lite = myPredictModel
```

`timeout` is in milliseconds. For prediction models, setting it is recommended so a stuck model backend does not leave REPL suggestions hanging indefinitely.

**Advantages:**

- Free and private
Expand All @@ -190,6 +194,7 @@ model OpenAIPredictor {
provider: "openai",
apiKey: env.OPENAI_API_KEY,
model: "gpt-5-mini",
timeout: 15000,
}

gsh.models.lite = OpenAIPredictor
Expand Down Expand Up @@ -244,6 +249,8 @@ Try switching to a smaller model:
model: "gemma3:270m", # Smaller than gemma3:1b
```

If the backend sometimes stalls entirely, set `timeout` on the model declaration so prediction requests fail fast instead of hanging.

### Predictions Are Wrong

Try using a more capable model:
Expand Down
51 changes: 51 additions & 0 deletions internal/repl/input/multiline_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,38 @@
package input

import (
"context"
"strings"
"sync"
"testing"
"time"

tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/kunchenguid/gsh/internal/script/interpreter"
"github.com/muesli/termenv"
)

type countingPredictionProvider struct {
mu sync.Mutex
inputs []string
}

func (p *countingPredictionProvider) Predict(ctx context.Context, input string, trigger interpreter.PredictTrigger, existingPrediction string) (string, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.inputs = append(p.inputs, input)
return "", nil
}

func (p *countingPredictionProvider) Inputs() []string {
p.mu.Lock()
defer p.mu.Unlock()
out := make([]string, len(p.inputs))
copy(out, p.inputs)
return out
}

func TestIsInputComplete(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -259,6 +283,33 @@ func TestRenderMultiLineNoPrediction(t *testing.T) {
}
}

func TestMultiLineInputDoesNotStartHiddenPrediction(t *testing.T) {
provider := &countingPredictionProvider{}
predictionState := NewPredictionState(PredictionStateConfig{
DebounceDelay: 10 * time.Millisecond,
Provider: provider,
})

// Simulate normal single-line editing that marks prediction state dirty.
predictionState.OnInputChanged(`echo "test`)

m := New(Config{
PredictionState: predictionState,
})
m.SetValue(`echo "test`)

newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
m = newModel.(Model)

time.Sleep(30 * time.Millisecond)

for _, input := range provider.Inputs() {
if input == "" {
t.Fatalf("multiline editing should cancel prediction without starting an empty-input prediction")
}
}
}

func TestSplitHighlightedByNewlines(t *testing.T) {
// Force ANSI output so highlighting produces escape codes in tests
oldProfile := lipgloss.DefaultRenderer().ColorProfile()
Expand Down
11 changes: 11 additions & 0 deletions internal/repl/input/prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ func (ps *PredictionState) Reset() {
ps.cancelPendingLocked()
}

// Cancel clears any in-flight prediction work without scheduling a replacement.
// It increments the state ID so stale async results are rejected.
func (ps *PredictionState) Cancel() {
ps.mu.Lock()
defer ps.mu.Unlock()
ps.cancelPendingLocked()
ps.stateID.Add(1)
ps.prediction = ""
ps.inputForPrediction = ""
}

// cancelPendingLocked cancels any pending prediction request.
// Must be called with mu held.
func (ps *PredictionState) cancelPendingLocked() {
Expand Down
2 changes: 1 addition & 1 deletion internal/repl/input/prediction_history.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (m Model) onTextChanged() (tea.Model, tea.Cmd) {
if strings.Contains(text, "\n") {
m.currentPrediction = ""
if m.prediction != nil {
m.prediction.OnInputChanged("")
m.prediction.Cancel()
}
return m, nil
}
Expand Down
11 changes: 10 additions & 1 deletion internal/repl/predict/event_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package predict
import (
"context"
"sync"
"time"

"github.com/kunchenguid/gsh/internal/script/interpreter"
"go.uber.org/zap"
)

const defaultPredictionRequestTimeout = 30 * time.Second

// EventPredictionProvider calls the repl.predict middleware chain to obtain predictions.
// If no middleware is registered or middleware returns null, no prediction is returned.
type EventPredictionProvider struct {
Expand Down Expand Up @@ -65,6 +68,12 @@ func (p *EventPredictionProvider) emitPredictEvent(ctx context.Context, input st
return "", nil
}

if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, defaultPredictionRequestTimeout)
defer cancel()
}

// For instant predictions, use TryLock to avoid blocking the UI thread.
// If a debounced prediction is running (holding the mutex for expensive
// operations like git diff or LLM calls), we skip rather than block.
Expand All @@ -79,7 +88,7 @@ func (p *EventPredictionProvider) emitPredictEvent(ctx context.Context, input st
// Ensure middleware sees the cancellable context used by PredictionState
p.interp.SetContext(ctx)
defer func() {
p.interp.SetContext(context.Background())
p.interp.ClearContext()
p.mu.Unlock()
}()

Expand Down
2 changes: 1 addition & 1 deletion internal/repl/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ func (r *REPL) processCommand(ctx context.Context, command string) error {
// Set the cancellable context on the interpreter so agent execution can use it
interp := r.executor.Interpreter()
interp.SetContext(cmdCtx)
defer interp.SetContext(context.Background()) // Clear context after command completes
defer interp.ClearContext() // Clear context after command completes

// Record ALL user input in history (including agent commands like "#...")
// This is done before middleware so all user input is captured
Expand Down
51 changes: 51 additions & 0 deletions internal/script/interpreter/context_local.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package interpreter

import (
"context"
"sync"
)

// goroutineContexts stores execution contexts per goroutine so concurrent REPL
// work like prediction and foreground command execution do not overwrite each
// other's cancellation state.
type goroutineContexts struct {
mu sync.RWMutex
contexts map[int64]context.Context
}

func newGoroutineContexts() *goroutineContexts {
return &goroutineContexts{
contexts: make(map[int64]context.Context),
}
}

func (g *goroutineContexts) set(ctx context.Context) {
gid := getGoroutineID()

g.mu.Lock()
defer g.mu.Unlock()

g.contexts[gid] = ctx
}

func (g *goroutineContexts) clear() {
gid := getGoroutineID()

g.mu.Lock()
defer g.mu.Unlock()

delete(g.contexts, gid)
}

func (g *goroutineContexts) get() context.Context {
gid := getGoroutineID()

g.mu.RLock()
defer g.mu.RUnlock()

ctx := g.contexts[gid]
if ctx == nil {
return context.Background()
}
return ctx
}
22 changes: 9 additions & 13 deletions internal/script/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,12 @@ type Interpreter struct {
mcpManager *mcp.Manager
providerRegistry *ProviderRegistry
callStacks *goroutineCallStacks // Per-goroutine call stacks for error reporting
contexts *goroutineContexts // Per-goroutine execution contexts for cancellation
logger *zap.Logger // Optional zap logger for log.* functions
stdin io.Reader // Reader for input() function, defaults to os.Stdin
runner *interp.Runner // Shared sh runner for env vars, working dir, and exec()
runnerMu sync.RWMutex // Protects runner access

// Context for cancellation (e.g., Ctrl+C handling)
ctx context.Context // Current execution context
ctxMu sync.RWMutex // Protects ctx access

// SDK infrastructure
eventManager *EventManager
sdkConfig *SDKConfig
Expand Down Expand Up @@ -158,6 +155,7 @@ func New(opts *Options) *Interpreter {
stdin: os.Stdin,
runner: runner,
callStacks: newGoroutineCallStacks(),
contexts: newGoroutineContexts(),
eventManager: NewEventManager(),
sdkConfig: NewSDKConfig(opts.Logger, atomicLevel),
version: version,
Expand Down Expand Up @@ -216,21 +214,19 @@ func (i *Interpreter) SetStdin(r io.Reader) {
// The REPL sets this before executing commands so that long-running operations
// (like agent execution or shell commands) can be cancelled.
func (i *Interpreter) SetContext(ctx context.Context) {
i.ctxMu.Lock()
defer i.ctxMu.Unlock()
i.ctx = ctx
i.contexts.set(ctx)
}

// ClearContext removes the execution context for the current goroutine.
func (i *Interpreter) ClearContext() {
i.contexts.clear()
}

// Context returns the current execution context.
// If no context has been set, returns context.Background().
// This should be used by operations that support cancellation.
func (i *Interpreter) Context() context.Context {
i.ctxMu.RLock()
defer i.ctxMu.RUnlock()
if i.ctx == nil {
return context.Background()
}
return i.ctx
return i.contexts.get()
}

// Runner returns the underlying sh runner
Expand Down
Loading
Loading