diff --git a/.version/VERSION b/.version/VERSION index b82608c..1474d00 100644 --- a/.version/VERSION +++ b/.version/VERSION @@ -1 +1 @@ -v0.1.0 +v0.2.0 diff --git a/.version/changelog/README.md b/.version/changelog/README.md index f54bbc1..faecdf1 100644 --- a/.version/changelog/README.md +++ b/.version/changelog/README.md @@ -10,6 +10,7 @@ ## 当前版本文件 - [`Unreleased.md`](Unreleased.md) +- [`v0.2.0.md`](v0.2.0.md) - [`v0.1.0.md`](v0.1.0.md) - [`v0.0.6.md`](v0.0.6.md) - [`v0.0.5.md`](v0.0.5.md) diff --git a/.version/changelog/Unreleased.md b/.version/changelog/Unreleased.md index 5d441a6..eedc4d7 100644 --- a/.version/changelog/Unreleased.md +++ b/.version/changelog/Unreleased.md @@ -7,17 +7,16 @@ ## 新增 -- 暂无 +暂无 ## 修复 -- 暂无 +暂无 ## 变更 -- 暂无 +暂无 ## 文档 -- 暂无 - +暂无 \ No newline at end of file diff --git a/.version/changelog/v0.2.0.md b/.version/changelog/v0.2.0.md new file mode 100644 index 0000000..9655b1b --- /dev/null +++ b/.version/changelog/v0.2.0.md @@ -0,0 +1,37 @@ +# [v0.2.0] - 2026-04-01 + +## 新增 + +- 增加 `InvocationStream`,支持命令在执行期间进行响应流输出。 +- 增加 `Invocation.ResponseStream()` 返回 `<-chan any`,由命令内部创建并暴露响应流供调用方消费。 +- 增加 `ResponseHandler` / `ResponseStreamHandler` 接口化处理器,并提供 `Unary[T]` / `Stream[T]` 泛型适配器以携带运行时类型信息。 +- 增加 `TypedWriter[T]` 泛型写入器,通过 `Send(v T)` 直接发送泛型数据到流通道。 +- 增加 `RunCallback[T]` 泛型回调执行入口,统一支持 Unary 响应与 Stream 输出的类型化回调分发。 +- 增加 `StreamError` 结构化错误类型,支持 `code/message/details`。 +- 增加 `StreamEnvelope` NDJSON 信封类型(`{"$":"resp","type":"T","data":...}`),stdout/stderr 镜像输出使用信封包装,消费者可通过 `$` 键区分响应数据与普通日志。 +- `InvocationStream.Send` 自动镜像输出:响应数据以 NDJSON 信封写入 stdout,`StreamError` 以信封写入 stderr;channel 消费者仍接收原始值。 +- 新增 `example/fastcommit stream` 流式输出测试命令,每秒发送一条带时间戳的消息。 + +## 修复 + +- 修复 `InvocationStream.Send` 与 `closeResponseStream` 之间的并发竞态:channel 引用在创建时捕获,不再动态读取。 + +## 变更 + +- 流通道类型从 `chan map[string]any` 简化为 `chan any`,直接传递泛型数据,不再包装事件结构。 +- 三类处理器互斥校验前移到 `init` 阶段:`Handler`、`ResponseHandler`、`ResponseStreamHandler` 同时配置时报错。 +- 适配器函数 `adaptResponseHandler` / `adaptResponseStreamHandler` 改为包内私有。 +- `handler.go` / `response_handlers.go` / `execution_typed.go` 合并为单一 `handler.go`。 +- 对应测试文件 `execution_typed_test.go` / `stream_test.go` 合并为 `handler_test.go`。 +- 移除 `Command.ResponseBuffer` 字段,流通道缓冲统一使用内部默认值。 +- `webcmd/webui` 增加 `/api/run/stream/ws` 流式 WebSocket 执行通道。 +- `webcmd/webui` 命令元数据 `/api/commands` 增加 `supportsStream` 字段。 + +## 文档 + +- 更新 `docs/DESIGN.md` 第 9 节,反映三类处理器互斥与泛型适配器架构,补充 Unary 执行路径与执行上下文兼容矩阵。 +- 重写 `docs/INTERACTIVE_STREAMING.md`,更新为 `TypedWriter[T].Send` + `chan any` 模型;新增 Unary 执行路径、`ResponseTypeInfo` 说明、执行上下文兼容性表、MCP 集成章节。 +- 更新 `docs/USAGE_AT_A_GLANCE.md`,新增第 8 节(Unary ResponseHandler)、第 9 节(Stream ResponseStreamHandler)、第 10 节(三类处理器对比表)。 +- 更新 `docs/DOCS_CATALOG.md`,补充 `example/unary` 与 `example/stream-interactive` 示例引用。 +- 新增 `example/unary` 示例及说明文档,展示 `Unary[T]` + `RunCallback[T]` 用法。 +- 新增 `example/stream-interactive` 示例及说明文档。 diff --git a/README.md b/README.md index 9adb571..155f998 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,7 @@ MCP 输入/输出协议、Schema 规则与排查建议见:[`docs/MCP.md`](docs - `example/env-test`:环境变量示例 - `example/globalflags`:全局标志示例 - `example/args-test`:参数解析示例 +- `example/stream-interactive`:交互流(结构化输出 + 内建响应流)完整示例 ## 开发与维护 diff --git a/command.go b/command.go index e35fac9..c922131 100644 --- a/command.go +++ b/command.go @@ -60,8 +60,10 @@ type Command struct { // Middleware is called before the Handler. // Use Chain() to combine multiple middlewares. - Middleware MiddlewareFunc - Handler HandlerFunc + Middleware MiddlewareFunc + Handler HandlerFunc + ResponseHandler ResponseHandler + ResponseStreamHandler ResponseStreamHandler } func ascendingSortFn[T cmp.Ordered](a, b T) int { @@ -121,6 +123,10 @@ func (c *Command) init() error { } } + if _, err := c.resolveConfiguredHandler(); err != nil { + merr = errors.Join(merr, err) + } + slices.SortFunc(c.Options, func(a, b Option) int { // Use Flag for sorting, fallback to Env if Flag is empty nameA := a.Flag @@ -272,6 +278,9 @@ type Invocation struct { Stderr io.Writer Stdin io.Reader + responseStream chan any + responseValue any + // Annotations is a map of arbitrary annotations to attach to the invocation. Annotations map[string]any @@ -329,6 +338,31 @@ func (inv *Invocation) WithTestParsedFlags( }) } +// ResponseStream returns the invocation response stream channel. +// +// The stream is internally owned by Invocation and is closed automatically when +// ResponseStreamHandler execution finishes. +func (inv *Invocation) ResponseStream() <-chan any { + return inv.ensureResponseStream() +} + +func (inv *Invocation) ensureResponseStream() chan any { + if inv.responseStream != nil { + return inv.responseStream + } + inv.responseStream = make(chan any, defaultStreamResponseBuffer) + return inv.responseStream +} + +func (inv *Invocation) closeResponseStream() { + if inv.responseStream == nil { + return + } + ch := inv.responseStream + inv.responseStream = nil + close(ch) +} + func (inv *Invocation) Context() context.Context { if inv.ctx == nil { return context.Background() @@ -336,6 +370,28 @@ func (inv *Invocation) Context() context.Context { return inv.ctx } +// Response returns the unary response produced by ResponseHandler in current run. +func (inv *Invocation) Response() (any, bool) { + if inv == nil || inv.responseValue == nil { + return nil, false + } + return inv.responseValue, true +} + +func (inv *Invocation) setResponse(v any) { + if inv == nil { + return + } + inv.responseValue = v +} + +func (inv *Invocation) clearResponse() { + if inv == nil { + return + } + inv.responseValue = nil +} + func (inv *Invocation) ParsedFlags() *pflag.FlagSet { if inv.Flags == nil { panic("flags not parsed, has Run() been called?") @@ -846,7 +902,7 @@ func (inv *Invocation) run(state *runState) error { ctx, cancel := context.WithCancel(ctx) defer cancel() - inv = inv.WithContext(ctx) + inv.ctx = ctx // Check for help flag if inv.Flags != nil { @@ -855,11 +911,16 @@ func (inv *Invocation) run(state *runState) error { } } - if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) { + handler, resolveErr := inv.Command.resolveConfiguredHandler() + if resolveErr != nil { + return &RunCommandError{Cmd: inv.Command, Err: resolveErr} + } + + if handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) { return DefaultHelpFn()(ctx, inv) } - err := mw(inv.Command.Handler)(ctx, inv) + err := mw(handler)(ctx, inv) if err != nil { return &RunCommandError{ Cmd: inv.Command, @@ -1060,6 +1121,9 @@ func parseAndSetArgs(argsDef ArgSet, args []string) error { // //nolint:revive func (inv *Invocation) Run() (err error) { + defer inv.closeResponseStream() + inv.clearResponse() + restoreEnv, preloadErr := preloadEnvFromArgs(inv.Args) if preloadErr != nil { return fmt.Errorf("preloading environment variables: %w", preloadErr) @@ -1213,3 +1277,42 @@ func (c *Command) children() map[string]*Command { } return childrenMap } + +func (c *Command) resolveConfiguredHandler() (HandlerFunc, error) { + if c == nil { + return nil, nil + } + + hasRequest := c.Handler != nil + hasUnary := c.ResponseHandler != nil + hasStream := c.ResponseStreamHandler != nil + + configuredCount := 0 + if hasRequest { + configuredCount++ + } + if hasUnary { + configuredCount++ + } + if hasStream { + configuredCount++ + } + + if configuredCount > 1 { + return nil, fmt.Errorf("command %q configures multiple handler models", c.FullName()) + } + + if hasRequest { + return c.Handler, nil + } + + if hasUnary { + return adaptResponseHandler(c.ResponseHandler), nil + } + + if hasStream { + return adaptResponseStreamHandler(c.ResponseStreamHandler), nil + } + + return nil, nil +} diff --git a/command_test.go b/command_test.go index bc37d4a..0032ffd 100644 --- a/command_test.go +++ b/command_test.go @@ -833,3 +833,47 @@ func TestInternalArgsFlagOverridesParsedArgs(t *testing.T) { t.Fatalf("second arg value = %q, want %q", gotSecond, "from-flag-2") } } + +func TestCommandInitHandlerValidation(t *testing.T) { + tests := []struct { + name string + command *Command + wantErr bool + }{ + { + name: "invalid multiple handler models configured", + command: &Command{ + Use: "echo", + Handler: func(ctx context.Context, inv *Invocation) error { + return nil + }, + ResponseHandler: Unary(func(ctx context.Context, inv *Invocation) (string, error) { + return "ok", nil + }), + }, + wantErr: true, + }, + { + name: "valid response stream handler", + command: &Command{ + Use: "chat", + ResponseStreamHandler: Stream(func(ctx context.Context, inv *Invocation, out *TypedWriter[string]) error { + return out.Send("hello") + }), + }, + wantErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.command.init() + if tc.wantErr && err == nil { + t.Fatalf("expected init error, got nil") + } + if !tc.wantErr && err != nil { + t.Fatalf("unexpected init error: %v", err) + } + }) + } +} diff --git a/docs/DESIGN.md b/docs/DESIGN.md index 608054b..2d9f721 100644 --- a/docs/DESIGN.md +++ b/docs/DESIGN.md @@ -19,10 +19,10 @@ flowchart LR R --> O[选项与参数解析层] O --> M[中间件编排层] M --> H[处理器执行层] - H --> X[输出/错误/退出码] + H --> X["输出/错误/退出码"] R --> C1[命令树] - O --> C2[Option 与 Arg] + O --> C2["Option 与 Arg"] M --> C3[Middleware] H --> C4[Handler] ``` @@ -152,13 +152,13 @@ stateDiagram-v2 ```mermaid flowchart LR - UI[浏览器 xterm.js] -->|WebSocket| WS[/ws] - WS --> PTY[PTY 桥接] - PTY --> SH[本地 shell] + UI["浏览器 xterm.js"] -->|WebSocket| WS["/ws"] + WS --> PTY["PTY 桥接"] + PTY --> SH["本地 shell"] - UI -->|multipart| UP[/upload] - UI -->|GET| LS[/api/files] - UI -->|GET| DL[/download] + UI -->|multipart| UP["/upload"] + UI -->|GET| LS["/api/files"] + UI -->|GET| DL["/download"] UP --> FS[(工作目录)] LS --> FS @@ -232,3 +232,50 @@ sequenceDiagram - 同级:[`MCP.md`](MCP.md) 提供 MCP 子命令、Schema 与调用协议说明。 - 同级:[`WEBTTY.md`](WEBTTY.md) 提供 WebTTY 能力说明与分阶段迭代路线。 - 下游:[`../example/args-test/README.md`](../example/args-test/README.md) 提供参数解析落地示例。 +- 下游:[`../example/unary/README.md`](../example/unary/README.md) 提供 Unary 响应处理器示例。 +- 下游:[`../example/stream-interactive/README.md`](../example/stream-interactive/README.md) 提供流式响应处理器示例。 + +## 9. 交互式命令与流式处理 + +为兼容传统一次性命令执行,同时支持类 RPC 的结构化响应输出,Redant 新增了 Unary/Stream 处理能力。 + +### 处理器类型与适配 + +```mermaid +flowchart LR + A[Invocation.Run] --> B{命令处理器类型} + B -- Handler --> D[原始 HandlerFunc] + B -- ResponseHandler --> E[Unary 适配] + B -- ResponseStreamHandler --> C[Stream 适配] + E --> E1[Handle → 返回 T] + E1 --> E2[setResponse + JSON stdout] + C --> F[InvocationStream] + F --> G[TypedWriter.Send] + G --> H[ResponseStream 通道输出] + H --> I[Run 结束自动 close] +``` + +### 执行上下文兼容矩阵 + +```mermaid +flowchart TD + CMD[命令] --> CTX{执行上下文} + CTX -- CLI 直接调用 --> CLI[stdout 自动输出] + CTX -- RunCallback --> RCB[泛型回调分发] + CTX -- MCP callTool --> MCP[RunCallback → structuredContent] + CTX -- WebUI Stream WS --> WS[WebSocket 事件推送] + CTX -- readline / richline --> REPL[降级: stdout 纯文本] +``` + +设计要点: + +1. 保留 `HandlerFunc` 与 `MiddlewareFunc`,不破坏现有执行链。 +2. 三类处理器互斥:`Handler`(无响应)、`ResponseHandler`(Unary 单响应)、`ResponseStreamHandler`(流式响应),初始化阶段校验冲突。 +3. `ResponseHandler` 通过 `Unary[T]` 泛型适配器构造,返回值自动 JSON 序列化写入 stdout,可通过 `Response()` 或 `RunCallback[T]` 获取。 +4. `ResponseStreamHandler` 通过 `Stream[T]` 泛型适配器构造,`TypedWriter[T].Send(v)` 直接发送泛型数据。 +5. 响应流通道类型为 `chan any`,通过 `Invocation.ResponseStream()` 消费。 +6. `InvocationStream.Send` 自动镜像文本到 stdout、`StreamError` 到 stderr、struct 类型 JSON 序列化到 stdout。 +7. `RunCallback[T]` 提供泛型回调消费入口,统一支持 Unary 与 Stream 两种模型的类型化分发。 +8. `ResponseTypeInfo` 暴露运行时类型元数据(`TypeName`、`Schema`),供 MCP 等集成层生成输出 Schema。 + +详见:[`INTERACTIVE_STREAMING.md`](INTERACTIVE_STREAMING.md)。 diff --git a/docs/DOCS_CATALOG.md b/docs/DOCS_CATALOG.md index 0cff931..7cd3531 100644 --- a/docs/DOCS_CATALOG.md +++ b/docs/DOCS_CATALOG.md @@ -11,6 +11,7 @@ ## 2) 架构与质量 - [`DESIGN.md`](DESIGN.md):核心模型、解析流程、扩展点。 +- [`INTERACTIVE_STREAMING.md`](INTERACTIVE_STREAMING.md):交互式命令与结构化响应流方案。 - [`EVALUATION.md`](EVALUATION.md):质量评估、风险与优化建议。 ## 3) PR 审查体系(聚合) @@ -27,7 +28,9 @@ ## 5) 仓库外延入口(按需) - 项目总览:`README.md` -- 示例:`example/args-test/README.md` +- 示例:`example/args-test/README.md`(参数解析示例) +- 示例:`example/unary/README.md`(Unary 响应处理器示例) +- 示例:`example/stream-interactive/README.md`(流式响应处理器示例) - 内部样式维护:`internal/pretty/README.md` --- diff --git a/docs/INDEX.md b/docs/INDEX.md index 21d5094..13158cf 100644 --- a/docs/INDEX.md +++ b/docs/INDEX.md @@ -19,6 +19,7 @@ flowchart TD 2. [`USAGE_AT_A_GLANCE.md`](USAGE_AT_A_GLANCE.md):新同学快速建立 CLI 使用心智模型。 3. [`review/PR_REVIEW_RUBRIC.md`](review/PR_REVIEW_RUBRIC.md):PR 审查流程基线与轮次规则。 4. [`DESIGN.md`](DESIGN.md):涉及实现变更时优先查阅。 +5. [`INTERACTIVE_STREAMING.md`](INTERACTIVE_STREAMING.md):交互式命令与结构化响应流设计。 > 说明:为保持主仓聚焦,`agentline` 与 `copilot-demo` 相关模块/示例已迁移到独立项目维护;本索引仅覆盖 `redant` 主仓当前内容。 diff --git a/docs/INTERACTIVE_STREAMING.md b/docs/INTERACTIVE_STREAMING.md new file mode 100644 index 0000000..ade5114 --- /dev/null +++ b/docs/INTERACTIVE_STREAMING.md @@ -0,0 +1,278 @@ +# 交互式命令与流式处理(Stream) + +本文档说明 Redant 的交互式命令与响应流设计。 + +## 目标与原则 + +- 保持命令分发与中间件主链不变。 +- 通过 `ResponseHandler` 提供 Unary 单响应能力。 +- 通过 `ResponseStreamHandler` 提供结构化响应流输出能力。 +- 响应流由 Invocation 内部创建并管理,`Run()` 结束后自动关闭。 +- 流通道类型为 `chan any`,直接传递泛型数据,不再包装事件结构。 +- `TypedWriter[T]` 提供类型安全的发送接口。 + +## 核心类型 + +### Command 扩展 + +三类处理器互斥,初始化阶段校验冲突: + +| 字段 | 类型 | 说明 | +| ----------------------- | ----------------------- | ---------------------------------- | +| `Handler` | `HandlerFunc` | 传统处理器,无结构化响应 | +| `ResponseHandler` | `ResponseHandler` | Unary 单响应,通过 `Unary[T]` 构造 | +| `ResponseStreamHandler` | `ResponseStreamHandler` | 流式响应,通过 `Stream[T]` 构造 | + +### Invocation 扩展 + +- `ResponseStream() <-chan any`:消费响应流通道。 +- `Response() (any, bool)`:获取 Unary 响应值。 + +### ResponseTypeInfo + +运行时输出类型元数据,由 `ResponseHandler` 和 `ResponseStreamHandler` 通过 `TypeInfo()` 方法暴露: + +```go +type ResponseTypeInfo struct { + TypeName string // 例如 "VersionInfo"、"string" + Schema string // 可选 JSON Schema +} +``` + +### InvocationStream + +`Send(data any)` 统一发送接口,行为: + +1. 将数据推入内部 `chan any` 通道(供 `ResponseStream()` 消费)。 +2. 自动镜像到 stdio: + - `string` / `[]byte` → stdout + - `StreamError` → stderr + - 其他类型 → JSON 序列化后写 stdout + +### TypedWriter[T] + +泛型写入器,由 `Stream[T]` 适配器自动注入: + +- `Send(v T) error`:发送泛型数据。 +- `Raw() *InvocationStream`:获取底层流(高级场景)。 + +## 执行路径 + +### Stream(流式响应) + +```mermaid +sequenceDiagram + participant I as Invocation + participant M as Middleware Chain + participant A as adaptResponseStreamHandler + participant S as StreamHandler[T] + participant W as TypedWriter[T] + + I->>I: 解析命令与 flags/args + I->>I: resolveConfiguredHandler + I->>M: 构建并执行中间件 + M->>A: 调用适配后的 Handler + A->>A: 创建 InvocationStream + A->>S: 执行 HandleStream + S->>W: 使用 TypedWriter.Send 发送数据 + W->>A: 数据推入 chan any + 镜像 stdout + A-->>M: 透传执行结果 + I->>I: Run 结束,closeResponseStream +``` + +### Unary(单响应) + +```mermaid +sequenceDiagram + participant I as Invocation + participant M as Middleware Chain + participant A as adaptResponseHandler + participant U as UnaryHandler[T] + + I->>I: 解析命令与 flags/args + I->>I: resolveConfiguredHandler + I->>M: 构建并执行中间件 + M->>A: 调用适配后的 Handler + A->>U: 执行 Handle → 返回 (T, error) + U-->>A: 返回 resp + A->>I: setResponse(resp) + A->>I: writeUnaryResponse → JSON 写 stdout + A-->>M: 透传执行结果 +``` + +## 泛型回调消费(RunCallback) + +```mermaid +flowchart TD + A[RunCallback] --> B[创建 ResponseStream] + B --> C[goroutine 消费 chan any] + C --> D{类型断言 T} + D -- 成功 --> E[调用 callback] + D -- 失败 --> F[返回类型不匹配错误] + A --> G[调用 inv.Run] + G --> H{Run 结束} + H --> I{有 Unary Response?} + I -- 是 --> J[类型断言并调用 callback] + I -- 否 --> K[返回] +``` + +## 开发任务同步 + +- [x] 增加 `InvocationStream` 与 `TypedWriter[T]`。 +- [x] 增加 `ResponseHandler` / `ResponseStreamHandler` 接口与 `Unary[T]` / `Stream[T]` 泛型适配器。 +- [x] 增加 `Invocation.ResponseStream()`。 +- [x] 三类处理器互斥校验(`resolveConfiguredHandler`)。 +- [x] `InvocationStream.Send` 直接发送 `any`,自动镜像 stdio。 +- [x] `RunCallback[T]` 泛型回调消费。 +- [x] 回归测试:stdio 回退 + channel 消费 + 类型不匹配。 +- [ ] 后续任务:补充流式中间件(按事件级拦截)。 + +## 使用示例 + +### Unary 命令定义 + +```go +type VersionInfo struct { + Version string `json:"version"` + BuildDate string `json:"buildDate"` +} + +versionCmd := &redant.Command{ + Use: "version", + ResponseHandler: redant.Unary(func(ctx context.Context, inv *redant.Invocation) (VersionInfo, error) { + return VersionInfo{Version: "1.0.0", BuildDate: "2026-04-01"}, nil + }), +} +``` + +### 流式命令定义 + +```go +chat := &redant.Command{ + Use: "chat", + ResponseStreamHandler: redant.Stream(func(ctx context.Context, inv *redant.Invocation, out *redant.TypedWriter[string]) error { + if err := out.Send("hello"); err != nil { + return err + } + return out.Send("world") + }), +} +``` + +### stdio 回退 + +直接 `Run()` 即可,文本数据自动写入 stdout: + +```go +chat.Invoke().WithOS().Run() +``` + +### Unary 泛型回调 + +```go +err := redant.RunCallback[VersionInfo](versionCmd.Invoke(), func(v VersionInfo) error { + fmt.Printf("version=%s build=%s\n", v.Version, v.BuildDate) + return nil +}) +``` + +### 流式泛型回调消费 + +```go +err := redant.RunCallback[string](chat.Invoke(), func(chunk string) error { + fmt.Println(chunk) + return nil +}) +``` + +### 通道消费 + +```go +inv := chat.Invoke() +out := inv.ResponseStream() +inv.Run() +for data := range out { + fmt.Println(data) +} +``` + +## 执行上下文兼容性 + +不同执行上下文对三类处理器的支持程度: + +| 执行上下文 | Handler | ResponseHandler (Unary) | ResponseStreamHandler (Stream) | +| -------------------------- | ------------------ | -------------------------------------------- | ---------------------------------------------- | +| **直接 `inv.Run()`** | 正常执行 | JSON 写 stdout | 文本镜像 stdout | +| **`RunCallback[T]`** | 正常执行(无回调) | 回调调用一次 | 回调调用多次 | +| **`inv.ResponseStream()`** | 通道无数据 | 通道无数据,用 `Response()` | 通道接收流事件 | +| **MCP (`callTool`)** | stdout 捕获 | `RunCallback` → `structuredContent.response` | `RunCallback` → `structuredContent.response[]` | +| **WebUI Stream WS** | stdout 捕获 | stdout 捕获 | WebSocket 推送事件 | +| **WebUI HTTP POST** | stdout 捕获 | JSON stdout 捕获 | 降级:文本 stdout 捕获 | +| **readline / richline** | stdout 捕获 | JSON stdout 捕获 | 降级:文本 stdout 捕获 | + +```mermaid +flowchart TD + CMD[命令执行] --> H{处理器类型} + H -- Handler --> O1[写 stdout/stderr] + H -- ResponseHandler --> O2[JSON 写 stdout] + H -- ResponseStreamHandler --> O3[每次 Send 镜像 stdout] + O2 --> C{调用方式} + O3 --> C + C -- RunCallback --> CB[泛型回调] + C -- ResponseStream --> CH[通道消费] + C -- 无消费者 --> FD[降级为 stdout 文本] +``` + +**降级语义**:当流命令在不支持结构化流消费的上下文(如 readline、HTTP POST)中执行时,`InvocationStream.Send` 仅走 stdout/stderr 镜像路径。由于没有调用 `ResponseStream()`,内部通道不会创建,不会有死锁或内存泄漏。 + +## MCP 集成 + +MCP server(`internal/mcpserver`)完整支持三类处理器: + +### 工具发现 + +`collectTools` 遍历命令树时检测三种处理器,并通过 `TypeInfo()` 获取运行时类型元数据: + +```mermaid +flowchart LR + CT[collectTools] --> W{遍历命令树} + W --> H[Handler] --> T1[toolDef] + W --> RH[ResponseHandler] --> TI1[TypeInfo] --> T2[toolDef + ResponseType] + W --> RS[ResponseStreamHandler] --> TI2[TypeInfo] --> T3[toolDef + ResponseType + SupportsStream] +``` + +### 输出 Schema + +有类型元数据的命令,MCP `outputSchema` 会包含 `response` 字段: + +- **Unary**:`response` 为单值,带 `x-redant-type` 标注类型名。 +- **Stream**:`response` 为 `type: "array"`,带 `x-redant-type` 标注元素类型名。 + +### 工具调用 + +```mermaid +flowchart TD + CT[callTool] --> C{有 ResponseType?} + C -- 是 --> RC[RunCallback 收集响应] + RC --> S{SupportsStream?} + S -- 是 --> ARR[response = array] + S -- 否 --> SINGLE[response = single] + C -- 否 --> RUN[普通 inv.Run] + RUN --> BT[buildToolResult stdout/stderr] + ARR --> BT + SINGLE --> BT +``` + +## 开发任务同步 + +- [x] 增加 `InvocationStream` 与 `TypedWriter[T]`。 +- [x] 增加 `ResponseHandler` / `ResponseStreamHandler` 接口与 `Unary[T]` / `Stream[T]` 泛型适配器。 +- [x] 增加 `ResponseTypeInfo` 运行时类型元数据。 +- [x] 增加 `Invocation.ResponseStream()`。 +- [x] 三类处理器互斥校验(`resolveConfiguredHandler`)。 +- [x] `InvocationStream.Send` 直接发送 `any`,自动镜像 stdio。 +- [x] `RunCallback[T]` 泛型回调消费。 +- [x] MCP 工具发现、输出 Schema 与 `RunCallback` 调用集成。 +- [x] WebUI 流式 WebSocket 端点(`/api/run/stream/ws`)。 +- [x] 回归测试:stdio 回退 + channel 消费 + 类型不匹配。 +- [ ] 后续任务:补充流式中间件(按事件级拦截)。 \ No newline at end of file diff --git a/docs/USAGE_AT_A_GLANCE.md b/docs/USAGE_AT_A_GLANCE.md index 16aee82..b5fb1d6 100644 --- a/docs/USAGE_AT_A_GLANCE.md +++ b/docs/USAGE_AT_A_GLANCE.md @@ -162,3 +162,62 @@ commit := &redant.Command{ repo.Children = append(repo.Children, commit) root.Children = append(root.Children, repo) ``` + +## 8) Unary 响应命令(ResponseHandler) + +当命令需要返回结构化单响应时,使用 `ResponseHandler` 配合 `Unary[T]` 泛型适配器: + +```go +type VersionInfo struct { + Version string `json:"version"` + BuildDate string `json:"buildDate"` +} + +versionCmd := &redant.Command{ + Use: "version", + ResponseHandler: redant.Unary(func(ctx context.Context, inv *redant.Invocation) (VersionInfo, error) { + return VersionInfo{Version: "1.0.0", BuildDate: "2026-04-01"}, nil + }), +} +``` + +说明: + +- `Unary[T]` 将类型化函数适配为 `ResponseHandler` 接口。 +- 响应自动 JSON 序列化写入 stdout。 +- 泛型回调消费:`redant.RunCallback[VersionInfo](inv, callback)`,回调调用一次。 +- 可通过 `inv.Response()` 获取 Unary 响应值。 +- 运行时类型信息通过 `TypeInfo()` 暴露,供 MCP 等集成层使用。 + +## 9) 流式响应命令(ResponseStreamHandler) + +当命令需要"结构化流式输出"时,使用 `ResponseStreamHandler` 配合 `Stream[T]` 泛型适配器: + +```go +chat := &redant.Command{ + Use: "chat", + ResponseStreamHandler: redant.Stream(func(ctx context.Context, inv *redant.Invocation, out *redant.TypedWriter[string]) error { + if err := out.Send("hello"); err != nil { + return err + } + return out.Send("world") + }), +} +``` + +说明: + +- `TypedWriter[T].Send(v)` 直接发送泛型数据,自动镜像文本到 stdout。 +- 响应流通过 `inv.ResponseStream()` 消费(`<-chan any`),`Run()` 结束后自动关闭。 +- 泛型回调消费:`redant.RunCallback[string](inv, callback)`,回调调用多次。 +- 现有 `Handler` 与 `Middleware` 仍可继续使用,三类处理器互斥。 + +## 10) 三类处理器对比 + +| 维度 | Handler | ResponseHandler (Unary) | ResponseStreamHandler (Stream) | +| ----------- | -------------- | ---------------------------- | ------------------------------ | +| 适配器 | 无(直接赋值) | `Unary[T]` | `Stream[T]` | +| 响应次数 | 无结构化响应 | 一次 | 多次 | +| stdout 行为 | 手动写入 | JSON 自动序列化 | 每次 Send 自动镜像 | +| RunCallback | 无回调 | 回调一次 | 回调多次 | +| MCP 集成 | stdout 捕获 | `structuredContent.response` | `structuredContent.response[]` | diff --git a/example/fastcommit/main.go b/example/fastcommit/main.go index 53e96ca..fa3e5f3 100644 --- a/example/fastcommit/main.go +++ b/example/fastcommit/main.go @@ -356,11 +356,39 @@ func main() { commitCmd.Children = append(commitCmd.Children, detailedCmd) + // 流式交互测试命令:每秒发送一条数据 + var streamCount int64 + streamCmd := &redant.Command{ + Use: "stream", + Short: "流式输出测试(每秒一条数据)", + Options: redant.OptionSet{ + {Flag: "count", Shorthand: "n", Description: "发送条数(0 表示无限)", Value: redant.Int64Of(&streamCount), Default: "10"}, + }, + ResponseStreamHandler: redant.Stream(func(ctx context.Context, inv *redant.Invocation, out *redant.TypedWriter[string]) error { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for i := int64(1); streamCount == 0 || i <= streamCount; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + msg := fmt.Sprintf("[%s] #%d hello stream\n", time.Now().Format("15:04:05"), i) + if err := out.Send(msg); err != nil { + return err + } + } + } + return nil + }), + } + rootCmd.Children = append(rootCmd.Children, commitCmd, releaseShipCmd, projectCmd, profileCmd, + streamCmd, completioncmd.New(), readlinecmd.New(), richlinecmd.New(), diff --git a/example/stream-interactive/README.md b/example/stream-interactive/README.md new file mode 100644 index 0000000..23bb5bb --- /dev/null +++ b/example/stream-interactive/README.md @@ -0,0 +1,55 @@ +# stream-interactive 示例 + +这个示例展示 Redant 在"纯响应流"语义下的两种运行方式: + +1. `stdio`:直接执行 `inv.Run()`,输出落到终端。 +2. `callback`:通过 `RunCallback` 消费类型化输出数据,适合函数调用场景。 + +并且在 `callback` 模式下会实时打印类型化输出文本(`string`)。 + +## 调用流程 + +```mermaid +flowchart LR + A[调用 stream-interactive] --> B{mode} + B -- stdio --> C[inv.Run] + B -- callback --> D[RunCallback] + C --> E[ResponseStreamHandler] + D --> E + E --> F[TypedWriter.Send] + F -- stdio --> G[自动镜像 Stdout] + F -- callback --> H[typed callback] +``` + +## 运行方式 + +### 1) 终端交互模式(stdio) + +```bash +go run ./example/stream-interactive stdio +``` + +运行后可直接看到文本输出(`TypedWriter[string].Send` 自动镜像到 stdout)。 + +### 2) 回调模式(callback) + +```bash +go run ./example/stream-interactive callback +``` + +该模式会通过 `RunCallback[string]` 实时接收并打印文本输出。 + +> 兼容说明:`channel` 仍作为别名可用,便于旧脚本平滑迁移。 + +## 关键代码点 + +- 命令定义:`ResponseStreamHandler: redant.Stream(...)` +- 泛型写入:`out.Send("hello")` (`TypedWriter[string].Send`) +- 回调执行:`RunCallback[T](inv, callback)` + +## 阻塞语义说明 + +- `inv.Run()` 是阻塞调用。 +- 在 `RunCallback` 模式中,流通道中的数据会类型断言为 `T` 并分发到回调;类型不匹配时返回错误。 +- `Run()` 结束后流自动关闭。 +- 推荐始终通过 `context` 设置超时/取消,避免上游异常导致无限等待。 diff --git a/example/stream-interactive/main.go b/example/stream-interactive/main.go new file mode 100644 index 0000000..5b3eca6 --- /dev/null +++ b/example/stream-interactive/main.go @@ -0,0 +1,77 @@ +package main + +import ( + "context" + "fmt" + "io" + "os" + + "github.com/pubgo/redant" +) + +func main() { + var persona string + + chatCmd := &redant.Command{ + Use: "chat [topic]", + Short: "交互式聊天示例(ResponseStreamHandler)", + Options: redant.OptionSet{ + { + Flag: "persona", + Description: "机器人人设", + Default: "assistant", + Value: redant.StringOf(&persona), + }, + }, + ResponseStreamHandler: redant.Stream(func(ctx context.Context, inv *redant.Invocation, out *redant.TypedWriter[string]) error { + topic := "default-topic" + if inv != nil && len(inv.Args) > 0 { + topic = inv.Args[0] + } + + if err := out.Send(fmt.Sprintf("[%s] topic=%s\n", persona, topic)); err != nil { + return err + } + if err := out.Send("chunk-1: hello\n"); err != nil { + return err + } + return out.Send("chunk-2: stream\n") + }), + } + + if len(os.Args) < 2 { + fmt.Fprintln(os.Stderr, "Usage: stream-interactive ") + os.Exit(2) + } + + mode := os.Args[1] + switch mode { + case "stdio": + if err := chatCmd.Invoke().WithOS().Run(); err != nil { + fmt.Fprintf(os.Stderr, "run failed: %v\n", err) + os.Exit(1) + } + case "callback", "channel": + inv := chatCmd.Invoke("--persona", "planner", "stream-topic") + inv.Annotations = map[string]any{"request_id": "demo.channel"} + inv.Stdout = io.Discard + inv.Stderr = io.Discard + inv.Stdin = nil + + fmt.Println("=== callback 实时事件 ===") + eventCount := 0 + if err := redant.RunCallback[string](inv, func(chunk string) error { + eventCount++ + fmt.Println(chunk) + return nil + }); err != nil { + fmt.Fprintf(os.Stderr, "run failed: %v\n", err) + os.Exit(1) + } + fmt.Printf("=== callback 完成,事件总数: %d ===\n", eventCount) + default: + fmt.Fprintf(os.Stderr, "unknown mode: %s\n", mode) + fmt.Fprintln(os.Stderr, "Usage: stream-interactive ") + os.Exit(2) + } +} diff --git a/example/unary/README.md b/example/unary/README.md new file mode 100644 index 0000000..2286f1a --- /dev/null +++ b/example/unary/README.md @@ -0,0 +1,59 @@ +# unary 示例 + +展示 Redant `ResponseHandler` + `Unary[T]` 泛型适配器的两种运行方式。 + +## 核心概念 + +`ResponseHandler` 适用于"请求→单响应"场景(类似 RPC 的 Unary 调用): + +- 处理器返回一个结构体,框架自动 JSON 序列化写入 stdout。 +- 调用方可通过 `RunCallback[T]` 直接获取类型化响应对象。 +- 与 `ResponseStreamHandler`(流式)互斥,同一命令只能配置一种。 + +## 调用流程 + +```mermaid +flowchart LR + A[调用 unary-demo version] --> B{mode} + B -- stdio --> C[inv.Run] + B -- callback --> D[RunCallback] + C --> E[ResponseHandler.Handle] + D --> E + E --> F[返回 VersionInfo] + F -- stdio --> G[JSON 序列化到 stdout] + F -- callback --> H[typed callback] +``` + +## 运行方式 + +```bash +go run ./example/unary +``` + +输出: + +``` +=== stdio 模式(JSON 自动输出到 stdout)=== +{"version":"1.2.3","buildDate":"2026-04-01","goVersion":"go1.23"} +=== callback 模式(RunCallback 泛型回调)=== + Version: 1.2.3 + BuildDate: 2026-04-01 + GoVersion: go1.23 +=== 完成 === +``` + +## 关键代码点 + +- 命令定义:`ResponseHandler: redant.Unary(func(...) (VersionInfo, error) {...})` +- 泛型回调:`RunCallback[VersionInfo](inv, callback)` +- stdio 回退:`inv.Run()` 自动将响应 JSON 序列化到 stdout + +## 与 Stream 的区别 + +| 维度 | Unary (`ResponseHandler`) | Stream (`ResponseStreamHandler`) | +| ---------------- | ------------------------- | -------------------------------- | +| 响应次数 | 一次 | 多次 | +| 适配器 | `Unary[T]` | `Stream[T]` | +| 写入器 | 无(直接返回值) | `TypedWriter[T].Send` | +| RunCallback 行为 | 回调调用一次 | 回调调用多次 | +| stdout 回退 | JSON 序列化 | 每次 Send 镜像 | diff --git a/example/unary/main.go b/example/unary/main.go new file mode 100644 index 0000000..4669c6c --- /dev/null +++ b/example/unary/main.go @@ -0,0 +1,65 @@ +package main + +import ( + "context" + "fmt" + "io" + "os" + "time" + + "github.com/pubgo/redant" +) + +// VersionInfo 展示 Unary 处理器的结构化响应类型。 +type VersionInfo struct { + Version string `json:"version"` + BuildDate string `json:"buildDate"` + GoVersion string `json:"goVersion"` +} + +func main() { + versionCmd := &redant.Command{ + Use: "version", + Short: "返回版本信息(Unary ResponseHandler 示例)", + ResponseHandler: redant.Unary(func(ctx context.Context, inv *redant.Invocation) (VersionInfo, error) { + return VersionInfo{ + Version: "1.2.3", + BuildDate: time.Now().Format("2006-01-02"), + GoVersion: "go1.23", + }, nil + }), + } + + root := &redant.Command{ + Use: "unary-demo", + Short: "Unary 响应处理器示例", + Children: []*redant.Command{versionCmd}, + } + + // --- stdio 模式 --- + fmt.Println("=== stdio 模式(JSON 自动输出到 stdout)===") + inv := root.Invoke("version") + inv.Stdout = os.Stdout + inv.Stderr = os.Stderr + if err := inv.Run(); err != nil { + fmt.Fprintf(os.Stderr, "run failed: %v\n", err) + os.Exit(1) + } + fmt.Println() + + // --- callback 模式 --- + fmt.Println("=== callback 模式(RunCallback 泛型回调)===") + inv2 := root.Invoke("version") + inv2.Stdout = io.Discard + inv2.Stderr = io.Discard + if err := redant.RunCallback[VersionInfo](inv2, func(v VersionInfo) error { + fmt.Printf(" Version: %s\n", v.Version) + fmt.Printf(" BuildDate: %s\n", v.BuildDate) + fmt.Printf(" GoVersion: %s\n", v.GoVersion) + return nil + }); err != nil { + fmt.Fprintf(os.Stderr, "run failed: %v\n", err) + os.Exit(1) + } + fmt.Println("=== 完成 ===") +} diff --git a/handler.go b/handler.go index 18bf143..4d64ef9 100644 --- a/handler.go +++ b/handler.go @@ -1,6 +1,308 @@ package redant -import "context" +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "reflect" +) // HandlerFunc handles an Invocation of a command. type HandlerFunc func(ctx context.Context, inv *Invocation) error + +const defaultStreamResponseBuffer = 64 + +// StreamEnvelope is the NDJSON envelope written to stdout/stderr to distinguish +// structured response data from ordinary log output. +// +// Each envelope is a single JSON line: {"$":"resp","type":"T","data":...} +// Consumers can test for the "$" key to separate response payloads from +// plain-text log lines. +type StreamEnvelope struct { + Kind string `json:"$"` // "resp" or "error" + Type string `json:"type,omitempty"` // type name from TypeInfo + Data any `json:"data"` // payload +} + +// StreamError models structured stream errors. +type StreamError struct { + Code int `json:"code"` + Message string `json:"message"` + Details any `json:"details,omitempty"` +} + +// ResponseTypeInfo describes runtime-visible output type metadata. +type ResponseTypeInfo struct { + TypeName string `json:"typeName,omitempty"` + Schema string `json:"schema,omitempty"` +} + +// ResponseHandler models request-response unary handling. +type ResponseHandler interface { + Handle(context.Context, *Invocation) (any, error) + TypeInfo() ResponseTypeInfo +} + +// ResponseStreamHandler models request-response-stream handling. +type ResponseStreamHandler interface { + HandleStream(context.Context, *Invocation, *InvocationStream) error + TypeInfo() ResponseTypeInfo +} + +// InvocationStream provides response-stream communication. +// Response stream is internally created by invocation and automatically closed +// when response stream handling returns. +type InvocationStream struct { + ctx context.Context + inv *Invocation + ch chan any // captured at creation to avoid racing with closeResponseStream + typeName string // response type name for envelope output +} + +// NewInvocationStream creates a stream bound to invocation. +func NewInvocationStream(ctx context.Context, inv *Invocation) *InvocationStream { + var ch chan any + if inv != nil { + ch = inv.responseStream + } + return &InvocationStream{ + ctx: ctx, + inv: inv, + ch: ch, + } +} + +// Send emits data to the invocation-owned response stream (channel) and mirrors +// structured output to stdout/stderr as NDJSON envelopes. +// +// Channel consumers (RunCallback, ResponseStream, WebUI stream WS) receive +// raw data without envelopes. Only the stdout/stderr mirror uses envelopes +// so that consumers can distinguish response payloads from ordinary log lines. +func (s *InvocationStream) Send(data any) error { + if s.ch != nil { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + case s.ch <- data: + } + } + + if s.inv == nil { + return nil + } + + switch v := data.(type) { + case *StreamError: + if v != nil && s.inv.Stderr != nil { + return writeEnvelope(s.inv.Stderr, StreamEnvelope{Kind: "error", Data: v}) + } + case StreamError: + if s.inv.Stderr != nil { + return writeEnvelope(s.inv.Stderr, StreamEnvelope{Kind: "error", Data: v}) + } + default: + if s.inv.Stdout != nil { + return writeEnvelope(s.inv.Stdout, StreamEnvelope{ + Kind: "resp", + Type: s.typeName, + Data: v, + }) + } + } + return nil +} + +// TypedWriter writes typed payloads to InvocationStream. +type TypedWriter[T any] struct { + stream *InvocationStream +} + +// Send emits a typed value to the stream. +func (w *TypedWriter[T]) Send(v T) error { + return w.stream.Send(v) +} + +// Raw returns the underlying InvocationStream for advanced use. +func (w *TypedWriter[T]) Raw() *InvocationStream { + return w.stream +} + +// Unary adapts a typed unary function into a runtime ResponseHandler. +func Unary[T any](fn func(context.Context, *Invocation) (T, error)) ResponseHandler { + return unaryHandler[T]{fn: fn} +} + +type unaryHandler[T any] struct { + fn func(context.Context, *Invocation) (T, error) +} + +func (h unaryHandler[T]) Handle(ctx context.Context, inv *Invocation) (any, error) { + v, err := h.fn(ctx, inv) + if err != nil { + return nil, err + } + return v, nil +} + +func (h unaryHandler[T]) TypeInfo() ResponseTypeInfo { + return ResponseTypeInfo{TypeName: typeNameOf[T]()} +} + +// Stream adapts typed stream producer into runtime ResponseStreamHandler. +func Stream[T any](fn func(context.Context, *Invocation, *TypedWriter[T]) error) ResponseStreamHandler { + return streamHandler[T]{fn: fn} +} + +type streamHandler[T any] struct { + fn func(context.Context, *Invocation, *TypedWriter[T]) error +} + +func (h streamHandler[T]) HandleStream(ctx context.Context, inv *Invocation, stream *InvocationStream) error { + return h.fn(ctx, inv, &TypedWriter[T]{stream: stream}) +} + +func (h streamHandler[T]) TypeInfo() ResponseTypeInfo { + return ResponseTypeInfo{TypeName: typeNameOf[T]()} +} + +// adaptResponseHandler converts ResponseHandler to HandlerFunc. +func adaptResponseHandler(responseHandler ResponseHandler) HandlerFunc { + if responseHandler == nil { + return nil + } + + typeName := responseHandler.TypeInfo().TypeName + return func(ctx context.Context, inv *Invocation) error { + resp, err := responseHandler.Handle(ctx, inv) + if err != nil { + return fmt.Errorf("running response handler: %w", err) + } + inv.setResponse(resp) + return writeUnaryResponse(inv, resp, typeName) + } +} + +// adaptResponseStreamHandler converts ResponseStreamHandler to HandlerFunc. +func adaptResponseStreamHandler(responseStreamHandler ResponseStreamHandler) HandlerFunc { + if responseStreamHandler == nil { + return nil + } + + typeName := responseStreamHandler.TypeInfo().TypeName + return func(ctx context.Context, inv *Invocation) error { + stream := NewInvocationStream(ctx, inv) + stream.typeName = typeName + if err := responseStreamHandler.HandleStream(ctx, inv, stream); err != nil { + return fmt.Errorf("running response stream handler: %w", err) + } + return nil + } +} + +// RunCallback runs invocation via original Run and dispatches typed callback. +// +// Callback will be invoked in two cases: +// - unary response payload (from ResponseHandler) +// - stream data payload (from ResponseStreamHandler) +func RunCallback[T any](inv *Invocation, callback func(T) error) error { + if inv == nil { + return errors.New("nil invocation") + } + if callback == nil { + return errors.New("nil callback") + } + + runCtx, cancel := context.WithCancel(inv.Context()) + defer cancel() + runInv := inv.WithContext(runCtx) + + stream := runInv.ResponseStream() + consumeErrCh := make(chan error, 1) + go func() { + defer close(consumeErrCh) + for evt := range stream { + typed, ok := evt.(T) + if !ok { + consumeErrCh <- fmt.Errorf("typed stream data mismatch: got %T", evt) + cancel() + return + } + + if err := callback(typed); err != nil { + consumeErrCh <- err + cancel() + return + } + } + }() + + runErr := runInv.Run() + cancel() + + var consumeErr error + for err := range consumeErrCh { + if err != nil { + consumeErr = err + break + } + } + + if consumeErr != nil { + return errors.Join(runErr, consumeErr) + } + if runErr != nil { + return runErr + } + + resp, ok := runInv.Response() + if !ok { + return nil + } + + typed, ok := resp.(T) + if !ok { + return fmt.Errorf("typed response mismatch: got %T", resp) + } + + return callback(typed) +} + +func writeUnaryResponse(inv *Invocation, resp any, typeName string) error { + if inv == nil || inv.Stdout == nil || resp == nil { + return nil + } + + return writeEnvelope(inv.Stdout, StreamEnvelope{ + Kind: "resp", + Type: typeName, + Data: resp, + }) +} + +// writeEnvelope writes a single NDJSON envelope line to w. +func writeEnvelope(w io.Writer, env StreamEnvelope) error { + b, err := json.Marshal(env) + if err != nil { + _, werr := fmt.Fprintf(w, "%v", env.Data) + return errors.Join(err, werr) + } + b = append(b, '\n') + _, err = w.Write(b) + return err +} + +func typeNameOf[T any]() string { + t := reflect.TypeOf((*T)(nil)).Elem() + if t == nil { + return "unknown" + } + if t.Name() == "" { + return t.String() + } + if t.PkgPath() == "" { + return t.String() + } + return t.PkgPath() + "." + t.Name() +} diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..7c61157 --- /dev/null +++ b/handler_test.go @@ -0,0 +1,293 @@ +package redant + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "slices" + "strings" + "testing" + "time" +) + +func TestRun_NoResponseHandler(t *testing.T) { + called := false + cmd := &Command{ + Use: "echo", + Handler: func(ctx context.Context, inv *Invocation) error { + called = true + _, _ = inv.Stdout.Write([]byte("ok")) + return nil + }, + } + + inv := cmd.Invoke() + inv.Stdout = io.Discard + inv.Stderr = io.Discard + + if err := inv.Run(); err != nil { + t.Fatalf("run failed: %v", err) + } + if !called { + t.Fatalf("expected handler called") + } +} + +func TestRunCallback_UnaryTyped(t *testing.T) { + type reply struct{ Message string } + + cmd := &Command{ + Use: "echo", + ResponseHandler: Unary(func(ctx context.Context, inv *Invocation) (reply, error) { + return reply{Message: "ok"}, nil + }), + } + + inv := cmd.Invoke() + inv.Stdout = io.Discard + inv.Stderr = io.Discard + + var got string + err := RunCallback[reply](inv, func(v reply) error { + got = v.Message + return nil + }) + if err != nil { + t.Fatalf("run callback failed: %v", err) + } + if got != "ok" { + t.Fatalf("got=%q, want=%q", got, "ok") + } +} + +func TestRunCallback_StreamTyped(t *testing.T) { + cmd := &Command{ + Use: "chat", + ResponseStreamHandler: Stream(func(ctx context.Context, inv *Invocation, out *TypedWriter[string]) error { + if err := out.Send("hello"); err != nil { + return err + } + return out.Send("world") + }), + } + + inv := cmd.Invoke() + inv.Stdout = io.Discard + inv.Stderr = io.Discard + + var got []string + err := RunCallback[string](inv, func(v string) error { + got = append(got, v) + return nil + }) + if err != nil { + t.Fatalf("run callback failed: %v", err) + } + + if !slices.Contains(got, "hello") { + t.Fatalf("missing hello: %v", got) + } + if !slices.Contains(got, "world") { + t.Fatalf("missing world: %v", got) + } +} + +func TestRunCallback_StreamTypeMismatch(t *testing.T) { + cmd := &Command{ + Use: "chat", + ResponseStreamHandler: Stream(func(ctx context.Context, inv *Invocation, out *TypedWriter[int]) error { + return out.Send(1) + }), + } + + inv := cmd.Invoke() + inv.Stdout = io.Discard + inv.Stderr = io.Discard + + err := RunCallback[string](inv, func(v string) error { + return nil + }) + if err == nil { + t.Fatalf("expected type mismatch error") + } +} + +func TestRunCallback_UnaryTypeMismatch(t *testing.T) { + cmd := &Command{ + Use: "echo", + ResponseHandler: Unary(func(ctx context.Context, inv *Invocation) (int, error) { + return 1, nil + }), + } + + inv := cmd.Invoke() + inv.Stdout = io.Discard + inv.Stderr = io.Discard + + err := RunCallback[string](inv, func(v string) error { + return nil + }) + if err == nil { + t.Fatalf("expected type mismatch error") + } +} + +func TestRunCallback_CallbackError(t *testing.T) { + wantErr := errors.New("stop") + cmd := &Command{ + Use: "chat", + ResponseStreamHandler: Stream(func(ctx context.Context, inv *Invocation, out *TypedWriter[string]) error { + return out.Send("hello") + }), + } + + inv := cmd.Invoke() + inv.Stdout = io.Discard + inv.Stderr = io.Discard + + err := RunCallback[string](inv, func(v string) error { + return wantErr + }) + if !errors.Is(err, wantErr) { + t.Fatalf("expected callback error, got %v", err) + } +} + +func TestResponseStreamHandlerFallsBackToStdIO(t *testing.T) { + cmd := &Command{ + Use: "chat", + ResponseStreamHandler: Stream(func(ctx context.Context, inv *Invocation, out *TypedWriter[string]) error { + if err := out.Send("phase:init"); err != nil { + return err + } + return out.Send("hello, redant") + }), + } + + var stdout bytes.Buffer + inv := cmd.Invoke() + inv.Stdin = bytes.NewBuffer(nil) + inv.Stdout = &stdout + inv.Stderr = io.Discard + + if err := inv.Run(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // stdout should contain NDJSON envelopes + lines := strings.Split(strings.TrimSpace(stdout.String()), "\n") + if len(lines) != 2 { + t.Fatalf("expected 2 NDJSON lines, got %d: %q", len(lines), stdout.String()) + } + + for i, line := range lines { + var env StreamEnvelope + if err := json.Unmarshal([]byte(line), &env); err != nil { + t.Fatalf("line %d: invalid NDJSON: %v", i, err) + } + if env.Kind != "resp" { + t.Fatalf("line %d: $.kind=%q, want \"resp\"", i, env.Kind) + } + } + + // verify first envelope data + var env0 StreamEnvelope + _ = json.Unmarshal([]byte(lines[0]), &env0) + if data, ok := env0.Data.(string); !ok || data != "phase:init" { + t.Fatalf("line 0: data=%v, want \"phase:init\"", env0.Data) + } +} + +func TestResponseStreamHandlerWithChannels(t *testing.T) { + cmd := &Command{ + Use: "chat", + ResponseStreamHandler: Stream(func(ctx context.Context, inv *Invocation, out *TypedWriter[string]) error { + return out.Send("echo:ping") + }), + } + + inv := cmd.Invoke() + inv.Stdout = io.Discard + inv.Stderr = io.Discard + inv.Stdin = bytes.NewBuffer(nil) + + out := inv.ResponseStream() + + if err := inv.Run(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var got string + for evt := range out { + if s, ok := evt.(string); ok { + got = s + break + } + } + + if got != "echo:ping" { + t.Fatalf("got = %q, want %q", got, "echo:ping") + } +} + +func TestInvocationRunClosesResponseStream(t *testing.T) { + cmd := &Command{ + Use: "chat", + ResponseStreamHandler: Stream(func(ctx context.Context, inv *Invocation, out *TypedWriter[string]) error { + return out.Send("done") + }), + } + + inv := cmd.Invoke() + inv.Stdout = io.Discard + inv.Stderr = io.Discard + inv.Stdin = bytes.NewBuffer(nil) + out := inv.ResponseStream() + + if err := inv.Run(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + count := 0 + for range out { + count++ + } + if count == 0 { + t.Fatalf("expected at least one stream response") + } +} + +func TestResponseStreamHandlerRunWithoutChannelConsumerDoesNotBlock(t *testing.T) { + cmd := &Command{ + Use: "chat", + ResponseStreamHandler: Stream(func(ctx context.Context, inv *Invocation, out *TypedWriter[string]) error { + for i := 0; i < defaultStreamResponseBuffer*4; i++ { + if err := out.Send("x"); err != nil { + return err + } + } + return nil + }), + } + + inv := cmd.Invoke() + inv.Stdout = io.Discard + inv.Stderr = io.Discard + inv.Stdin = bytes.NewBuffer(nil) + + done := make(chan error, 1) + go func() { + done <- inv.Run() + }() + + select { + case err := <-done: + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("run blocked without response stream consumer") + } +} diff --git a/internal/mcpserver/server.go b/internal/mcpserver/server.go index 4580f7a..601894b 100644 --- a/internal/mcpserver/server.go +++ b/internal/mcpserver/server.go @@ -25,11 +25,13 @@ type Server struct { } type ToolInfo struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Path []string `json:"path"` - InputSchema map[string]any `json:"inputSchema"` - OutputSchema map[string]any `json:"outputSchema"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Path []string `json:"path"` + InputSchema map[string]any `json:"inputSchema"` + OutputSchema map[string]any `json:"outputSchema"` + SupportsStream bool `json:"supportsStream,omitempty"` + ResponseType *redant.ResponseTypeInfo `json:"responseType,omitempty"` } func ListToolInfos(root *redant.Command) []ToolInfo { @@ -37,11 +39,13 @@ func ListToolInfos(root *redant.Command) []ToolInfo { out := make([]ToolInfo, 0, len(defs)) for _, td := range defs { out = append(out, ToolInfo{ - Name: td.Name, - Description: td.Description, - Path: append([]string(nil), td.PathTokens...), - InputSchema: td.InputSchema, - OutputSchema: td.OutputSchema, + Name: td.Name, + Description: td.Description, + Path: append([]string(nil), td.PathTokens...), + InputSchema: td.InputSchema, + OutputSchema: td.OutputSchema, + SupportsStream: td.SupportsStream, + ResponseType: td.ResponseType, }) } return out diff --git a/internal/mcpserver/server_test.go b/internal/mcpserver/server_test.go index 0f92721..a74e30b 100644 --- a/internal/mcpserver/server_test.go +++ b/internal/mcpserver/server_test.go @@ -583,3 +583,238 @@ func prettyJSON(v any) string { } return string(b) } + +func TestCollectToolsIncludesStreamHandler(t *testing.T) { + root := &redant.Command{Use: "app"} + root.Children = append(root.Children, &redant.Command{ + Use: "chat", + Short: "stream chat", + ResponseStreamHandler: redant.Stream(func(ctx context.Context, inv *redant.Invocation, out *redant.TypedWriter[string]) error { + return out.Send("hello") + }), + }) + + s := New(root) + if len(s.tools) != 1 { + t.Fatalf("tools count = %d, want 1", len(s.tools)) + } + + tool := s.tools[0] + if tool.Name != "chat" { + t.Fatalf("tool name = %q, want %q", tool.Name, "chat") + } + if !tool.SupportsStream { + t.Fatalf("expected SupportsStream=true") + } + if tool.ResponseType == nil { + t.Fatalf("expected ResponseType to be set") + } + if tool.ResponseType.TypeName != "string" { + t.Fatalf("ResponseType.TypeName = %q, want string", tool.ResponseType.TypeName) + } + // Output schema should include response field for stream tools. + props, _ := tool.OutputSchema["properties"].(map[string]any) + if props == nil { + t.Fatalf("output schema properties missing") + } + respProp, ok := props["response"] + if !ok { + t.Fatalf("output schema should have response property for stream tool") + } + respMap, _ := respProp.(map[string]any) + if respMap["type"] != "array" { + t.Fatalf("response schema type = %v, want array for stream", respMap["type"]) + } +} + +func TestCollectToolsIncludesResponseHandler(t *testing.T) { + root := &redant.Command{Use: "app"} + root.Children = append(root.Children, &redant.Command{ + Use: "greet", + Short: "unary greet", + ResponseHandler: redant.Unary(func(ctx context.Context, inv *redant.Invocation) (string, error) { + return "hi", nil + }), + }) + + s := New(root) + if len(s.tools) != 1 { + t.Fatalf("tools count = %d, want 1", len(s.tools)) + } + + tool := s.tools[0] + if tool.Name != "greet" { + t.Fatalf("tool name = %q, want %q", tool.Name, "greet") + } + if tool.SupportsStream { + t.Fatalf("expected SupportsStream=false for ResponseHandler") + } + if tool.ResponseType == nil { + t.Fatalf("expected ResponseType to be set for ResponseHandler") + } + if tool.ResponseType.TypeName != "string" { + t.Fatalf("ResponseType.TypeName = %q, want string", tool.ResponseType.TypeName) + } + // Output schema should include response field (non-array) for unary. + props, _ := tool.OutputSchema["properties"].(map[string]any) + respProp, ok := props["response"] + if !ok { + t.Fatalf("output schema should have response property for unary tool") + } + respMap, _ := respProp.(map[string]any) + if _, hasType := respMap["type"]; hasType { + t.Fatalf("unary response schema should not have array type") + } +} + +func TestCallToolWithStreamHandler(t *testing.T) { + root := &redant.Command{Use: "app"} + root.Children = append(root.Children, &redant.Command{ + Use: "chat", + ResponseStreamHandler: redant.Stream(func(ctx context.Context, inv *redant.Invocation, out *redant.TypedWriter[string]) error { + if err := out.Send("chunk-1"); err != nil { + return err + } + return out.Send("chunk-2") + }), + }) + + s := New(root) + result, err := s.callTool(context.Background(), toolsCallParams{ + Name: "chat", + Arguments: map[string]any{}, + }) + if err != nil { + t.Fatalf("callTool error: %v", err) + } + + isError, _ := result["isError"].(bool) + if isError { + t.Fatalf("expected success result, got error") + } + + structured, ok := result["structuredContent"].(map[string]any) + if !ok { + t.Fatalf("structuredContent missing") + } + stdout, _ := structured["stdout"].(string) + if !strings.Contains(stdout, "chunk-1") || !strings.Contains(stdout, "chunk-2") { + t.Fatalf("stdout = %q, want contains chunk-1 and chunk-2", stdout) + } + // Verify typed response array is collected. + responses, ok := structured["response"].([]any) + if !ok { + t.Fatalf("structured response should be an array, got %T", structured["response"]) + } + if len(responses) != 2 { + t.Fatalf("expected 2 response chunks, got %d", len(responses)) + } + if responses[0] != "chunk-1" || responses[1] != "chunk-2" { + t.Fatalf("response = %v, want [chunk-1, chunk-2]", responses) + } +} + +func TestCallToolWithResponseHandler(t *testing.T) { + root := &redant.Command{Use: "app"} + root.Children = append(root.Children, &redant.Command{ + Use: "greet", + ResponseHandler: redant.Unary(func(ctx context.Context, inv *redant.Invocation) (string, error) { + return "hello-unary", nil + }), + }) + + s := New(root) + result, err := s.callTool(context.Background(), toolsCallParams{ + Name: "greet", + Arguments: map[string]any{}, + }) + if err != nil { + t.Fatalf("callTool error: %v", err) + } + + isError, _ := result["isError"].(bool) + if isError { + t.Fatalf("expected success result, got error") + } + + structured, ok := result["structuredContent"].(map[string]any) + if !ok { + t.Fatalf("structuredContent missing") + } + stdout, _ := structured["stdout"].(string) + if !strings.Contains(stdout, "hello-unary") { + t.Fatalf("stdout = %q, want contains hello-unary", stdout) + } + // Verify typed unary response is collected (single value, not array). + respVal, ok := structured["response"] + if !ok { + t.Fatalf("structured response should exist for unary handler") + } + if respVal != "hello-unary" { + t.Fatalf("response = %v, want hello-unary", respVal) + } +} + +func TestServeSDKClientCallStreamTool(t *testing.T) { + root := &redant.Command{Use: "app"} + root.Children = append(root.Children, &redant.Command{ + Use: "chat", + Short: "streaming chat tool", + ResponseStreamHandler: redant.Stream(func(ctx context.Context, inv *redant.Invocation, out *redant.TypedWriter[string]) error { + if err := out.Send("hello"); err != nil { + return err + } + return out.Send("world") + }), + }) + + srv := New(root) + serverTransport, clientTransport := mcp.NewInMemoryTransports() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverErrCh := make(chan error, 1) + go func() { + serverErrCh <- srv.server.Run(ctx, serverTransport) + }() + + client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "v1"}, nil) + session, err := client.Connect(ctx, clientTransport, nil) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer func() { _ = session.Close() }() + + listRes, err := session.ListTools(ctx, &mcp.ListToolsParams{}) + if err != nil { + t.Fatalf("list tools: %v", err) + } + if len(listRes.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(listRes.Tools)) + } + if listRes.Tools[0].Name != "chat" { + t.Fatalf("tool name = %q, want chat", listRes.Tools[0].Name) + } + + callRes, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "chat", + Arguments: map[string]any{}, + }) + if err != nil { + t.Fatalf("call tool: %v", err) + } + if callRes.IsError { + t.Fatalf("expected success, got error: %q", firstText(callRes.Content)) + } + + text := firstText(callRes.Content) + if !strings.Contains(text, "hello") || !strings.Contains(text, "world") { + t.Fatalf("content text = %q, want contains hello and world", text) + } + + cancel() + if err := <-serverErrCh; err != nil && !strings.Contains(err.Error(), "context canceled") { + t.Fatalf("server run error: %v", err) + } +} diff --git a/internal/mcpserver/tools.go b/internal/mcpserver/tools.go index 82dc221..c6a2893 100644 --- a/internal/mcpserver/tools.go +++ b/internal/mcpserver/tools.go @@ -14,13 +14,15 @@ import ( ) type toolDef struct { - Name string - Description string - PathTokens []string - Command *redant.Command - Options redant.OptionSet - InputSchema map[string]any - OutputSchema map[string]any + Name string + Description string + PathTokens []string + Command *redant.Command + Options redant.OptionSet + InputSchema map[string]any + OutputSchema map[string]any + SupportsStream bool + ResponseType *redant.ResponseTypeInfo } func collectTools(root *redant.Command) []toolDef { @@ -39,15 +41,25 @@ func collectTools(root *redant.Command) []toolDef { effectiveOptions = append(effectiveOptions, inheritedOptions...) effectiveOptions = append(effectiveOptions, cmd.Options...) - if cmd.Handler != nil { + if cmd.Handler != nil || cmd.ResponseHandler != nil || cmd.ResponseStreamHandler != nil { + var respType *redant.ResponseTypeInfo + if cmd.ResponseHandler != nil { + ti := cmd.ResponseHandler.TypeInfo() + respType = &ti + } else if cmd.ResponseStreamHandler != nil { + ti := cmd.ResponseStreamHandler.TypeInfo() + respType = &ti + } tools = append(tools, toolDef{ - Name: strings.Join(path, "."), - Description: commandDescription(cmd), - PathTokens: append([]string(nil), path...), - Command: cmd, - Options: append(redant.OptionSet(nil), effectiveOptions...), - InputSchema: buildInputSchema(cmd.Args, effectiveOptions), - OutputSchema: buildOutputSchema(), + Name: strings.Join(path, "."), + Description: commandDescription(cmd), + PathTokens: append([]string(nil), path...), + Command: cmd, + Options: append(redant.OptionSet(nil), effectiveOptions...), + InputSchema: buildInputSchema(cmd.Args, effectiveOptions), + OutputSchema: buildOutputSchema(respType, cmd.ResponseStreamHandler != nil), + SupportsStream: cmd.ResponseStreamHandler != nil, + ResponseType: respType, }) } @@ -104,28 +116,43 @@ func buildInputSchema(args redant.ArgSet, options redant.OptionSet) map[string]a return schema } -func buildOutputSchema() map[string]any { +func buildOutputSchema(respType *redant.ResponseTypeInfo, isStream bool) map[string]any { + props := map[string]any{ + "ok": map[string]any{ + "type": "boolean", + }, + "stdout": map[string]any{ + "type": "string", + }, + "stderr": map[string]any{ + "type": "string", + }, + "error": map[string]any{ + "type": "string", + }, + "combined": map[string]any{ + "type": "string", + }, + } + required := []string{"ok", "stdout", "stderr", "error", "combined"} + + if respType != nil && respType.TypeName != "" { + respSchema := map[string]any{ + "description": "typed response payload (" + respType.TypeName + ")", + "x-redant-type": respType.TypeName, + } + if isStream { + respSchema["type"] = "array" + respSchema["items"] = map[string]any{} + } + props["response"] = respSchema + } + return map[string]any{ "type": "object", "additionalProperties": false, - "properties": map[string]any{ - "ok": map[string]any{ - "type": "boolean", - }, - "stdout": map[string]any{ - "type": "string", - }, - "stderr": map[string]any{ - "type": "string", - }, - "error": map[string]any{ - "type": "string", - }, - "combined": map[string]any{ - "type": "string", - }, - }, - "required": []string{"ok", "stdout", "stderr", "error", "combined"}, + "properties": props, + "required": required, } } @@ -300,6 +327,24 @@ func (s *Server) callTool(ctx context.Context, params toolsCallParams) (map[stri inv.Stderr = &stderr inv.Stdin = bytes.NewReader(nil) + // For commands with typed response, use RunCallback to collect structured data. + if tool.ResponseType != nil { + var responses []any + runErr := redant.RunCallback[any](inv.WithContext(ctx), func(v any) error { + responses = append(responses, v) + return nil + }) + result := buildToolResult(stdout.String(), stderr.String(), runErr) + if structured, ok := result["structuredContent"].(map[string]any); ok && len(responses) > 0 { + if tool.SupportsStream { + structured["response"] = responses + } else { + structured["response"] = responses[0] + } + } + return result, nil + } + runErr := inv.WithContext(ctx).Run() return buildToolResult(stdout.String(), stderr.String(), runErr), nil } diff --git a/internal/webui/server.go b/internal/webui/server.go index 2b2108f..420625d 100644 --- a/internal/webui/server.go +++ b/internal/webui/server.go @@ -49,18 +49,19 @@ type FlagMeta struct { } type CommandMeta struct { - ID string `json:"id"` - Name string `json:"name"` - Use string `json:"use"` - Aliases []string `json:"aliases,omitempty"` - Short string `json:"short,omitempty"` - Long string `json:"long,omitempty"` - Deprecated string `json:"deprecated,omitempty"` - RawArgs bool `json:"rawArgs"` - Path []string `json:"path"` - Description string `json:"description,omitempty"` - Flags []FlagMeta `json:"flags"` - Args []ArgMeta `json:"args"` + ID string `json:"id"` + Name string `json:"name"` + Use string `json:"use"` + Aliases []string `json:"aliases,omitempty"` + Short string `json:"short,omitempty"` + Long string `json:"long,omitempty"` + Deprecated string `json:"deprecated,omitempty"` + RawArgs bool `json:"rawArgs"` + Path []string `json:"path"` + Description string `json:"description,omitempty"` + Flags []FlagMeta `json:"flags"` + Args []ArgMeta `json:"args"` + SupportsStream bool `json:"supportsStream,omitempty"` } type RunRequest struct { @@ -92,6 +93,7 @@ type commandListResponse struct { type wsRunMessage struct { Type string `json:"type"` Request *RunRequest `json:"request,omitempty"` + Event any `json:"event,omitempty"` Data string `json:"data,omitempty"` Rows int `json:"rows,omitempty"` Cols int `json:"cols,omitempty"` @@ -133,10 +135,138 @@ func (a *App) Handler() http.Handler { mux.HandleFunc("/api/commands", a.handleCommands) mux.HandleFunc("/api/run", a.handleRun) mux.HandleFunc("/api/run/ws", a.handleRunWS) + mux.HandleFunc("/api/run/stream/ws", a.handleRunStreamWS) mux.HandleFunc("/api/terminal/ws", a.handleTerminalWS) return mux } +func (a *App) handleRunStreamWS(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() + + var sendMu sync.Mutex + send := func(msg wsRunMessage) error { + sendMu.Lock() + defer sendMu.Unlock() + + writeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return wsjson.Write(writeCtx, conn, msg) + } + + startCtx, cancelStart := context.WithTimeout(r.Context(), wsStartTimeout) + defer cancelStart() + + var start wsRunMessage + if err := wsjson.Read(startCtx, conn, &start); err != nil { + _ = send(wsRunMessage{Type: "error", Error: fmt.Sprintf("read start message failed: %v", err)}) + _ = conn.Close(websocket.StatusPolicyViolation, "invalid start message") + return + } + + if start.Type != "start" || start.Request == nil { + _ = send(wsRunMessage{Type: "error", Error: "first message must be {type:start, request:{...}}"}) + _ = conn.Close(websocket.StatusPolicyViolation, "missing start request") + return + } + + req := *start.Request + meta, ok := a.byID[strings.TrimSpace(req.Command)] + if !ok { + _ = send(wsRunMessage{Type: "error", Error: fmt.Sprintf("unknown command: %q", req.Command)}) + _ = conn.Close(websocket.StatusPolicyViolation, "unknown command") + return + } + + argv, program, invocation, err := buildInvocation(meta, req) + if err != nil { + _ = send(wsRunMessage{Type: "error", Error: err.Error()}) + _ = conn.Close(websocket.StatusPolicyViolation, "invalid invocation") + return + } + + runCtx, cancelRun := context.WithTimeout(r.Context(), resolveRunTimeout(req.TimeoutSeconds)) + defer cancelRun() + + if err := send(wsRunMessage{ + Type: "started", + Command: meta.ID, + Program: program, + Argv: append([]string(nil), argv...), + Invocation: invocation, + }); err != nil { + return + } + + root := cloneCommandTree(a.root) + inv := root.Invoke(argv...) + var stdout bytes.Buffer + var stderr bytes.Buffer + inv.Stdout = &stdout + inv.Stderr = &stderr + inv.Stdin = bytes.NewReader([]byte(req.Stdin)) + inv = inv.WithContext(runCtx) + stream := inv.ResponseStream() + + streamErrCh := make(chan error, 1) + go func() { + defer close(streamErrCh) + for evt := range stream { + if err := send(wsRunMessage{Type: "stream", Event: evt}); err != nil { + streamErrCh <- err + cancelRun() + return + } + } + }() + + runErrCh := make(chan error, 1) + go func() { + a.mu.Lock() + defer a.mu.Unlock() + runErrCh <- inv.Run() + }() + + runErr := <-runErrCh + var streamErr error + for err := range streamErrCh { + if err != nil { + streamErr = err + break + } + } + runErr = errors.Join(runErr, streamErr) + timedOut := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(runCtx.Err(), context.DeadlineExceeded) + displayErr := withInteractiveWSTimeoutHint(runErr, timedOut) + + resp := RunResponse{ + OK: displayErr == nil, + TimedOut: timedOut, + Command: meta.ID, + Program: program, + Argv: append([]string(nil), argv...), + Invocation: invocation, + Stdout: stdout.String(), + Stderr: stderr.String(), + Combined: combineOutput(stdout.String(), stderr.String(), displayErr), + } + if displayErr != nil { + resp.Error = displayErr.Error() + } + + _ = send(wsRunMessage{Type: "result", OK: resp.OK, TimedOut: resp.TimedOut, Error: resp.Error, Data: resp.Combined, Command: resp.Command, Program: resp.Program, Argv: resp.Argv, Invocation: resp.Invocation}) + if displayErr != nil { + _ = conn.Close(websocket.StatusInternalError, "command failed") + } else { + _ = conn.Close(websocket.StatusNormalClosure, "command completed") + } +} + func (a *App) handleIndex(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") _, _ = w.Write([]byte(indexHTML)) @@ -1101,7 +1231,7 @@ func collectCommands(root *redant.Command) []CommandMeta { effective = append(effective, inherited...) effective = append(effective, cmd.Options...) - if cmd.Handler != nil && len(path) > 0 && path[0] != "web" { + if (cmd.Handler != nil || cmd.ResponseHandler != nil || cmd.ResponseStreamHandler != nil) && len(path) > 0 && path[0] != "web" { out = append(out, toCommandMeta(cmd, path, effective)) } @@ -1120,18 +1250,19 @@ func collectCommands(root *redant.Command) []CommandMeta { func toCommandMeta(cmd *redant.Command, path []string, opts redant.OptionSet) CommandMeta { return CommandMeta{ - ID: strings.Join(path, " "), - Name: cmd.Name(), - Use: cmd.Use, - Aliases: append([]string(nil), cmd.Aliases...), - Short: strings.TrimSpace(cmd.Short), - Long: strings.TrimSpace(cmd.Long), - Deprecated: strings.TrimSpace(cmd.Deprecated), - RawArgs: cmd.RawArgs, - Path: append([]string(nil), path...), - Description: commandDescription(cmd), - Flags: toFlagMeta(opts), - Args: toArgMeta(cmd.Args), + ID: strings.Join(path, " "), + Name: cmd.Name(), + Use: cmd.Use, + Aliases: append([]string(nil), cmd.Aliases...), + Short: strings.TrimSpace(cmd.Short), + Long: strings.TrimSpace(cmd.Long), + Deprecated: strings.TrimSpace(cmd.Deprecated), + RawArgs: cmd.RawArgs, + Path: append([]string(nil), path...), + Description: commandDescription(cmd), + Flags: toFlagMeta(opts), + Args: toArgMeta(cmd.Args), + SupportsStream: cmd.ResponseStreamHandler != nil, } } diff --git a/internal/webui/server_test.go b/internal/webui/server_test.go index a24aa9f..13f18f9 100644 --- a/internal/webui/server_test.go +++ b/internal/webui/server_test.go @@ -117,6 +117,46 @@ func TestCommandsEndpoint(t *testing.T) { } } +func TestCommandsEndpointIncludesStreamMetadata(t *testing.T) { + root := &redant.Command{Use: "testapp"} + root.Children = append(root.Children, &redant.Command{ + Use: "chat", + ResponseStreamHandler: redant.Stream(func(ctx context.Context, inv *redant.Invocation, out *redant.TypedWriter[string]) error { + return out.Send("hello") + }), + }) + + ts := httptest.NewServer(New(root).Handler()) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/api/commands") + if err != nil { + t.Fatalf("request commands: %v", err) + } + defer closeResponseBody(t, resp) + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + + var payload commandListResponse + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + t.Fatalf("decode commands response: %v", err) + } + + if len(payload.Commands) != 1 { + t.Fatalf("expected 1 command, got %d", len(payload.Commands)) + } + + cmd := payload.Commands[0] + if cmd.ID != "chat" { + t.Fatalf("expected chat command, got %s", cmd.ID) + } + if !cmd.SupportsStream { + t.Fatalf("expected supportsStream=true") + } +} + func TestIndexPageServedFromStatic(t *testing.T) { root := &redant.Command{Use: "testapp"} root.Children = append(root.Children, &redant.Command{Use: "echo", Handler: func(ctx context.Context, inv *redant.Invocation) error { @@ -463,6 +503,66 @@ func TestRunWSEndpointInteractive(t *testing.T) { } } +func TestRunStreamWSEndpoint(t *testing.T) { + root := &redant.Command{Use: "testapp"} + root.Children = append(root.Children, &redant.Command{ + Use: "chat", + ResponseStreamHandler: redant.Stream(func(ctx context.Context, inv *redant.Invocation, out *redant.TypedWriter[string]) error { + if err := out.Send("hello"); err != nil { + return err + } + return out.Send("done") + }), + }) + + ts := httptest.NewServer(New(root).Handler()) + defer ts.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/run/stream/ws" + conn, _, err := websocket.Dial(ctx, wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { _ = conn.Close(websocket.StatusNormalClosure, "done") }() + + if err := wsjson.Write(ctx, conn, wsRunMessage{Type: "start", Request: &RunRequest{Command: "chat"}}); err != nil { + t.Fatalf("write start message: %v", err) + } + + sawStarted := false + sawStream := false + sawResult := false + + for !sawResult { + var msg wsRunMessage + if err := wsjson.Read(ctx, conn, &msg); err != nil { + t.Fatalf("read ws message: %v", err) + } + + switch msg.Type { + case "started": + sawStarted = true + case "stream": + sawStream = true + case "result": + sawResult = true + if !msg.OK { + t.Fatalf("expected ok result, got error=%s", msg.Error) + } + } + } + + if !sawStarted { + t.Fatalf("expected started message") + } + if !sawStream { + t.Fatalf("expected at least one stream message") + } +} + func TestTerminalWSEndpointStartAndClose(t *testing.T) { root := &redant.Command{Use: "testapp"} root.Children = append(root.Children, &redant.Command{ diff --git a/internal/webui/static/index.html b/internal/webui/static/index.html index 2557dac..a8bf86a 100644 --- a/internal/webui/static/index.html +++ b/internal/webui/static/index.html @@ -342,12 +342,21 @@

Args

一次性运行(HTTP)

- +
+ + +
@@ -367,6 +376,8 @@

一次性运行(HTTP)

会展示 curl 风格请求与完整 CLI 调用过程
+
流式运行会连接 + /api/run/stream/ws,实时展示结构化事件(stream/result)。
@@ -375,7 +386,8 @@

可交互运行(WebSocket + PTY)