diff --git a/cmd/dj/main.go b/cmd/dj/main.go index 1cc16ac..b95ce2a 100644 --- a/cmd/dj/main.go +++ b/cmd/dj/main.go @@ -7,6 +7,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/spf13/cobra" + "github.com/robinojw/dj/internal/appserver" "github.com/robinojw/dj/internal/config" "github.com/robinojw/dj/internal/state" "github.com/robinojw/dj/internal/tui" @@ -37,12 +38,18 @@ func runApp(cmd *cobra.Command, args []string) error { return fmt.Errorf("load config: %w", err) } - _ = cfg + client := appserver.NewClient(cfg.AppServer.Command, cfg.AppServer.Args...) + defer client.Stop() + + router := appserver.NewEventRouter() + client.Router = router store := state.NewThreadStore() - app := tui.NewAppModel(store) + app := tui.NewAppModel(store, client) program := tea.NewProgram(app, tea.WithAltScreen()) + tui.WireEventBridge(router, program) + _, err = program.Run() return err } diff --git a/internal/appserver/client.go b/internal/appserver/client.go index 9357eb2..770bd40 100644 --- a/internal/appserver/client.go +++ b/internal/appserver/client.go @@ -11,7 +11,7 @@ import ( "sync/atomic" ) -// Client manages a child app-server process and bidirectional JSON-RPC communication. +// Client manages a child codex proto process and bidirectional JSONL communication. type Client struct { command string args []string @@ -21,27 +21,17 @@ type Client struct { stdout io.ReadCloser scanner *bufio.Scanner - mu sync.Mutex // protects writes to stdin - nextID atomic.Int64 - pending sync.Map // id → chan *Message + mu sync.Mutex + nextID atomic.Int64 running atomic.Bool - // OnNotification is called for each server notification (no id). - // Set this before calling Start. - OnNotification func(method string, params json.RawMessage) - - // OnServerRequest is called for server-to-client requests (has id). - // Set this before calling Start. - OnServerRequest func(id int, method string, params json.RawMessage) - - // Router dispatches typed notifications by method name. - // Falls back to OnNotification for unregistered methods. - Router *NotificationRouter + Router *EventRouter } +const scannerBufferSize = 1024 * 1024 + // NewClient creates a client that will spawn the given command. -// Additional arguments can be passed after the command. func NewClient(command string, args ...string) *Client { return &Client{ command: command, @@ -49,7 +39,7 @@ func NewClient(command string, args ...string) *Client { } } -// Start spawns the child process and begins reading stdout. +// Start spawns the child process. func (c *Client) Start(ctx context.Context) error { c.cmd = exec.CommandContext(ctx, c.command, c.args...) @@ -65,7 +55,7 @@ func (c *Client) Start(ctx context.Context) error { } c.scanner = bufio.NewScanner(c.stdout) - c.scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB max line + c.scanner.Buffer(make([]byte, scannerBufferSize), scannerBufferSize) if err := c.cmd.Start(); err != nil { return fmt.Errorf("start process: %w", err) @@ -87,7 +77,6 @@ func (c *Client) Stop() error { } c.running.Store(false) - // Close stdin to signal EOF to the child if c.stdin != nil { c.stdin.Close() } @@ -98,11 +87,11 @@ func (c *Client) Stop() error { return nil } -// Send writes a JSON-RPC request to the child's stdin as a JSONL line. -func (c *Client) Send(req *Request) error { - data, err := json.Marshal(req) +// Send writes a Submission to the child's stdin as a JSONL line. +func (c *Client) Send(sub *Submission) error { + data, err := json.Marshal(sub) if err != nil { - return fmt.Errorf("marshal request: %w", err) + return fmt.Errorf("marshal submission: %w", err) } c.mu.Lock() @@ -113,138 +102,28 @@ func (c *Client) Send(req *Request) error { return err } -// ReadLoop reads JSONL from stdout and dispatches each message to the callback. -// It blocks until the scanner is exhausted (stdout closed) or an error occurs. -func (c *Client) ReadLoop(handler func(Message)) { +// NextID generates a unique string ID for a submission. +func (c *Client) NextID() string { + return fmt.Sprintf("sub-%d", c.nextID.Add(1)) +} + +// ReadLoop reads JSONL events from stdout and dispatches to the router. +// Blocks until stdout is closed or an error occurs. +func (c *Client) ReadLoop() { for c.scanner.Scan() { line := c.scanner.Bytes() if len(line) == 0 { continue } - var msg Message - if err := json.Unmarshal(line, &msg); err != nil { - continue // skip malformed lines - } - - handler(msg) - } -} - -// Call sends a request and blocks until the response with the matching ID arrives. -func (c *Client) Call(ctx context.Context, method string, params json.RawMessage) (*Message, error) { - id := int(c.nextID.Add(1)) - - ch := make(chan *Message, 1) - c.pending.Store(id, ch) - defer c.pending.Delete(id) - - req := &Request{ - JSONRPC: "2.0", - ID: &id, - Method: method, - Params: params, - } - - if err := c.Send(req); err != nil { - return nil, err - } - - select { - case msg := <-ch: - return msg, nil - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -// Dispatch routes an incoming message to the appropriate handler: -// - Messages with an ID matching a pending request -> resolve the pending Call -// - Messages with an ID but no pending request -> server-to-client request (OnServerRequest) -// - Messages without an ID -> notification (OnNotification) -func (c *Client) Dispatch(msg Message) { - if msg.ID != nil { - // Check if this resolves a pending call - if ch, ok := c.pending.LoadAndDelete(*msg.ID); ok { - ch.(chan *Message) <- &msg - return - } - - // Server-to-client request - if c.OnServerRequest != nil && msg.Method != "" { - c.OnServerRequest(*msg.ID, msg.Method, msg.Params) + var event Event + if err := json.Unmarshal(line, &event); err != nil { + continue } - return - } - if msg.Method == "" { - return - } - - if c.Router != nil { - c.Router.Handle(msg.Method, msg.Params) - } - - if c.OnNotification != nil { - c.OnNotification(msg.Method, msg.Params) - } -} - -// InitializeParams is sent as the first request to the app-server. -type InitializeParams struct { - ClientInfo ClientInfo `json:"clientInfo"` -} - -// ClientInfo identifies this client to the app-server. -type ClientInfo struct { - Name string `json:"name"` - Title string `json:"title"` - Version string `json:"version"` -} - -// ServerCapabilities is the result of the initialize request. -type ServerCapabilities struct { - ServerInfo struct { - Name string `json:"name"` - Version string `json:"version"` - } `json:"serverInfo"` -} - -// Initialize performs the required handshake with the app-server. -// Sends initialize request, receives capabilities, then sends initialized notification. -func (c *Client) Initialize(ctx context.Context) (*ServerCapabilities, error) { - params, _ := json.Marshal(InitializeParams{ - ClientInfo: ClientInfo{ - Name: "dj", - Title: "DJ — Codex TUI Visualizer", - Version: "0.1.0", - }, - }) - - resp, err := c.Call(ctx, "initialize", params) - if err != nil { - return nil, fmt.Errorf("initialize request: %w", err) - } - - if resp.Error != nil { - return nil, fmt.Errorf("initialize error: %s", resp.Error.Message) - } - - var caps ServerCapabilities - if resp.Result != nil { - if err := json.Unmarshal(resp.Result, &caps); err != nil { - return nil, fmt.Errorf("unmarshal capabilities: %w", err) + if c.Router != nil { + c.Router.HandleEvent(event) } } - - // Send the initialized notification (no id, no response expected) - notif := &Request{ - JSONRPC: "2.0", - Method: "initialized", - } - if err := c.Send(notif); err != nil { - return nil, fmt.Errorf("send initialized: %w", err) - } - - return &caps, nil + c.running.Store(false) } diff --git a/internal/appserver/client_test.go b/internal/appserver/client_test.go index 404f63a..082ea33 100644 --- a/internal/appserver/client_test.go +++ b/internal/appserver/client_test.go @@ -10,8 +10,6 @@ import ( ) func TestClientStartStop(t *testing.T) { - // Use 'cat' as a mock app-server: it reads stdin and echoes to stdout. - // This verifies process lifecycle without a real codex binary. client := NewClient("cat") ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -35,7 +33,6 @@ func TestClientStartStop(t *testing.T) { } func TestClientSendAndRead(t *testing.T) { - // 'cat' echoes back what we write — simulates a response client := NewClient("cat") ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -46,137 +43,85 @@ func TestClientSendAndRead(t *testing.T) { } defer client.Stop() - // Start the read loop - msgs := make(chan Message, 10) - go client.ReadLoop(func(msg Message) { - msgs <- msg + events := make(chan Event, 10) + client.Router = NewEventRouter() + client.Router.OnSessionConfigured(func(event SessionConfigured) { + events <- Event{Msg: json.RawMessage(`{"type":"session_configured"}`)} }) - // Send a JSON-RPC request — cat will echo it back - req := &Request{ - JSONRPC: "2.0", - ID: intPtr(1), - Method: "test/echo", - Params: json.RawMessage(`{"hello":"world"}`), + go client.ReadLoop() + + sub := &Submission{ + ID: "test-1", + Op: json.RawMessage(`{"type":"session_configured","session_id":"s1","model":"test"}`), } - if err := client.Send(req); err != nil { - t.Fatalf("Send failed: %v", err) + + wrappedEvent := Event{ + ID: "", + Msg: json.RawMessage(`{"type":"session_configured","session_id":"s1","model":"test"}`), } + data, _ := json.Marshal(wrappedEvent) + client.mu.Lock() + data = append(data, '\n') + client.stdin.Write(data) + client.mu.Unlock() + _ = sub select { - case msg := <-msgs: - if msg.Method != "test/echo" { - t.Errorf("expected method test/echo, got %s", msg.Method) - } + case <-events: case <-time.After(3 * time.Second): - t.Fatal("timeout waiting for message") + t.Fatal("timeout waiting for event") } } -func TestClientCall(t *testing.T) { - // Use 'cat' — it echoes the request back as-is. - // The Call method will see the echoed message has a matching ID and treat it as a response. +func TestClientNextID(t *testing.T) { client := NewClient("cat") - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + id1 := client.NextID() + id2 := client.NextID() - if err := client.Start(ctx); err != nil { - t.Fatalf("Start failed: %v", err) + if id1 == id2 { + t.Errorf("expected unique IDs, got %s and %s", id1, id2) } - defer client.Stop() - - // Start dispatching - go client.ReadLoop(client.Dispatch) - - // Call — cat echoes back the request, which has an id, so it resolves as a response - resp, err := client.Call(ctx, "test/method", json.RawMessage(`{"key":"val"}`)) - if err != nil { - t.Fatalf("Call failed: %v", err) + if id1 != "sub-1" { + t.Errorf("expected sub-1, got %s", id1) } - - // The echo will have params (not result), but the call resolved because ID matched - if resp == nil { - t.Fatal("expected non-nil response") + if id2 != "sub-2" { + t.Errorf("expected sub-2, got %s", id2) } } -func TestInitializeHandshake(t *testing.T) { - // Set up bidirectional pipes to simulate app-server - // client writes -> serverRead, serverWrite -> clientRead +func TestClientReadLoopParsesEvents(t *testing.T) { clientRead, serverWrite := io.Pipe() - serverRead, clientWrite := io.Pipe() - - // Mock server: reads initialize request, writes back capabilities response, - // then reads the initialized notification - go func() { - scanner := bufio.NewScanner(serverRead) - scanner.Buffer(make([]byte, 1024*1024), 1024*1024) - - // Read initialize request - if !scanner.Scan() { - t.Error("mock server: failed to read initialize request") - return - } - var req Message - if err := json.Unmarshal(scanner.Bytes(), &req); err != nil { - t.Errorf("mock server: unmarshal request: %v", err) - return - } - if req.Method != "initialize" { - t.Errorf("mock server: expected method initialize, got %s", req.Method) - return - } - // Write capabilities response - resp := Message{ - JSONRPC: "2.0", - ID: req.ID, - Result: json.RawMessage(`{"serverInfo":{"name":"codex-app-server","version":"0.1.0"}}`), - } - data, _ := json.Marshal(resp) - data = append(data, '\n') - serverWrite.Write(data) - - // Read initialized notification - if !scanner.Scan() { - t.Error("mock server: failed to read initialized notification") - return - } - var notif Message - if err := json.Unmarshal(scanner.Bytes(), ¬if); err != nil { - t.Errorf("mock server: unmarshal notification: %v", err) - return - } - if notif.Method != "initialized" { - t.Errorf("mock server: expected method initialized, got %s", notif.Method) - } - }() - - // Set up client with our pipes instead of a real process client := &Client{} - client.stdin = clientWrite client.stdout = clientRead client.scanner = bufio.NewScanner(clientRead) - client.scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + client.scanner.Buffer(make([]byte, scannerBufferSize), scannerBufferSize) client.running.Store(true) - go client.ReadLoop(client.Dispatch) + received := make(chan SessionConfigured, 1) + client.Router = NewEventRouter() + client.Router.OnSessionConfigured(func(event SessionConfigured) { + received <- event + }) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + go client.ReadLoop() - caps, err := client.Initialize(ctx) - if err != nil { - t.Fatalf("Initialize failed: %v", err) - } - if caps == nil { - t.Fatal("expected non-nil capabilities") - } - if caps.ServerInfo.Name != "codex-app-server" { - t.Errorf("expected server name codex-app-server, got %s", caps.ServerInfo.Name) - } - if caps.ServerInfo.Version != "0.1.0" { - t.Errorf("expected server version 0.1.0, got %s", caps.ServerInfo.Version) + eventJSON := `{"id":"","msg":{"type":"session_configured","session_id":"sess-123","model":"gpt-4o"}}` + "\n" + serverWrite.Write([]byte(eventJSON)) + + select { + case event := <-received: + if event.SessionID != "sess-123" { + t.Errorf("expected sess-123, got %s", event.SessionID) + } + if event.Model != "gpt-4o" { + t.Errorf("expected gpt-4o, got %s", event.Model) + } + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for session_configured event") } + + serverWrite.Close() } diff --git a/internal/appserver/client_thread.go b/internal/appserver/client_thread.go index f4a6111..634cd65 100644 --- a/internal/appserver/client_thread.go +++ b/internal/appserver/client_thread.go @@ -1,58 +1,53 @@ package appserver import ( - "context" "encoding/json" "fmt" ) -func (c *Client) CreateThread(ctx context.Context, instructions string) (*ThreadCreateResult, error) { - params, _ := json.Marshal(ThreadCreateParams{ - Instructions: instructions, - }) +// SendUserTurn sends a user_turn submission with the given text content. +func (c *Client) SendUserTurn(text string, cwd string, model string) error { + op := UserTurnOp{ + Type: OpUserTurn, + Items: []UserInput{NewTextInput(text)}, + Cwd: cwd, + ApprovalPolicy: "on-request", + SandboxPolicy: SandboxPolicyReadOnly(), + Model: model, + } - resp, err := c.Call(ctx, MethodThreadCreate, params) + opData, err := json.Marshal(op) if err != nil { - return nil, fmt.Errorf("thread/create: %w", err) - } - if resp.Error != nil { - return nil, fmt.Errorf("thread/create: %w", resp.Error) + return fmt.Errorf("marshal user_turn op: %w", err) } - var result ThreadCreateResult - if err := json.Unmarshal(resp.Result, &result); err != nil { - return nil, fmt.Errorf("unmarshal thread/create result: %w", err) + sub := &Submission{ + ID: c.NextID(), + Op: opData, } - return &result, nil + return c.Send(sub) } -func (c *Client) ListThreads(ctx context.Context) (*ThreadListResult, error) { - resp, err := c.Call(ctx, MethodThreadList, json.RawMessage(`{}`)) - if err != nil { - return nil, fmt.Errorf("thread/list: %w", err) - } - if resp.Error != nil { - return nil, fmt.Errorf("thread/list: %w", resp.Error) - } +// SendInterrupt sends an interrupt submission to stop the current turn. +func (c *Client) SendInterrupt() error { + op := InterruptOp{Type: OpInterrupt} + opData, _ := json.Marshal(op) - var result ThreadListResult - if err := json.Unmarshal(resp.Result, &result); err != nil { - return nil, fmt.Errorf("unmarshal thread/list result: %w", err) + sub := &Submission{ + ID: c.NextID(), + Op: opData, } - return &result, nil + return c.Send(sub) } -func (c *Client) DeleteThread(ctx context.Context, threadID string) error { - params, _ := json.Marshal(ThreadDeleteParams{ - ThreadID: threadID, - }) +// SendShutdown sends a shutdown submission. +func (c *Client) SendShutdown() error { + op := ShutdownOp{Type: OpShutdown} + opData, _ := json.Marshal(op) - resp, err := c.Call(ctx, MethodThreadDelete, params) - if err != nil { - return fmt.Errorf("thread/delete: %w", err) - } - if resp.Error != nil { - return fmt.Errorf("thread/delete: %w", resp.Error) + sub := &Submission{ + ID: c.NextID(), + Op: opData, } - return nil + return c.Send(sub) } diff --git a/internal/appserver/client_thread_test.go b/internal/appserver/client_thread_test.go index 8ac0dd2..c88d017 100644 --- a/internal/appserver/client_thread_test.go +++ b/internal/appserver/client_thread_test.go @@ -2,114 +2,97 @@ package appserver import ( "bufio" - "context" "encoding/json" "io" "testing" "time" ) -func TestClientCreateThread(t *testing.T) { - clientRead, serverWrite := io.Pipe() +func TestSendUserTurn(t *testing.T) { + clientRead, _ := io.Pipe() serverRead, clientWrite := io.Pipe() - go mockThreadCreateServer(t, serverRead, serverWrite) - client := &Client{} client.stdin = clientWrite client.stdout = clientRead client.scanner = bufio.NewScanner(clientRead) - client.scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + client.scanner.Buffer(make([]byte, scannerBufferSize), scannerBufferSize) client.running.Store(true) - go client.ReadLoop(client.Dispatch) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - result, err := client.CreateThread(ctx, "Build a web app") + received := make(chan map[string]any, 1) + go func() { + scanner := bufio.NewScanner(serverRead) + scanner.Buffer(make([]byte, scannerBufferSize), scannerBufferSize) + if scanner.Scan() { + var parsed map[string]any + json.Unmarshal(scanner.Bytes(), &parsed) + received <- parsed + } + }() + + err := client.SendUserTurn("Hello world", "/tmp", "o4-mini") if err != nil { - t.Fatalf("CreateThread failed: %v", err) - } - if result.ThreadID != "t-new-123" { - t.Errorf("expected t-new-123, got %s", result.ThreadID) - } -} - -func mockThreadCreateServer(t *testing.T, reader *io.PipeReader, writer *io.PipeWriter) { - t.Helper() - scanner := bufio.NewScanner(reader) - scanner.Buffer(make([]byte, 1024*1024), 1024*1024) - - if !scanner.Scan() { - t.Error("mock: failed to read request") - return - } - var req Message - if err := json.Unmarshal(scanner.Bytes(), &req); err != nil { - t.Errorf("mock: unmarshal: %v", err) - return + t.Fatalf("SendUserTurn failed: %v", err) } - resp := Message{ - JSONRPC: "2.0", - ID: req.ID, - Result: json.RawMessage(`{"threadId":"t-new-123"}`), + select { + case msg := <-received: + if msg["id"] == nil || msg["id"] == "" { + t.Error("expected non-empty id") + } + opRaw, _ := json.Marshal(msg["op"]) + var op map[string]any + json.Unmarshal(opRaw, &op) + if op["type"] != OpUserTurn { + t.Errorf("expected user_turn, got %v", op["type"]) + } + if op["model"] != "o4-mini" { + t.Errorf("expected model o4-mini, got %v", op["model"]) + } + if op["cwd"] != "/tmp" { + t.Errorf("expected cwd /tmp, got %v", op["cwd"]) + } + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for submission") } - data, _ := json.Marshal(resp) - data = append(data, '\n') - writer.Write(data) } -func TestClientListThreads(t *testing.T) { - clientRead, serverWrite := io.Pipe() +func TestSendInterrupt(t *testing.T) { + clientRead, _ := io.Pipe() serverRead, clientWrite := io.Pipe() - go mockThreadListServer(t, serverRead, serverWrite) - client := &Client{} client.stdin = clientWrite client.stdout = clientRead client.scanner = bufio.NewScanner(clientRead) - client.scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + client.scanner.Buffer(make([]byte, scannerBufferSize), scannerBufferSize) client.running.Store(true) - go client.ReadLoop(client.Dispatch) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - result, err := client.ListThreads(ctx) + received := make(chan map[string]any, 1) + go func() { + scanner := bufio.NewScanner(serverRead) + scanner.Buffer(make([]byte, scannerBufferSize), scannerBufferSize) + if scanner.Scan() { + var parsed map[string]any + json.Unmarshal(scanner.Bytes(), &parsed) + received <- parsed + } + }() + + err := client.SendInterrupt() if err != nil { - t.Fatalf("ListThreads failed: %v", err) - } - if len(result.Threads) != 2 { - t.Fatalf("expected 2 threads, got %d", len(result.Threads)) - } -} - -func mockThreadListServer(t *testing.T, reader *io.PipeReader, writer *io.PipeWriter) { - t.Helper() - scanner := bufio.NewScanner(reader) - scanner.Buffer(make([]byte, 1024*1024), 1024*1024) - - if !scanner.Scan() { - t.Error("mock: failed to read request") - return - } - var req Message - if err := json.Unmarshal(scanner.Bytes(), &req); err != nil { - t.Errorf("mock: unmarshal: %v", err) - return + t.Fatalf("SendInterrupt failed: %v", err) } - threadList := `{"threads":[{"id":"t-1","status":"active","title":"A"},{"id":"t-2","status":"idle","title":"B"}]}` - resp := Message{ - JSONRPC: "2.0", - ID: req.ID, - Result: json.RawMessage(threadList), + select { + case msg := <-received: + opRaw, _ := json.Marshal(msg["op"]) + var op map[string]any + json.Unmarshal(opRaw, &op) + if op["type"] != OpInterrupt { + t.Errorf("expected interrupt, got %v", op["type"]) + } + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for submission") } - data, _ := json.Marshal(resp) - data = append(data, '\n') - writer.Write(data) } diff --git a/internal/appserver/dispatch_test.go b/internal/appserver/dispatch_test.go index 2490b64..e8c64c7 100644 --- a/internal/appserver/dispatch_test.go +++ b/internal/appserver/dispatch_test.go @@ -6,48 +6,52 @@ import ( "testing" ) -func TestDispatchRoutesNotificationToRouter(t *testing.T) { - client := &Client{} - router := NewNotificationRouter() +func TestEventRouterDispatchesSessionConfigured(t *testing.T) { + router := NewEventRouter() var called atomic.Bool - router.OnThreadStatusChanged(func(params ThreadStatusChanged) { + router.OnSessionConfigured(func(event SessionConfigured) { called.Store(true) - if params.ThreadID != "t-1" { - t.Errorf("expected t-1, got %s", params.ThreadID) + if event.SessionID != "sess-1" { + t.Errorf("expected sess-1, got %s", event.SessionID) } }) - client.Router = router - - msg := Message{ - JSONRPC: "2.0", - Method: NotifyThreadStatusChanged, - Params: json.RawMessage(`{"threadId":"t-1","status":"active","title":"Test"}`), + event := Event{ + ID: "", + Msg: json.RawMessage(`{"type":"session_configured","session_id":"sess-1","model":"gpt-4o"}`), } - client.Dispatch(msg) + router.HandleEvent(event) if !called.Load() { - t.Error("router handler was not called") + t.Error("handler was not called") } } -func TestDispatchFallsBackToOnNotification(t *testing.T) { - client := &Client{} - - var called atomic.Bool - client.OnNotification = func(method string, params json.RawMessage) { - called.Store(true) +func TestEventRouterIgnoresUnregisteredType(t *testing.T) { + router := NewEventRouter() + event := Event{ + ID: "", + Msg: json.RawMessage(`{"type":"unknown_event"}`), } + router.HandleEvent(event) +} - msg := Message{ - JSONRPC: "2.0", - Method: "custom/notification", - Params: json.RawMessage(`{}`), +func TestEventRouterDispatchesError(t *testing.T) { + router := NewEventRouter() + + var receivedMsg string + router.OnError(func(event ServerError) { + receivedMsg = event.Message + }) + + event := Event{ + ID: "", + Msg: json.RawMessage(`{"type":"error","message":"test error"}`), } - client.Dispatch(msg) + router.HandleEvent(event) - if !called.Load() { - t.Error("OnNotification was not called") + if receivedMsg != "test error" { + t.Errorf("expected 'test error', got %s", receivedMsg) } } diff --git a/internal/appserver/integration_test.go b/internal/appserver/integration_test.go index 1ae343d..556eebf 100644 --- a/internal/appserver/integration_test.go +++ b/internal/appserver/integration_test.go @@ -8,23 +8,29 @@ import ( "time" ) -func TestIntegrationAppServerConnect(t *testing.T) { - client := NewClient("codex", "app-server", "--listen", "stdio://") +func TestIntegrationProtoConnect(t *testing.T) { + client := NewClient("codex", "proto") ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() if err := client.Start(ctx); err != nil { - t.Fatalf("Failed to start app-server: %v", err) + t.Fatalf("Failed to start codex proto: %v", err) } defer client.Stop() - go client.ReadLoop(client.Dispatch) + received := make(chan SessionConfigured, 1) + client.Router = NewEventRouter() + client.Router.OnSessionConfigured(func(event SessionConfigured) { + received <- event + }) - caps, err := client.Initialize(ctx) - if err != nil { - t.Fatalf("Initialize failed: %v", err) - } + go client.ReadLoop() - t.Logf("Connected to: %s %s", caps.ServerInfo.Name, caps.ServerInfo.Version) + select { + case event := <-received: + t.Logf("Connected: session=%s model=%s", event.SessionID, event.Model) + case <-ctx.Done(): + t.Fatal("timeout waiting for session_configured") + } } diff --git a/internal/appserver/methods.go b/internal/appserver/methods.go index ff6dafb..9ffe97d 100644 --- a/internal/appserver/methods.go +++ b/internal/appserver/methods.go @@ -1,17 +1,28 @@ package appserver const ( - MethodThreadCreate = "thread/create" - MethodThreadList = "thread/list" - MethodThreadDelete = "thread/delete" - MethodThreadSendMessage = "thread/sendMessage" - MethodCommandExec = "command/exec" + OpUserTurn = "user_turn" + OpInterrupt = "interrupt" + OpExecApproval = "exec_approval" + OpPatchApproval = "patch_approval" + OpShutdown = "shutdown" + OpGetHistoryEntry = "get_history_entry_request" ) const ( - NotifyThreadStatusChanged = "thread/status/changed" - NotifyThreadMessageCreated = "thread/message/created" - NotifyThreadMessageDelta = "thread/message/delta" - NotifyCommandOutput = "command/output" - NotifyCommandFinished = "command/finished" + EventSessionConfigured = "session_configured" + EventTaskStarted = "task_started" + EventTaskComplete = "task_complete" + EventTurnAborted = "turn_aborted" + EventAgentMessage = "agent_message" + EventAgentMessageDelta = "agent_message_delta" + EventExecCommandBegin = "exec_command_begin" + EventExecCommandOutputDelta = "exec_command_output_delta" + EventExecCommandEnd = "exec_command_end" + EventExecApprovalRequest = "exec_approval_request" + EventPatchApplyBegin = "patch_apply_begin" + EventPatchApplyEnd = "patch_apply_end" + EventError = "error" + EventWarning = "warning" + EventShutdownComplete = "shutdown_complete" ) diff --git a/internal/appserver/methods_test.go b/internal/appserver/methods_test.go index db05e13..2c4cf93 100644 --- a/internal/appserver/methods_test.go +++ b/internal/appserver/methods_test.go @@ -2,22 +2,42 @@ package appserver import "testing" -func TestMethodConstants(t *testing.T) { +func TestOpConstants(t *testing.T) { tests := []struct { name string constant string expected string }{ - {"ThreadCreate", MethodThreadCreate, "thread/create"}, - {"ThreadList", MethodThreadList, "thread/list"}, - {"ThreadDelete", MethodThreadDelete, "thread/delete"}, - {"ThreadSendMessage", MethodThreadSendMessage, "thread/sendMessage"}, - {"CommandExec", MethodCommandExec, "command/exec"}, - {"NotifyThreadStatus", NotifyThreadStatusChanged, "thread/status/changed"}, - {"NotifyThreadMessage", NotifyThreadMessageCreated, "thread/message/created"}, - {"NotifyMessageDelta", NotifyThreadMessageDelta, "thread/message/delta"}, - {"NotifyCommandOutput", NotifyCommandOutput, "command/output"}, - {"NotifyCommandFinished", NotifyCommandFinished, "command/finished"}, + {"UserTurn", OpUserTurn, "user_turn"}, + {"Interrupt", OpInterrupt, "interrupt"}, + {"ExecApproval", OpExecApproval, "exec_approval"}, + {"Shutdown", OpShutdown, "shutdown"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.constant != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, tt.constant) + } + }) + } +} + +func TestEventConstants(t *testing.T) { + tests := []struct { + name string + constant string + expected string + }{ + {"SessionConfigured", EventSessionConfigured, "session_configured"}, + {"TaskStarted", EventTaskStarted, "task_started"}, + {"TaskComplete", EventTaskComplete, "task_complete"}, + {"AgentMessage", EventAgentMessage, "agent_message"}, + {"AgentMessageDelta", EventAgentMessageDelta, "agent_message_delta"}, + {"ExecCommandBegin", EventExecCommandBegin, "exec_command_begin"}, + {"ExecCommandOutputDelta", EventExecCommandOutputDelta, "exec_command_output_delta"}, + {"ExecCommandEnd", EventExecCommandEnd, "exec_command_end"}, + {"ExecApprovalRequest", EventExecApprovalRequest, "exec_approval_request"}, + {"Error", EventError, "error"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/appserver/protocol.go b/internal/appserver/protocol.go index 6c697be..201f29e 100644 --- a/internal/appserver/protocol.go +++ b/internal/appserver/protocol.go @@ -2,35 +2,24 @@ package appserver import "encoding/json" -// Message is the generic JSON-RPC 2.0 envelope. -// It covers requests (id + method), responses (id + result/error), -// and notifications (method, no id). -type Message struct { - JSONRPC string `json:"jsonrpc"` - ID *int `json:"id,omitempty"` - Method string `json:"method,omitempty"` - Params json.RawMessage `json:"params,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *RPCError `json:"error,omitempty"` +// Submission is a client-to-server message in the Codex proto protocol. +type Submission struct { + ID string `json:"id"` + Op json.RawMessage `json:"op"` } -// Request is an outbound JSON-RPC request. -type Request struct { - JSONRPC string `json:"jsonrpc"` - ID *int `json:"id,omitempty"` - Method string `json:"method"` - Params json.RawMessage `json:"params,omitempty"` +// Event is a server-to-client message in the Codex proto protocol. +type Event struct { + ID string `json:"id"` + Msg json.RawMessage `json:"msg"` } -// Response is an inbound JSON-RPC response. -type Response struct { - JSONRPC string `json:"jsonrpc"` - ID *int `json:"id,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *RPCError `json:"error,omitempty"` +// EventHeader extracts just the type discriminator from an event message. +type EventHeader struct { + Type string `json:"type"` } -// RPCError is the JSON-RPC error object. +// RPCError is an error returned by the server. type RPCError struct { Code int `json:"code"` Message string `json:"message"` diff --git a/internal/appserver/protocol_test.go b/internal/appserver/protocol_test.go index d288013..dd9c158 100644 --- a/internal/appserver/protocol_test.go +++ b/internal/appserver/protocol_test.go @@ -5,68 +5,65 @@ import ( "testing" ) -func TestRequestMarshal(t *testing.T) { - req := &Request{ - JSONRPC: "2.0", - ID: intPtr(1), - Method: "thread/list", - Params: json.RawMessage(`{}`), +func TestSubmissionMarshal(t *testing.T) { + op := UserTurnOp{ + Type: OpUserTurn, + Items: []UserInput{NewTextInput("hello")}, + Cwd: "/tmp", + ApprovalPolicy: "never", + SandboxPolicy: SandboxPolicyReadOnly(), + Model: "o4-mini", } - data, err := json.Marshal(req) + opData, _ := json.Marshal(op) + + sub := &Submission{ + ID: "sub-1", + Op: opData, + } + data, err := json.Marshal(sub) if err != nil { t.Fatal(err) } var parsed map[string]any json.Unmarshal(data, &parsed) - if parsed["jsonrpc"] != "2.0" { - t.Errorf("expected jsonrpc 2.0, got %v", parsed["jsonrpc"]) - } - if parsed["method"] != "thread/list" { - t.Errorf("expected method thread/list, got %v", parsed["method"]) + if parsed["id"] != "sub-1" { + t.Errorf("expected id sub-1, got %v", parsed["id"]) } } -func TestResponseUnmarshal(t *testing.T) { - raw := `{"jsonrpc":"2.0","id":1,"result":{"threads":[]}}` - var resp Response - if err := json.Unmarshal([]byte(raw), &resp); err != nil { +func TestEventUnmarshal(t *testing.T) { + raw := `{"id":"","msg":{"type":"session_configured","session_id":"sess-1","model":"gpt-4o"}}` + var event Event + if err := json.Unmarshal([]byte(raw), &event); err != nil { t.Fatal(err) } - if resp.ID == nil || *resp.ID != 1 { - t.Errorf("expected id 1, got %v", resp.ID) + if event.ID != "" { + t.Errorf("expected empty id, got %s", event.ID) } - if resp.Error != nil { - t.Error("expected no error") - } -} -func TestNotificationUnmarshal(t *testing.T) { - raw := `{"jsonrpc":"2.0","method":"thread/status/changed","params":{"threadId":"t1","status":"active"}}` - var msg Message - if err := json.Unmarshal([]byte(raw), &msg); err != nil { + var header EventHeader + if err := json.Unmarshal(event.Msg, &header); err != nil { t.Fatal(err) } - if msg.Method != "thread/status/changed" { - t.Errorf("expected thread/status/changed, got %s", msg.Method) - } - // Notification: no ID - if msg.ID != nil { - t.Error("notification should have no id") + if header.Type != EventSessionConfigured { + t.Errorf("expected session_configured, got %s", header.Type) } } -func TestErrorResponseUnmarshal(t *testing.T) { - raw := `{"jsonrpc":"2.0","id":2,"error":{"code":-32600,"message":"Invalid Request"}}` - var resp Response - if err := json.Unmarshal([]byte(raw), &resp); err != nil { +func TestEventHeaderExtraction(t *testing.T) { + raw := `{"type":"agent_message_delta","delta":"hello"}` + var header EventHeader + if err := json.Unmarshal([]byte(raw), &header); err != nil { t.Fatal(err) } - if resp.Error == nil { - t.Fatal("expected error") - } - if resp.Error.Code != -32600 { - t.Errorf("expected code -32600, got %d", resp.Error.Code) + if header.Type != EventAgentMessageDelta { + t.Errorf("expected agent_message_delta, got %s", header.Type) } } -func intPtr(i int) *int { return &i } +func TestRPCErrorMessage(t *testing.T) { + err := &RPCError{Code: -1, Message: "something broke"} + if err.Error() != "something broke" { + t.Errorf("expected 'something broke', got %s", err.Error()) + } +} diff --git a/internal/appserver/router.go b/internal/appserver/router.go index f9398fa..d9b996e 100644 --- a/internal/appserver/router.go +++ b/internal/appserver/router.go @@ -2,70 +2,128 @@ package appserver import "encoding/json" -type NotificationRouter struct { - handlers map[string]func(json.RawMessage) +// EventRouter dispatches incoming events by their type field. +type EventRouter struct { + handlers map[string]func(string, json.RawMessage) } -func NewNotificationRouter() *NotificationRouter { - return &NotificationRouter{ - handlers: make(map[string]func(json.RawMessage)), +// NewEventRouter creates a new event router. +func NewEventRouter() *EventRouter { + return &EventRouter{ + handlers: make(map[string]func(string, json.RawMessage)), } } -func (router *NotificationRouter) Handle(method string, params json.RawMessage) { - handler, exists := router.handlers[method] +// HandleEvent extracts the event type and dispatches to the registered handler. +func (router *EventRouter) HandleEvent(event Event) { + var header EventHeader + if err := json.Unmarshal(event.Msg, &header); err != nil { + return + } + + handler, exists := router.handlers[header.Type] if !exists { return } - handler(params) + handler(event.ID, event.Msg) +} + +func (router *EventRouter) OnSessionConfigured(fn func(SessionConfigured)) { + router.handlers[EventSessionConfigured] = func(_ string, raw json.RawMessage) { + var event SessionConfigured + if err := json.Unmarshal(raw, &event); err != nil { + return + } + fn(event) + } +} + +func (router *EventRouter) OnTaskStarted(fn func(TaskStarted)) { + router.handlers[EventTaskStarted] = func(_ string, raw json.RawMessage) { + var event TaskStarted + if err := json.Unmarshal(raw, &event); err != nil { + return + } + fn(event) + } +} + +func (router *EventRouter) OnTaskComplete(fn func(TaskComplete)) { + router.handlers[EventTaskComplete] = func(_ string, raw json.RawMessage) { + var event TaskComplete + if err := json.Unmarshal(raw, &event); err != nil { + return + } + fn(event) + } +} + +func (router *EventRouter) OnAgentMessage(fn func(AgentMessage)) { + router.handlers[EventAgentMessage] = func(_ string, raw json.RawMessage) { + var event AgentMessage + if err := json.Unmarshal(raw, &event); err != nil { + return + } + fn(event) + } +} + +func (router *EventRouter) OnAgentMessageDelta(fn func(AgentMessageDelta)) { + router.handlers[EventAgentMessageDelta] = func(_ string, raw json.RawMessage) { + var event AgentMessageDelta + if err := json.Unmarshal(raw, &event); err != nil { + return + } + fn(event) + } } -func (router *NotificationRouter) OnThreadStatusChanged(fn func(ThreadStatusChanged)) { - router.handlers[NotifyThreadStatusChanged] = func(raw json.RawMessage) { - var params ThreadStatusChanged - if err := json.Unmarshal(raw, ¶ms); err != nil { +func (router *EventRouter) OnExecCommandBegin(fn func(ExecCommandBegin)) { + router.handlers[EventExecCommandBegin] = func(_ string, raw json.RawMessage) { + var event ExecCommandBegin + if err := json.Unmarshal(raw, &event); err != nil { return } - fn(params) + fn(event) } } -func (router *NotificationRouter) OnThreadMessageCreated(fn func(ThreadMessageCreated)) { - router.handlers[NotifyThreadMessageCreated] = func(raw json.RawMessage) { - var params ThreadMessageCreated - if err := json.Unmarshal(raw, ¶ms); err != nil { +func (router *EventRouter) OnExecCommandOutputDelta(fn func(ExecCommandOutputDelta)) { + router.handlers[EventExecCommandOutputDelta] = func(_ string, raw json.RawMessage) { + var event ExecCommandOutputDelta + if err := json.Unmarshal(raw, &event); err != nil { return } - fn(params) + fn(event) } } -func (router *NotificationRouter) OnThreadMessageDelta(fn func(ThreadMessageDelta)) { - router.handlers[NotifyThreadMessageDelta] = func(raw json.RawMessage) { - var params ThreadMessageDelta - if err := json.Unmarshal(raw, ¶ms); err != nil { +func (router *EventRouter) OnExecCommandEnd(fn func(ExecCommandEnd)) { + router.handlers[EventExecCommandEnd] = func(_ string, raw json.RawMessage) { + var event ExecCommandEnd + if err := json.Unmarshal(raw, &event); err != nil { return } - fn(params) + fn(event) } } -func (router *NotificationRouter) OnCommandOutput(fn func(CommandOutput)) { - router.handlers[NotifyCommandOutput] = func(raw json.RawMessage) { - var params CommandOutput - if err := json.Unmarshal(raw, ¶ms); err != nil { +func (router *EventRouter) OnExecApprovalRequest(fn func(ExecApprovalRequest)) { + router.handlers[EventExecApprovalRequest] = func(_ string, raw json.RawMessage) { + var event ExecApprovalRequest + if err := json.Unmarshal(raw, &event); err != nil { return } - fn(params) + fn(event) } } -func (router *NotificationRouter) OnCommandFinished(fn func(CommandFinished)) { - router.handlers[NotifyCommandFinished] = func(raw json.RawMessage) { - var params CommandFinished - if err := json.Unmarshal(raw, ¶ms); err != nil { +func (router *EventRouter) OnError(fn func(ServerError)) { + router.handlers[EventError] = func(_ string, raw json.RawMessage) { + var event ServerError + if err := json.Unmarshal(raw, &event); err != nil { return } - fn(params) + fn(event) } } diff --git a/internal/appserver/router_test.go b/internal/appserver/router_test.go index d719547..9ddab41 100644 --- a/internal/appserver/router_test.go +++ b/internal/appserver/router_test.go @@ -6,58 +6,97 @@ import ( "testing" ) -func TestRouterDispatchesNotification(t *testing.T) { - router := NewNotificationRouter() +func TestRouterDispatchesAgentMessageDelta(t *testing.T) { + router := NewEventRouter() - var called atomic.Bool - router.OnThreadStatusChanged(func(params ThreadStatusChanged) { - called.Store(true) - if params.ThreadID != "t-1" { - t.Errorf("expected t-1, got %s", params.ThreadID) - } + var receivedDelta string + router.OnAgentMessageDelta(func(event AgentMessageDelta) { + receivedDelta = event.Delta }) - raw := json.RawMessage(`{"threadId":"t-1","status":"active","title":"Test"}`) - router.Handle(NotifyThreadStatusChanged, raw) - - if !called.Load() { - t.Error("handler was not called") + event := Event{ + ID: "sub-1", + Msg: json.RawMessage(`{"type":"agent_message_delta","delta":"hello"}`), } -} + router.HandleEvent(event) -func TestRouterIgnoresUnregisteredMethod(t *testing.T) { - router := NewNotificationRouter() - router.Handle("unknown/method", json.RawMessage(`{}`)) + if receivedDelta != "hello" { + t.Errorf("expected hello, got %s", receivedDelta) + } } -func TestRouterDispatchesMessageDelta(t *testing.T) { - router := NewNotificationRouter() +func TestRouterDispatchesExecCommandBegin(t *testing.T) { + router := NewEventRouter() - var receivedDelta string - router.OnThreadMessageDelta(func(params ThreadMessageDelta) { - receivedDelta = params.Delta + var receivedCommand string + router.OnExecCommandBegin(func(event ExecCommandBegin) { + receivedCommand = event.Command }) - raw := json.RawMessage(`{"threadId":"t-1","messageId":"m-1","delta":"hello"}`) - router.Handle(NotifyThreadMessageDelta, raw) + event := Event{ + ID: "sub-1", + Msg: json.RawMessage(`{"type":"exec_command_begin","call_id":"cmd-1","command":"ls"}`), + } + router.HandleEvent(event) - if receivedDelta != "hello" { - t.Errorf("expected hello, got %s", receivedDelta) + if receivedCommand != "ls" { + t.Errorf("expected ls, got %s", receivedCommand) } } -func TestRouterDispatchesCommandOutput(t *testing.T) { - router := NewNotificationRouter() +func TestRouterDispatchesExecCommandOutputDelta(t *testing.T) { + router := NewEventRouter() var receivedData string - router.OnCommandOutput(func(params CommandOutput) { - receivedData = params.Data + router.OnExecCommandOutputDelta(func(event ExecCommandOutputDelta) { + receivedData = event.Delta }) - raw := json.RawMessage(`{"threadId":"t-1","execId":"e-1","data":"output line\n"}`) - router.Handle(NotifyCommandOutput, raw) + event := Event{ + ID: "sub-1", + Msg: json.RawMessage(`{"type":"exec_command_output_delta","call_id":"cmd-1","delta":"output\n"}`), + } + router.HandleEvent(event) - if receivedData != "output line\n" { + if receivedData != "output\n" { t.Errorf("expected output, got %s", receivedData) } } + +func TestRouterDispatchesExecApprovalRequest(t *testing.T) { + router := NewEventRouter() + + var called atomic.Bool + router.OnExecApprovalRequest(func(event ExecApprovalRequest) { + called.Store(true) + }) + + event := Event{ + ID: "", + Msg: json.RawMessage(`{"type":"exec_approval_request","call_id":"cmd-1","command":"rm file"}`), + } + router.HandleEvent(event) + + if !called.Load() { + t.Error("handler was not called") + } +} + +func TestRouterDispatchesTaskComplete(t *testing.T) { + router := NewEventRouter() + + var called atomic.Bool + router.OnTaskComplete(func(event TaskComplete) { + called.Store(true) + }) + + event := Event{ + ID: "sub-1", + Msg: json.RawMessage(`{"type":"task_complete"}`), + } + router.HandleEvent(event) + + if !called.Load() { + t.Error("handler was not called") + } +} diff --git a/internal/appserver/types_command.go b/internal/appserver/types_command.go index 8e34704..89aed6f 100644 --- a/internal/appserver/types_command.go +++ b/internal/appserver/types_command.go @@ -1,20 +1,24 @@ package appserver +// CommandExecParams is used when the TUI wants to execute a command. type CommandExecParams struct { ThreadID string `json:"threadId"` Command string `json:"command"` TTY bool `json:"tty"` } +// CommandExecResult is the result of a command execution. type CommandExecResult struct { ExecID string `json:"execId"` } +// ConfirmExecParams is used for exec approval responses. type ConfirmExecParams struct { ThreadID string `json:"threadId"` Command string `json:"command"` } +// ConfirmExecResult is the approval response. type ConfirmExecResult struct { Approved bool `json:"approved"` } diff --git a/internal/appserver/types_notify.go b/internal/appserver/types_notify.go index 3c103e1..755cab0 100644 --- a/internal/appserver/types_notify.go +++ b/internal/appserver/types_notify.go @@ -1,32 +1,68 @@ package appserver -type ThreadStatusChanged struct { - ThreadID string `json:"threadId"` - Status string `json:"status"` - Title string `json:"title"` +// SessionConfigured is the initial event sent by the server on startup. +type SessionConfigured struct { + Type string `json:"type"` + SessionID string `json:"session_id"` + Model string `json:"model"` + ReasoningEffort string `json:"reasoning_effort"` + HistoryLogID int `json:"history_log_id"` + HistoryEntryCount int `json:"history_entry_count"` + RolloutPath string `json:"rollout_path"` } -type ThreadMessageCreated struct { - ThreadID string `json:"threadId"` - MessageID string `json:"messageId"` - Role string `json:"role"` - Content string `json:"content"` +// TaskStarted signals the beginning of an agent turn. +type TaskStarted struct { + Type string `json:"type"` } -type ThreadMessageDelta struct { - ThreadID string `json:"threadId"` - MessageID string `json:"messageId"` - Delta string `json:"delta"` +// TaskComplete signals the end of an agent turn. +type TaskComplete struct { + Type string `json:"type"` } -type CommandOutput struct { - ThreadID string `json:"threadId"` - ExecID string `json:"execId"` - Data string `json:"data"` +// AgentMessage is a complete agent message. +type AgentMessage struct { + Type string `json:"type"` + Content string `json:"content"` } -type CommandFinished struct { - ThreadID string `json:"threadId"` - ExecID string `json:"execId"` - ExitCode int `json:"exitCode"` +// AgentMessageDelta is a streaming text delta from the agent. +type AgentMessageDelta struct { + Type string `json:"type"` + Delta string `json:"delta"` +} + +// ExecCommandBegin signals the start of a command execution. +type ExecCommandBegin struct { + Type string `json:"type"` + ExecID string `json:"call_id"` + Command string `json:"command"` +} + +// ExecCommandOutputDelta is a chunk of command output. +type ExecCommandOutputDelta struct { + Type string `json:"type"` + ExecID string `json:"call_id"` + Delta string `json:"delta"` +} + +// ExecCommandEnd signals the end of a command execution. +type ExecCommandEnd struct { + Type string `json:"type"` + ExecID string `json:"call_id"` + ExitCode int `json:"exit_code"` +} + +// ExecApprovalRequest asks the client to approve a command. +type ExecApprovalRequest struct { + Type string `json:"type"` + ExecID string `json:"call_id"` + Command string `json:"command"` +} + +// ServerError is an error event from the server. +type ServerError struct { + Type string `json:"type"` + Message string `json:"message"` } diff --git a/internal/appserver/types_notify_test.go b/internal/appserver/types_notify_test.go index 7afce24..44ada14 100644 --- a/internal/appserver/types_notify_test.go +++ b/internal/appserver/types_notify_test.go @@ -5,63 +5,85 @@ import ( "testing" ) -func TestThreadStatusChangedUnmarshal(t *testing.T) { - raw := `{"threadId":"t-1","status":"completed","title":"Done"}` - var params ThreadStatusChanged - if err := json.Unmarshal([]byte(raw), ¶ms); err != nil { +func TestSessionConfiguredUnmarshal(t *testing.T) { + raw := `{"type":"session_configured","session_id":"sess-1","model":"gpt-4o","reasoning_effort":"high","history_log_id":123,"history_entry_count":5,"rollout_path":"/tmp/rollout.jsonl"}` + var event SessionConfigured + if err := json.Unmarshal([]byte(raw), &event); err != nil { t.Fatal(err) } - if params.ThreadID != "t-1" { - t.Errorf("expected t-1, got %s", params.ThreadID) + if event.SessionID != "sess-1" { + t.Errorf("expected sess-1, got %s", event.SessionID) } - if params.Status != ThreadStatusCompleted { - t.Errorf("expected completed, got %s", params.Status) + if event.Model != "gpt-4o" { + t.Errorf("expected gpt-4o, got %s", event.Model) } } -func TestThreadMessageCreatedUnmarshal(t *testing.T) { - raw := `{"threadId":"t-1","messageId":"m-1","role":"assistant","content":"Hello"}` - var params ThreadMessageCreated - if err := json.Unmarshal([]byte(raw), ¶ms); err != nil { +func TestAgentMessageDeltaUnmarshal(t *testing.T) { + raw := `{"type":"agent_message_delta","delta":"hello world"}` + var event AgentMessageDelta + if err := json.Unmarshal([]byte(raw), &event); err != nil { t.Fatal(err) } - if params.Role != "assistant" { - t.Errorf("expected assistant, got %s", params.Role) + if event.Delta != "hello world" { + t.Errorf("expected 'hello world', got %s", event.Delta) } - if params.Content != "Hello" { - t.Errorf("expected Hello, got %s", params.Content) +} + +func TestExecCommandBeginUnmarshal(t *testing.T) { + raw := `{"type":"exec_command_begin","call_id":"cmd-1","command":"ls -la"}` + var event ExecCommandBegin + if err := json.Unmarshal([]byte(raw), &event); err != nil { + t.Fatal(err) + } + if event.ExecID != "cmd-1" { + t.Errorf("expected cmd-1, got %s", event.ExecID) + } + if event.Command != "ls -la" { + t.Errorf("expected ls -la, got %s", event.Command) + } +} + +func TestExecCommandOutputDeltaUnmarshal(t *testing.T) { + raw := `{"type":"exec_command_output_delta","call_id":"cmd-1","delta":"output line\n"}` + var event ExecCommandOutputDelta + if err := json.Unmarshal([]byte(raw), &event); err != nil { + t.Fatal(err) + } + if event.Delta != "output line\n" { + t.Errorf("expected output, got %s", event.Delta) } } -func TestThreadMessageDeltaUnmarshal(t *testing.T) { - raw := `{"threadId":"t-1","messageId":"m-1","delta":"more text"}` - var params ThreadMessageDelta - if err := json.Unmarshal([]byte(raw), ¶ms); err != nil { +func TestExecCommandEndUnmarshal(t *testing.T) { + raw := `{"type":"exec_command_end","call_id":"cmd-1","exit_code":0}` + var event ExecCommandEnd + if err := json.Unmarshal([]byte(raw), &event); err != nil { t.Fatal(err) } - if params.Delta != "more text" { - t.Errorf("expected 'more text', got %s", params.Delta) + if event.ExitCode != 0 { + t.Errorf("expected exit code 0, got %d", event.ExitCode) } } -func TestCommandOutputUnmarshal(t *testing.T) { - raw := `{"threadId":"t-1","execId":"e-1","data":"line of output\n"}` - var params CommandOutput - if err := json.Unmarshal([]byte(raw), ¶ms); err != nil { +func TestServerErrorUnmarshal(t *testing.T) { + raw := `{"type":"error","message":"something went wrong"}` + var event ServerError + if err := json.Unmarshal([]byte(raw), &event); err != nil { t.Fatal(err) } - if params.ExecID != "e-1" { - t.Errorf("expected e-1, got %s", params.ExecID) + if event.Message != "something went wrong" { + t.Errorf("expected error message, got %s", event.Message) } } -func TestCommandFinishedUnmarshal(t *testing.T) { - raw := `{"threadId":"t-1","execId":"e-1","exitCode":0}` - var params CommandFinished - if err := json.Unmarshal([]byte(raw), ¶ms); err != nil { +func TestExecApprovalRequestUnmarshal(t *testing.T) { + raw := `{"type":"exec_approval_request","call_id":"cmd-1","command":"rm -rf /tmp/test"}` + var event ExecApprovalRequest + if err := json.Unmarshal([]byte(raw), &event); err != nil { t.Fatal(err) } - if params.ExitCode != 0 { - t.Errorf("expected exit code 0, got %d", params.ExitCode) + if event.Command != "rm -rf /tmp/test" { + t.Errorf("expected command, got %s", event.Command) } } diff --git a/internal/appserver/types_thread.go b/internal/appserver/types_thread.go index 1c5ef3c..63a6b27 100644 --- a/internal/appserver/types_thread.go +++ b/internal/appserver/types_thread.go @@ -1,39 +1,61 @@ package appserver -const ( - ThreadStatusActive = "active" - ThreadStatusIdle = "idle" - ThreadStatusCompleted = "completed" - ThreadStatusError = "error" -) +import "encoding/json" + +// UserTurnOp is the "user_turn" submission operation. +type UserTurnOp struct { + Type string `json:"type"` + Items []UserInput `json:"items"` + Cwd string `json:"cwd"` + ApprovalPolicy string `json:"approval_policy"` + SandboxPolicy json.RawMessage `json:"sandbox_policy"` + Model string `json:"model"` +} -type ThreadCreateParams struct { - Instructions string `json:"instructions"` +// UserInput is a content item in a user turn. +type UserInput struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + TextElements []any `json:"text_elements,omitempty"` } -type ThreadCreateResult struct { - ThreadID string `json:"threadId"` +// NewTextInput creates a text user input item. +func NewTextInput(text string) UserInput { + return UserInput{ + Type: "text", + Text: text, + TextElements: []any{}, + } } -type ThreadDeleteParams struct { - ThreadID string `json:"threadId"` +// SandboxPolicyReadOnly creates a read-only sandbox policy. +func SandboxPolicyReadOnly() json.RawMessage { + return json.RawMessage(`{"type":"read-only","network_access":false}`) } -type ThreadListResult struct { - Threads []ThreadSummary `json:"threads"` +// SandboxPolicyWorkspaceWrite creates a workspace-write sandbox policy. +func SandboxPolicyWorkspaceWrite(roots []string) json.RawMessage { + policy := map[string]any{ + "type": "workspace-write", + "writable_roots": roots, + "network_access": false, + } + data, _ := json.Marshal(policy) + return data } -type ThreadSummary struct { - ID string `json:"id"` - Status string `json:"status"` - Title string `json:"title"` +// InterruptOp is the "interrupt" submission operation. +type InterruptOp struct { + Type string `json:"type"` } -type ThreadSendMessageParams struct { - ThreadID string `json:"threadId"` - Content string `json:"content"` +// ShutdownOp is the "shutdown" submission operation. +type ShutdownOp struct { + Type string `json:"type"` } -type ThreadSendMessageResult struct { - MessageID string `json:"messageId"` +// ExecApprovalOp is the "exec_approval" submission operation. +type ExecApprovalOp struct { + Type string `json:"type"` + Approved bool `json:"approved"` } diff --git a/internal/appserver/types_thread_test.go b/internal/appserver/types_thread_test.go index e6c298c..6067610 100644 --- a/internal/appserver/types_thread_test.go +++ b/internal/appserver/types_thread_test.go @@ -5,57 +5,70 @@ import ( "testing" ) -func TestThreadCreateParamsMarshal(t *testing.T) { - params := ThreadCreateParams{ - Instructions: "Build a web server", +func TestUserTurnOpMarshal(t *testing.T) { + op := UserTurnOp{ + Type: OpUserTurn, + Items: []UserInput{NewTextInput("hello")}, + Cwd: "/home/user", + ApprovalPolicy: "on-request", + SandboxPolicy: SandboxPolicyReadOnly(), + Model: "gpt-4o", } - data, err := json.Marshal(params) + data, err := json.Marshal(op) if err != nil { t.Fatal(err) } var parsed map[string]any json.Unmarshal(data, &parsed) - if parsed["instructions"] != "Build a web server" { - t.Errorf("expected instructions, got %v", parsed["instructions"]) + if parsed["type"] != "user_turn" { + t.Errorf("expected user_turn, got %v", parsed["type"]) } -} - -func TestThreadCreateResultUnmarshal(t *testing.T) { - raw := `{"threadId":"t-abc123"}` - var result ThreadCreateResult - if err := json.Unmarshal([]byte(raw), &result); err != nil { - t.Fatal(err) - } - if result.ThreadID != "t-abc123" { - t.Errorf("expected t-abc123, got %s", result.ThreadID) + if parsed["model"] != "gpt-4o" { + t.Errorf("expected gpt-4o, got %v", parsed["model"]) } } -func TestThreadListResultUnmarshal(t *testing.T) { - raw := `{"threads":[{"id":"t-1","status":"active","title":"Test"}]}` - var result ThreadListResult - if err := json.Unmarshal([]byte(raw), &result); err != nil { - t.Fatal(err) +func TestNewTextInput(t *testing.T) { + input := NewTextInput("hello world") + if input.Type != "text" { + t.Errorf("expected text, got %s", input.Type) } - if len(result.Threads) != 1 { - t.Fatalf("expected 1 thread, got %d", len(result.Threads)) + if input.Text != "hello world" { + t.Errorf("expected hello world, got %s", input.Text) } - if result.Threads[0].ID != "t-1" { - t.Errorf("expected id t-1, got %s", result.Threads[0].ID) + if input.TextElements == nil { + t.Error("expected non-nil text_elements") } - if result.Threads[0].Status != "active" { - t.Errorf("expected status active, got %s", result.Threads[0].Status) +} + +func TestSandboxPolicyReadOnly(t *testing.T) { + policy := SandboxPolicyReadOnly() + var parsed map[string]any + json.Unmarshal(policy, &parsed) + if parsed["type"] != "read-only" { + t.Errorf("expected read-only, got %v", parsed["type"]) } } -func TestThreadStatusValues(t *testing.T) { - if ThreadStatusActive != "active" { - t.Errorf("expected active, got %s", ThreadStatusActive) +func TestSandboxPolicyWorkspaceWrite(t *testing.T) { + policy := SandboxPolicyWorkspaceWrite([]string{"/home/user/project"}) + var parsed map[string]any + json.Unmarshal(policy, &parsed) + if parsed["type"] != "workspace-write" { + t.Errorf("expected workspace-write, got %v", parsed["type"]) } - if ThreadStatusCompleted != "completed" { - t.Errorf("expected completed, got %s", ThreadStatusCompleted) + roots := parsed["writable_roots"].([]any) + if len(roots) != 1 { + t.Fatalf("expected 1 root, got %d", len(roots)) } - if ThreadStatusError != "error" { - t.Errorf("expected error, got %s", ThreadStatusError) +} + +func TestInterruptOpMarshal(t *testing.T) { + op := InterruptOp{Type: OpInterrupt} + data, _ := json.Marshal(op) + var parsed map[string]any + json.Unmarshal(data, &parsed) + if parsed["type"] != "interrupt" { + t.Errorf("expected interrupt, got %v", parsed["type"]) } } diff --git a/internal/config/config.go b/internal/config/config.go index 5b8d71d..6d12d85 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -28,7 +28,7 @@ func Load(path string) (*Config, error) { viperInstance.SetConfigType("toml") viperInstance.SetDefault("appserver.command", DefaultAppServerCommand) - viperInstance.SetDefault("appserver.args", []string{"app-server", "--listen", "stdio://"}) + viperInstance.SetDefault("appserver.args", []string{"proto"}) viperInstance.SetDefault("ui.theme", DefaultTheme) if path != "" { diff --git a/internal/tui/actions_test.go b/internal/tui/actions_test.go index 37a6c93..a4e2255 100644 --- a/internal/tui/actions_test.go +++ b/internal/tui/actions_test.go @@ -28,7 +28,7 @@ func TestMenuEnterDispatchesFork(t *testing.T) { store := state.NewThreadStore() store.Add("t-1", "Test") - app := NewAppModel(store) + app := NewAppModel(store, nil) ctrlB := tea.KeyMsg{Type: tea.KeyCtrlB} updated, _ := app.Update(ctrlB) @@ -64,7 +64,7 @@ func TestMenuEnterDispatchesDelete(t *testing.T) { store := state.NewThreadStore() store.Add("t-1", "Test") - app := NewAppModel(store) + app := NewAppModel(store, nil) ctrlB := tea.KeyMsg{Type: tea.KeyCtrlB} updated, _ := app.Update(ctrlB) diff --git a/internal/tui/app.go b/internal/tui/app.go index bd2a41f..0c8ecdd 100644 --- a/internal/tui/app.go +++ b/internal/tui/app.go @@ -1,8 +1,11 @@ package tui import ( + "context" + tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/robinojw/dj/internal/appserver" "github.com/robinojw/dj/internal/state" ) @@ -18,13 +21,17 @@ var titleStyle = lipgloss.NewStyle(). MarginBottom(1) type AppModel struct { - store *state.ThreadStore - canvas CanvasModel - tree TreeModel - session *SessionModel - prefix *PrefixHandler - menu MenuModel - help HelpModel + store *state.ThreadStore + client *appserver.Client + canvas CanvasModel + tree TreeModel + session *SessionModel + prefix *PrefixHandler + menu MenuModel + help HelpModel + statusBar *StatusBar + + connected bool menuVisible bool helpVisible bool focus int @@ -32,13 +39,20 @@ type AppModel struct { height int } -func NewAppModel(store *state.ThreadStore) AppModel { +func NewAppModel(store *state.ThreadStore, client *appserver.Client) AppModel { + bar := NewStatusBar() + if client != nil { + bar.SetConnecting() + } + return AppModel{ - store: store, - canvas: NewCanvasModel(store), - tree: NewTreeModel(store), - prefix: NewPrefixHandler(), - help: NewHelpModel(), + store: store, + client: client, + canvas: NewCanvasModel(store), + tree: NewTreeModel(store), + prefix: NewPrefixHandler(), + help: NewHelpModel(), + statusBar: bar, } } @@ -51,7 +65,19 @@ func (app AppModel) HelpVisible() bool { } func (app AppModel) Init() tea.Cmd { - return nil + if app.client == nil { + return nil + } + + return func() tea.Msg { + if err := app.client.Start(context.Background()); err != nil { + return AppServerErrorMsg{Err: err} + } + + go app.client.ReadLoop() + + return nil + } } func (app AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -61,6 +87,7 @@ func (app AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.WindowSizeMsg: app.width = msg.Width app.height = msg.Height + app.statusBar.SetWidth(msg.Width) if app.session != nil { app.session.SetSize(msg.Width, msg.Height) } @@ -74,6 +101,20 @@ func (app AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return app.handleThreadDelta(msg) case CommandOutputMsg: return app.handleCommandOutput(msg) + case ThreadCreatedMsg: + app.store.Add(msg.ThreadID, msg.Title) + app.statusBar.SetThreadCount(len(app.store.All())) + return app, nil + case AppServerConnectedMsg: + app.connected = true + app.statusBar.SetConnected(true) + app.store.Add(msg.SessionID, msg.Model) + app.statusBar.SetThreadCount(len(app.store.All())) + app.statusBar.SetSelectedThread(msg.Model) + return app, nil + case AppServerErrorMsg: + app.statusBar.SetError(msg.Error()) + return app, nil } return app, nil } @@ -119,10 +160,18 @@ func (app AppModel) handleRune(msg tea.KeyMsg) (tea.Model, tea.Cmd) { app.toggleFocus() case "?": app.helpVisible = !app.helpVisible + case "n": + if !app.connected { + app.statusBar.SetError("waiting for codex — is codex CLI installed?") + return app, nil + } + app.statusBar.SetError("single session mode — session auto-created on connect") + return app, nil } return app, nil } + func (app AppModel) handleHelpKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { isToggle := msg.Type == tea.KeyRunes && msg.String() == "?" isEsc := msg.Type == tea.KeyEsc @@ -171,123 +220,40 @@ func (app *AppModel) handleCanvasArrow(msg tea.KeyMsg) { } } -func (app AppModel) handleSessionKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - switch msg.Type { - case tea.KeyCtrlC: - return app, tea.Quit - case tea.KeyEsc: - app.closeSession() - return app, nil - case tea.KeyUp, tea.KeyDown, tea.KeyPgUp, tea.KeyPgDown: - app.scrollSession(msg) - return app, nil +func (app AppModel) canvasView() string { + threads := app.store.All() + if len(threads) > 0 { + return app.canvas.View() } - return app, nil -} - -func (app AppModel) openSession() (tea.Model, tea.Cmd) { - threadID := app.canvas.SelectedThreadID() - if threadID == "" { - return app, nil + if !app.connected { + return "Waiting for app-server connection..." } - - thread, exists := app.store.Get(threadID) - if !exists { - return app, nil - } - - session := NewSessionModel(thread) - session.SetSize(app.width, app.height) - app.session = &session - app.focus = FocusSession - return app, nil -} - -func (app *AppModel) closeSession() { - app.session = nil - app.focus = FocusCanvas -} - -func (app *AppModel) scrollSession(msg tea.KeyMsg) { - if app.session == nil { - return - } - - switch msg.Type { - case tea.KeyUp: - app.session.viewport.ScrollUp(1) - case tea.KeyDown: - app.session.viewport.ScrollDown(1) - case tea.KeyPgUp: - app.session.viewport.HalfPageUp() - case tea.KeyPgDown: - app.session.viewport.HalfPageDown() - } -} - -func (app *AppModel) refreshSession() { - if app.session == nil { - return - } - app.session.Refresh() -} - -func (app AppModel) handleThreadMessage(msg ThreadMessageMsg) (tea.Model, tea.Cmd) { - thread, exists := app.store.Get(msg.ThreadID) - if !exists { - return app, nil - } - thread.AppendMessage(state.ChatMessage{ - ID: msg.MessageID, - Role: msg.Role, - Content: msg.Content, - }) - app.refreshSession() - return app, nil -} - -func (app AppModel) handleThreadDelta(msg ThreadDeltaMsg) (tea.Model, tea.Cmd) { - thread, exists := app.store.Get(msg.ThreadID) - if !exists { - return app, nil - } - thread.AppendDelta(msg.MessageID, msg.Delta) - app.refreshSession() - return app, nil -} - -func (app AppModel) handleCommandOutput(msg CommandOutputMsg) (tea.Model, tea.Cmd) { - thread, exists := app.store.Get(msg.ThreadID) - if !exists { - return app, nil - } - thread.AppendOutput(msg.ExecID, msg.Data) - app.refreshSession() - return app, nil + return "No active threads. Press 'n' to create one." } func (app AppModel) View() string { title := titleStyle.Render("DJ — Codex TUI Visualizer") + status := app.statusBar.View() if app.helpVisible { - return title + "\n" + app.help.View() + "\n" + return title + "\n" + app.help.View() + "\n" + status } if app.menuVisible { - return title + "\n" + app.menu.View() + "\n" + return title + "\n" + app.menu.View() + "\n" + status } if app.focus == FocusSession && app.session != nil { - return title + "\n" + app.session.View() + "\n" + return title + "\n" + app.session.View() + "\n" + status } - canvas := app.canvas.View() + canvas := app.canvasView() if app.focus == FocusTree { treeView := app.tree.View() body := lipgloss.JoinHorizontal(lipgloss.Top, treeView+" ", canvas) - return title + "\n" + body + "\n" + return title + "\n" + body + "\n" + status } - return title + "\n" + canvas + "\n" + return title + "\n" + canvas + "\n" + status } diff --git a/internal/tui/app_session.go b/internal/tui/app_session.go new file mode 100644 index 0000000..1a8bcff --- /dev/null +++ b/internal/tui/app_session.go @@ -0,0 +1,101 @@ +package tui + +import ( + tea "github.com/charmbracelet/bubbletea" + "github.com/robinojw/dj/internal/state" +) + +func (app AppModel) handleSessionKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.Type { + case tea.KeyCtrlC: + return app, tea.Quit + case tea.KeyEsc: + app.closeSession() + return app, nil + case tea.KeyUp, tea.KeyDown, tea.KeyPgUp, tea.KeyPgDown: + app.scrollSession(msg) + return app, nil + } + return app, nil +} + +func (app AppModel) openSession() (tea.Model, tea.Cmd) { + threadID := app.canvas.SelectedThreadID() + if threadID == "" { + return app, nil + } + + thread, exists := app.store.Get(threadID) + if !exists { + return app, nil + } + + session := NewSessionModel(thread) + session.SetSize(app.width, app.height) + app.session = &session + app.focus = FocusSession + return app, nil +} + +func (app *AppModel) closeSession() { + app.session = nil + app.focus = FocusCanvas +} + +func (app *AppModel) scrollSession(msg tea.KeyMsg) { + if app.session == nil { + return + } + + switch msg.Type { + case tea.KeyUp: + app.session.viewport.ScrollUp(1) + case tea.KeyDown: + app.session.viewport.ScrollDown(1) + case tea.KeyPgUp: + app.session.viewport.HalfPageUp() + case tea.KeyPgDown: + app.session.viewport.HalfPageDown() + } +} + +func (app *AppModel) refreshSession() { + if app.session == nil { + return + } + app.session.Refresh() +} + +func (app AppModel) handleThreadMessage(msg ThreadMessageMsg) (tea.Model, tea.Cmd) { + thread, exists := app.store.Get(msg.ThreadID) + if !exists { + return app, nil + } + thread.AppendMessage(state.ChatMessage{ + ID: msg.MessageID, + Role: msg.Role, + Content: msg.Content, + }) + app.refreshSession() + return app, nil +} + +func (app AppModel) handleThreadDelta(msg ThreadDeltaMsg) (tea.Model, tea.Cmd) { + thread, exists := app.store.Get(msg.ThreadID) + if !exists { + return app, nil + } + thread.AppendDelta(msg.MessageID, msg.Delta) + app.refreshSession() + return app, nil +} + +func (app AppModel) handleCommandOutput(msg CommandOutputMsg) (tea.Model, tea.Cmd) { + thread, exists := app.store.Get(msg.ThreadID) + if !exists { + return app, nil + } + thread.AppendOutput(msg.ExecID, msg.Data) + app.refreshSession() + return app, nil +} diff --git a/internal/tui/app_test.go b/internal/tui/app_test.go index bac2a3d..6675141 100644 --- a/internal/tui/app_test.go +++ b/internal/tui/app_test.go @@ -12,7 +12,7 @@ func TestAppHandlesArrowKeys(t *testing.T) { store.Add("t-1", "First") store.Add("t-2", "Second") - app := NewAppModel(store) + app := NewAppModel(store, nil) rightKey := tea.KeyMsg{Type: tea.KeyRight} updated, _ := app.Update(rightKey) @@ -27,7 +27,7 @@ func TestAppHandlesThreadStatusMsg(t *testing.T) { store := state.NewThreadStore() store.Add("t-1", "Initial") - app := NewAppModel(store) + app := NewAppModel(store, nil) msg := ThreadStatusMsg{ ThreadID: "t-1", @@ -49,7 +49,7 @@ func TestAppToggleFocus(t *testing.T) { store := state.NewThreadStore() store.Add("t-1", "Test") - app := NewAppModel(store) + app := NewAppModel(store, nil) if app.Focus() != FocusCanvas { t.Errorf("expected canvas focus, got %d", app.Focus()) @@ -69,7 +69,7 @@ func TestAppTreeNavigationWhenFocused(t *testing.T) { store.Add("t-1", "First") store.Add("t-2", "Second") - app := NewAppModel(store) + app := NewAppModel(store, nil) toggleKey := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'t'}} updated, _ := app.Update(toggleKey) @@ -86,7 +86,7 @@ func TestAppTreeNavigationWhenFocused(t *testing.T) { func TestAppHandlesQuit(t *testing.T) { store := state.NewThreadStore() - app := NewAppModel(store) + app := NewAppModel(store, nil) quitKey := tea.KeyMsg{Type: tea.KeyCtrlC} _, cmd := app.Update(quitKey) @@ -100,7 +100,7 @@ func TestAppEnterOpensSession(t *testing.T) { store := state.NewThreadStore() store.Add("t-1", "Test Task") - app := NewAppModel(store) + app := NewAppModel(store, nil) enterKey := tea.KeyMsg{Type: tea.KeyEnter} updated, _ := app.Update(enterKey) @@ -115,7 +115,7 @@ func TestAppEscClosesSession(t *testing.T) { store := state.NewThreadStore() store.Add("t-1", "Test Task") - app := NewAppModel(store) + app := NewAppModel(store, nil) enterKey := tea.KeyMsg{Type: tea.KeyEnter} updated, _ := app.Update(enterKey) @@ -132,7 +132,7 @@ func TestAppEscClosesSession(t *testing.T) { func TestAppEnterWithNoThreadsDoesNothing(t *testing.T) { store := state.NewThreadStore() - app := NewAppModel(store) + app := NewAppModel(store, nil) enterKey := tea.KeyMsg{Type: tea.KeyEnter} updated, _ := app.Update(enterKey) @@ -147,7 +147,7 @@ func TestAppCtrlBMOpensMenu(t *testing.T) { store := state.NewThreadStore() store.Add("t-1", "Test") - app := NewAppModel(store) + app := NewAppModel(store, nil) ctrlB := tea.KeyMsg{Type: tea.KeyCtrlB} updated, _ := app.Update(ctrlB) @@ -166,7 +166,7 @@ func TestAppMenuEscCloses(t *testing.T) { store := state.NewThreadStore() store.Add("t-1", "Test") - app := NewAppModel(store) + app := NewAppModel(store, nil) ctrlB := tea.KeyMsg{Type: tea.KeyCtrlB} updated, _ := app.Update(ctrlB) @@ -189,7 +189,7 @@ func TestAppCtrlBEscCancelsPrefix(t *testing.T) { store := state.NewThreadStore() store.Add("t-1", "Test") - app := NewAppModel(store) + app := NewAppModel(store, nil) ctrlB := tea.KeyMsg{Type: tea.KeyCtrlB} updated, _ := app.Update(ctrlB) @@ -208,7 +208,7 @@ func TestAppMenuNavigation(t *testing.T) { store := state.NewThreadStore() store.Add("t-1", "Test") - app := NewAppModel(store) + app := NewAppModel(store, nil) ctrlB := tea.KeyMsg{Type: tea.KeyCtrlB} updated, _ := app.Update(ctrlB) @@ -229,7 +229,7 @@ func TestAppMenuNavigation(t *testing.T) { func TestAppHelpToggle(t *testing.T) { store := state.NewThreadStore() - app := NewAppModel(store) + app := NewAppModel(store, nil) helpKey := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'?'}} updated, _ := app.Update(helpKey) @@ -249,7 +249,7 @@ func TestAppHelpToggle(t *testing.T) { func TestAppHelpEscCloses(t *testing.T) { store := state.NewThreadStore() - app := NewAppModel(store) + app := NewAppModel(store, nil) helpKey := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'?'}} updated, _ := app.Update(helpKey) @@ -268,7 +268,7 @@ func TestAppSessionRefreshesOnMessage(t *testing.T) { store := state.NewThreadStore() store.Add("t-1", "Test") - app := NewAppModel(store) + app := NewAppModel(store, nil) enterKey := tea.KeyMsg{Type: tea.KeyEnter} updated, _ := app.Update(enterKey) @@ -287,3 +287,48 @@ func TestAppSessionRefreshesOnMessage(t *testing.T) { t.Errorf("expected session focus maintained, got %d", app.Focus()) } } + +func TestAppNewThreadBlockedWhenDisconnected(t *testing.T) { + store := state.NewThreadStore() + app := NewAppModel(store, nil) + + nKey := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'n'}} + _, cmd := app.Update(nKey) + + if cmd != nil { + t.Error("expected no command when disconnected") + } +} + +func TestAppConnectedAutoCreatesThread(t *testing.T) { + store := state.NewThreadStore() + app := NewAppModel(store, nil) + + connectedMsg := AppServerConnectedMsg{SessionID: "sess-1", Model: "gpt-4o"} + updated, _ := app.Update(connectedMsg) + app = updated.(AppModel) + + threads := store.All() + if len(threads) != 1 { + t.Fatalf("expected 1 thread after connect, got %d", len(threads)) + } + if threads[0].Title != "gpt-4o" { + t.Errorf("expected thread title gpt-4o, got %s", threads[0].Title) + } +} + +func TestAppNKeyShowsMessageWhenConnected(t *testing.T) { + store := state.NewThreadStore() + app := NewAppModel(store, nil) + + connectedMsg := AppServerConnectedMsg{SessionID: "sess-1", Model: "gpt-4o"} + updated, _ := app.Update(connectedMsg) + app = updated.(AppModel) + + nKey := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'n'}} + _, cmd := app.Update(nKey) + + if cmd != nil { + t.Error("expected no command in single session mode") + } +} diff --git a/internal/tui/bridge.go b/internal/tui/bridge.go index df64794..a2de8d6 100644 --- a/internal/tui/bridge.go +++ b/internal/tui/bridge.go @@ -1,50 +1,52 @@ package tui -import "github.com/robinojw/dj/internal/appserver" +import ( + tea "github.com/charmbracelet/bubbletea" + "github.com/robinojw/dj/internal/appserver" +) type MessageSender interface { - Send(msg any) + Send(msg tea.Msg) } -func WireEventBridge(router *appserver.NotificationRouter, sender MessageSender) { - router.OnThreadStatusChanged(func(params appserver.ThreadStatusChanged) { - sender.Send(ThreadStatusMsg{ - ThreadID: params.ThreadID, - Status: params.Status, - Title: params.Title, +func WireEventBridge(router *appserver.EventRouter, sender MessageSender) { + router.OnSessionConfigured(func(event appserver.SessionConfigured) { + sender.Send(AppServerConnectedMsg{ + SessionID: event.SessionID, + Model: event.Model, }) }) - router.OnThreadMessageCreated(func(params appserver.ThreadMessageCreated) { - sender.Send(ThreadMessageMsg{ - ThreadID: params.ThreadID, - MessageID: params.MessageID, - Role: params.Role, - Content: params.Content, + router.OnAgentMessageDelta(func(event appserver.AgentMessageDelta) { + sender.Send(ThreadDeltaMsg{ + Delta: event.Delta, }) }) - router.OnThreadMessageDelta(func(params appserver.ThreadMessageDelta) { - sender.Send(ThreadDeltaMsg{ - ThreadID: params.ThreadID, - MessageID: params.MessageID, - Delta: params.Delta, + router.OnExecCommandBegin(func(event appserver.ExecCommandBegin) { + sender.Send(CommandOutputMsg{ + ExecID: event.ExecID, + Data: "$ " + event.Command + "\n", }) }) - router.OnCommandOutput(func(params appserver.CommandOutput) { + router.OnExecCommandOutputDelta(func(event appserver.ExecCommandOutputDelta) { sender.Send(CommandOutputMsg{ - ThreadID: params.ThreadID, - ExecID: params.ExecID, - Data: params.Data, + ExecID: event.ExecID, + Data: event.Delta, }) }) - router.OnCommandFinished(func(params appserver.CommandFinished) { + router.OnExecCommandEnd(func(event appserver.ExecCommandEnd) { sender.Send(CommandFinishedMsg{ - ThreadID: params.ThreadID, - ExecID: params.ExecID, - ExitCode: params.ExitCode, + ExecID: event.ExecID, + ExitCode: event.ExitCode, + }) + }) + + router.OnError(func(event appserver.ServerError) { + sender.Send(AppServerErrorMsg{ + Err: &appserver.RPCError{Message: event.Message}, }) }) } diff --git a/internal/tui/bridge_test.go b/internal/tui/bridge_test.go index 8ea6ef2..cb8369a 100644 --- a/internal/tui/bridge_test.go +++ b/internal/tui/bridge_test.go @@ -1,46 +1,80 @@ package tui import ( + "encoding/json" "testing" + tea "github.com/charmbracelet/bubbletea" "github.com/robinojw/dj/internal/appserver" ) type mockSender struct { - messages []any + messages []tea.Msg } -func (mock *mockSender) Send(msg any) { +func (mock *mockSender) Send(msg tea.Msg) { mock.messages = append(mock.messages, msg) } -func TestBridgeThreadStatusChanged(t *testing.T) { +func TestBridgeSessionConfigured(t *testing.T) { sender := &mockSender{} - router := appserver.NewNotificationRouter() + router := appserver.NewEventRouter() WireEventBridge(router, sender) - router.Handle(appserver.NotifyThreadStatusChanged, - []byte(`{"threadId":"t-1","status":"active","title":"Running"}`)) + event := appserver.Event{ + ID: "", + Msg: json.RawMessage(`{"type":"session_configured","session_id":"sess-1","model":"gpt-4o"}`), + } + router.HandleEvent(event) if len(sender.messages) != 1 { t.Fatalf("expected 1 message, got %d", len(sender.messages)) } - msg, ok := sender.messages[0].(ThreadStatusMsg) + msg, ok := sender.messages[0].(AppServerConnectedMsg) if !ok { - t.Fatalf("expected ThreadStatusMsg, got %T", sender.messages[0]) + t.Fatalf("expected AppServerConnectedMsg, got %T", sender.messages[0]) + } + if msg.SessionID != "sess-1" { + t.Errorf("expected sess-1, got %s", msg.SessionID) } - if msg.ThreadID != "t-1" { - t.Errorf("expected t-1, got %s", msg.ThreadID) + if msg.Model != "gpt-4o" { + t.Errorf("expected gpt-4o, got %s", msg.Model) } } -func TestBridgeCommandOutput(t *testing.T) { +func TestBridgeAgentMessageDelta(t *testing.T) { sender := &mockSender{} - router := appserver.NewNotificationRouter() + router := appserver.NewEventRouter() WireEventBridge(router, sender) - router.Handle(appserver.NotifyCommandOutput, - []byte(`{"threadId":"t-1","execId":"e-1","data":"hello\n"}`)) + event := appserver.Event{ + ID: "sub-1", + Msg: json.RawMessage(`{"type":"agent_message_delta","delta":"hello"}`), + } + router.HandleEvent(event) + + if len(sender.messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(sender.messages)) + } + msg, ok := sender.messages[0].(ThreadDeltaMsg) + if !ok { + t.Fatalf("expected ThreadDeltaMsg, got %T", sender.messages[0]) + } + if msg.Delta != "hello" { + t.Errorf("expected hello, got %s", msg.Delta) + } +} + +func TestBridgeExecCommandBegin(t *testing.T) { + sender := &mockSender{} + router := appserver.NewEventRouter() + WireEventBridge(router, sender) + + event := appserver.Event{ + ID: "sub-1", + Msg: json.RawMessage(`{"type":"exec_command_begin","call_id":"cmd-1","command":"ls -la"}`), + } + router.HandleEvent(event) if len(sender.messages) != 1 { t.Fatalf("expected 1 message, got %d", len(sender.messages)) @@ -49,7 +83,30 @@ func TestBridgeCommandOutput(t *testing.T) { if !ok { t.Fatalf("expected CommandOutputMsg, got %T", sender.messages[0]) } - if msg.Data != "hello\n" { - t.Errorf("expected hello, got %s", msg.Data) + if msg.ExecID != "cmd-1" { + t.Errorf("expected cmd-1, got %s", msg.ExecID) + } +} + +func TestBridgeServerError(t *testing.T) { + sender := &mockSender{} + router := appserver.NewEventRouter() + WireEventBridge(router, sender) + + event := appserver.Event{ + ID: "", + Msg: json.RawMessage(`{"type":"error","message":"something broke"}`), + } + router.HandleEvent(event) + + if len(sender.messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(sender.messages)) + } + msg, ok := sender.messages[0].(AppServerErrorMsg) + if !ok { + t.Fatalf("expected AppServerErrorMsg, got %T", sender.messages[0]) + } + if msg.Error() != "something broke" { + t.Errorf("expected 'something broke', got %s", msg.Error()) } } diff --git a/internal/tui/integration_test.go b/internal/tui/integration_test.go new file mode 100644 index 0000000..bf1570b --- /dev/null +++ b/internal/tui/integration_test.go @@ -0,0 +1,48 @@ +//go:build integration + +package tui + +import ( + "context" + "testing" + "time" + + "github.com/robinojw/dj/internal/appserver" + "github.com/robinojw/dj/internal/state" +) + +func TestIntegrationEndToEnd(t *testing.T) { + client := appserver.NewClient("codex", "proto") + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Start failed: %v", err) + } + defer client.Stop() + + received := make(chan appserver.SessionConfigured, 1) + router := appserver.NewEventRouter() + router.OnSessionConfigured(func(event appserver.SessionConfigured) { + received <- event + }) + client.Router = router + + go client.ReadLoop() + + select { + case event := <-received: + t.Logf("Connected: session=%s model=%s", event.SessionID, event.Model) + + store := state.NewThreadStore() + store.Add(event.SessionID, event.Model) + + threads := store.All() + if len(threads) != 1 { + t.Fatalf("expected 1 thread, got %d", len(threads)) + } + case <-ctx.Done(): + t.Fatal("timeout waiting for session_configured") + } +} diff --git a/internal/tui/msgs.go b/internal/tui/msgs.go index b5d1949..a68673f 100644 --- a/internal/tui/msgs.go +++ b/internal/tui/msgs.go @@ -40,6 +40,11 @@ type ThreadDeletedMsg struct { ThreadID string } +type AppServerConnectedMsg struct { + SessionID string + Model string +} + type AppServerErrorMsg struct { Err error } diff --git a/internal/tui/statusbar.go b/internal/tui/statusbar.go new file mode 100644 index 0000000..ebaebae --- /dev/null +++ b/internal/tui/statusbar.go @@ -0,0 +1,98 @@ +package tui + +import ( + "fmt" + + "github.com/charmbracelet/lipgloss" +) + +var ( + statusBarStyle = lipgloss.NewStyle(). + Background(lipgloss.Color("236")). + Foreground(lipgloss.Color("252")). + Padding(0, 1) + statusConnectedStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("42")) + statusConnectingStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("214")) + statusDisconnectedStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("196")) + statusErrorStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("196")). + Bold(true) +) + +type StatusBar struct { + connected bool + connecting bool + threadCount int + selectedThread string + errorMessage string + width int +} + +func NewStatusBar() *StatusBar { + return &StatusBar{} +} + +func (statusBar *StatusBar) SetConnecting() { + statusBar.connecting = true + statusBar.connected = false + statusBar.errorMessage = "" +} + +func (statusBar *StatusBar) SetConnected(connected bool) { + statusBar.connecting = false + statusBar.connected = connected + if connected { + statusBar.errorMessage = "" + } +} + +func (statusBar *StatusBar) SetThreadCount(count int) { + statusBar.threadCount = count +} + +func (statusBar *StatusBar) SetSelectedThread(name string) { + statusBar.selectedThread = name +} + +func (statusBar StatusBar) renderConnectionState() string { + if statusBar.connected { + return statusConnectedStyle.Render("● Connected") + } + if statusBar.connecting { + return statusConnectingStyle.Render("◌ Connecting to app-server...") + } + return statusDisconnectedStyle.Render("○ Disconnected — requires codex CLI (codex app-server)") +} + +func (statusBar *StatusBar) SetError(message string) { + statusBar.errorMessage = message +} + +func (statusBar *StatusBar) SetWidth(width int) { + statusBar.width = width +} + +func (statusBar StatusBar) View() string { + left := statusBar.renderConnectionState() + + if statusBar.errorMessage != "" { + left += " " + statusErrorStyle.Render(statusBar.errorMessage) + } + + middle := "" + if statusBar.threadCount > 0 { + middle = fmt.Sprintf(" | %d threads", statusBar.threadCount) + } + + right := "" + if statusBar.selectedThread != "" { + right = fmt.Sprintf(" | %s", statusBar.selectedThread) + } + + content := left + middle + right + style := statusBarStyle.Width(statusBar.width) + return style.Render(content) +} diff --git a/internal/tui/statusbar_test.go b/internal/tui/statusbar_test.go new file mode 100644 index 0000000..988ca31 --- /dev/null +++ b/internal/tui/statusbar_test.go @@ -0,0 +1,76 @@ +package tui + +import ( + "strings" + "testing" +) + +func TestStatusBarConnected(t *testing.T) { + bar := NewStatusBar() + bar.SetConnected(true) + bar.SetThreadCount(3) + bar.SetSelectedThread("Build web app") + + output := bar.View() + + if !strings.Contains(output, "Connected") { + t.Errorf("expected Connected in output:\n%s", output) + } + if !strings.Contains(output, "3 threads") { + t.Errorf("expected thread count in output:\n%s", output) + } + if !strings.Contains(output, "Build web app") { + t.Errorf("expected selected thread in output:\n%s", output) + } +} + +func TestStatusBarDisconnected(t *testing.T) { + bar := NewStatusBar() + bar.SetConnected(false) + + output := bar.View() + + if !strings.Contains(output, "Disconnected") { + t.Errorf("expected Disconnected in output:\n%s", output) + } + if !strings.Contains(output, "codex") { + t.Errorf("expected codex hint in disconnected output:\n%s", output) + } +} + +func TestStatusBarConnecting(t *testing.T) { + bar := NewStatusBar() + bar.SetConnecting() + + output := bar.View() + + if !strings.Contains(output, "Connecting") { + t.Errorf("expected Connecting in output:\n%s", output) + } +} + +func TestStatusBarConnectingClearedOnConnect(t *testing.T) { + bar := NewStatusBar() + bar.SetConnecting() + bar.SetConnected(true) + + output := bar.View() + + if strings.Contains(output, "Connecting") { + t.Errorf("expected Connecting cleared after connect:\n%s", output) + } + if !strings.Contains(output, "Connected") { + t.Errorf("expected Connected in output:\n%s", output) + } +} + +func TestStatusBarError(t *testing.T) { + bar := NewStatusBar() + bar.SetError("connection lost") + + output := bar.View() + + if !strings.Contains(output, "connection lost") { + t.Errorf("expected error in output:\n%s", output) + } +}