From 6c3c0b34b185ca0811f33a7d17c1076324afee13 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 17 Feb 2026 08:03:47 +0000 Subject: [PATCH] Added gRPC server streaming support for desktop/Wails MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Added ServerStream method to Go gRPC client using raw byte codec and grpc.NewStream - Added TargetServerStream and CancelStream Wails bindings that emit response chunks via Wails events - Implemented serverStreaming() in WailsTransport using RpcOutputStreamController and EventsOn - Extended Method interface with serverStreaming/clientStreaming flags from protobuf-ts MethodInfo - Updated client.ts to handle server streaming via async iteration of responses - Added streaming UI: status indicator (◉), message count, stream progress bar in Console - Added "stream" label on streaming methods in Sidebar TreeView https://claude.ai/code/session_01EDaG2s7TQn7gmdBzRo6sDK --- desktop/main.go | 60 +++++++++++++++++++ server/pkg/grpc/client.go | 78 +++++++++++++++++++++++++ ui/src/Console.tsx | 53 +++++++++++++++-- ui/src/Sidebar.tsx | 5 ++ ui/src/client.ts | 69 ++++++++++++++++------ ui/src/kaja.ts | 2 + ui/src/project.ts | 2 + ui/src/projectLoader.ts | 2 + ui/src/server/wails-transport.ts | 98 ++++++++++++++++++++++++++++++-- ui/src/wailsjs/go/main/App.d.ts | 4 ++ ui/src/wailsjs/go/main/App.js | 8 +++ 11 files changed, 355 insertions(+), 26 deletions(-) diff --git a/desktop/main.go b/desktop/main.go index 3c39ce67..d70efb5e 100644 --- a/desktop/main.go +++ b/desktop/main.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "embed" + "encoding/base64" "encoding/json" "fmt" "log/slog" @@ -12,6 +13,7 @@ import ( "os" "path/filepath" "strings" + "sync" "time" "github.com/wailsapp/wails/v2" @@ -58,6 +60,7 @@ type App struct { ctx context.Context twirpHandler api.TwirpServer configurationWatcher *api.ConfigurationWatcher + activeStreams sync.Map // streamID -> context.CancelFunc } // NewApp creates a new App application struct @@ -308,6 +311,63 @@ func (a *App) targetTwirp(target string, method string, req []byte, headers map[ }, nil } +// TargetServerStream starts a server-streaming gRPC call. +// Each response message is emitted as a Wails event "stream:" with base64-encoded body. +// When the stream ends, "stream::end" is emitted. +// On error, "stream::error" is emitted with the error message. +func (a *App) TargetServerStream(target string, method string, req []byte, headersJson string, streamID string) error { + slog.Info("TargetServerStream called", "target", target, "method", method, "streamID", streamID) + + if req == nil { + return fmt.Errorf("nil request") + } + + headers := make(map[string]string) + if headersJson != "" && headersJson != "{}" { + if err := json.Unmarshal([]byte(headersJson), &headers); err != nil { + return fmt.Errorf("failed to parse headers: %w", err) + } + } + + client, err := grpc.NewClientFromString(target) + if err != nil { + return fmt.Errorf("failed to create gRPC client: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + a.activeStreams.Store(streamID, cancel) + + messages, errc := client.ServerStream(ctx, method, req, headers) + + go func() { + defer cancel() + defer a.activeStreams.Delete(streamID) + + for msg := range messages { + encoded := base64.StdEncoding.EncodeToString(msg) + runtime.EventsEmit(a.ctx, "stream:"+streamID, encoded) + } + + if err := <-errc; err != nil { + slog.Error("Server stream error", "streamID", streamID, "error", err) + runtime.EventsEmit(a.ctx, "stream:"+streamID+":error", err.Error()) + } else { + runtime.EventsEmit(a.ctx, "stream:"+streamID+":end") + } + }() + + return nil +} + +// CancelStream cancels an active streaming call. +func (a *App) CancelStream(streamID string) error { + if cancel, ok := a.activeStreams.LoadAndDelete(streamID); ok { + cancel.(context.CancelFunc)() + return nil + } + return fmt.Errorf("stream not found: %s", streamID) +} + func main() { // Get user's home directory and use ~/.kaja for config homeDir, err := os.UserHomeDir() diff --git a/server/pkg/grpc/client.go b/server/pkg/grpc/client.go index d504868c..2b8e14b1 100644 --- a/server/pkg/grpc/client.go +++ b/server/pkg/grpc/client.go @@ -4,7 +4,9 @@ package grpc import ( "context" "crypto/tls" + "errors" "fmt" + "io" "net/url" "strings" "time" @@ -167,3 +169,79 @@ func (c *Client) InvokeWithTimeout(method string, request []byte, timeout time.D defer cancel() return c.Invoke(ctx, method, request, headers) } + +// ServerStream opens a server-streaming gRPC call. +// It sends a single request and returns a channel that yields response messages. +// The channel is closed when the stream ends. Errors are sent on the error channel. +func (c *Client) ServerStream(ctx context.Context, method string, request []byte, headers map[string]string) (<-chan []byte, <-chan error) { + messages := make(chan []byte, 16) + errc := make(chan error, 1) + + go func() { + defer close(messages) + defer close(errc) + + if !strings.HasPrefix(method, "/") { + method = "/" + method + } + + var creds credentials.TransportCredentials + if c.useTLS { + creds = credentials.NewTLS(&tls.Config{}) + } else { + creds = insecure.NewCredentials() + } + + codec := &grpcCodec{} + conn, err := grpc.NewClient(c.target, grpc.WithTransportCredentials(creds), grpc.WithDefaultCallOptions(grpc.ForceCodec(codec))) + if err != nil { + errc <- fmt.Errorf("failed to create gRPC client: %w", err) + return + } + defer conn.Close() + + if len(headers) > 0 { + md := metadata.New(headers) + ctx = metadata.NewOutgoingContext(ctx, md) + } + + streamDesc := &grpc.StreamDesc{ + ServerStreams: true, + } + + stream, err := conn.NewStream(ctx, streamDesc, method) + if err != nil { + errc <- fmt.Errorf("failed to open stream: %w", err) + return + } + + if err := stream.SendMsg(request); err != nil { + errc <- fmt.Errorf("failed to send request: %w", err) + return + } + + if err := stream.CloseSend(); err != nil { + errc <- fmt.Errorf("failed to close send: %w", err) + return + } + + for { + var response []byte + err := stream.RecvMsg(&response) + if err != nil { + if errors.Is(err, io.EOF) || ctx.Err() != nil { + return + } + errc <- fmt.Errorf("stream receive failed: %w", err) + return + } + select { + case messages <- response: + case <-ctx.Done(): + return + } + } + }() + + return messages, errc +} diff --git a/ui/src/Console.tsx b/ui/src/Console.tsx index 50aec560..d4ada496 100644 --- a/ui/src/Console.tsx +++ b/ui/src/Console.tsx @@ -267,16 +267,21 @@ interface MethodCallRowProps { } Console.MethodCallRow = function ({ methodCall, isSelected, onClick, now }: MethodCallRowProps) { - const status = methodCall.error ? "error" : methodCall.output ? "success" : "pending"; + const isStreaming = methodCall.streamOutputs !== undefined; + const isStreamActive = isStreaming && !methodCall.streamComplete && !methodCall.error; + + const status = methodCall.error ? "error" : isStreamActive ? "streaming" : methodCall.output ? "success" : "pending"; const statusColor = { pending: "var(--fgColor-muted)", + streaming: "var(--fgColor-accent)", success: "var(--fgColor-success)", error: "var(--fgColor-danger)", }[status]; const statusIcon = { pending: "○", + streaming: "◉", success: "●", error: "●", }[status]; @@ -295,6 +300,18 @@ Console.MethodCallRow = function ({ methodCall, isSelected, onClick, now }: Meth > {methodId(methodCall.service, methodCall.method)} + {isStreaming && methodCall.streamOutputs!.length > 0 && ( + + {methodCall.streamOutputs!.length} + + )}
onTabChange("request")}> @@ -330,7 +350,7 @@ Console.DetailTabs = function ({ methodCall, activeTab, onTabChange }: DetailTab color: methodCall.error ? "var(--fgColor-danger)" : activeTab === "response" ? "var(--fgColor-default)" : "var(--fgColor-muted)", }} > - Response + Response{isStreaming && streamCount > 0 ? ` (${streamCount})` : ""}
onTabChange("headers")}> Headers @@ -348,7 +368,8 @@ interface DetailContentProps { } Console.DetailContent = function ({ methodCall, activeTab, onTabChange, colorMode = "night", jsonViewerRef }: DetailContentProps) { - const hasResponse = methodCall.output !== undefined || methodCall.error !== undefined; + const isStreaming = methodCall.streamOutputs !== undefined; + const hasResponse = methodCall.output !== undefined || methodCall.error !== undefined || (isStreaming && methodCall.streamOutputs!.length > 0); const hasError = methodCall.error !== undefined; // Switch to response tab when response arrives @@ -362,7 +383,16 @@ Console.DetailContent = function ({ methodCall, activeTab, onTabChange, colorMod return ; } - const content = activeTab === "request" ? methodCall.input : methodCall.error || methodCall.output; + let content; + if (activeTab === "request") { + content = methodCall.input; + } else if (hasError) { + content = methodCall.error; + } else if (isStreaming) { + content = methodCall.streamOutputs; + } else { + content = methodCall.output; + } return (
)} + {activeTab === "response" && isStreaming && !hasError && ( +
+ {methodCall.streamComplete + ? `Stream complete - ${methodCall.streamOutputs!.length} message${methodCall.streamOutputs!.length !== 1 ? "s" : ""}` + : `Streaming - ${methodCall.streamOutputs!.length} message${methodCall.streamOutputs!.length !== 1 ? "s" : ""} received...`} +
+ )} )} diff --git a/ui/src/Sidebar.tsx b/ui/src/Sidebar.tsx index aadc26ad..370fd5c5 100644 --- a/ui/src/Sidebar.tsx +++ b/ui/src/Sidebar.tsx @@ -425,6 +425,11 @@ export function Sidebar({ current={currentMethod === method} > {method.name} + {method.serverStreaming && ( + + stream + + )} ); })} diff --git a/ui/src/client.ts b/ui/src/client.ts index 5bfaf655..57a3ea84 100644 --- a/ui/src/client.ts +++ b/ui/src/client.ts @@ -1,5 +1,5 @@ import { GrpcWebFetchTransport } from "@protobuf-ts/grpcweb-transport"; -import { RpcOptions, UnaryCall } from "@protobuf-ts/runtime-rpc"; +import type { RpcOptions, ServerStreamingCall, UnaryCall } from "@protobuf-ts/runtime-rpc"; import { TwirpFetchTransport } from "@protobuf-ts/twirp-transport"; import { MethodCall } from "./kaja"; import { Client, ProjectRef, Service, serviceId } from "./project"; @@ -64,6 +64,8 @@ export function createClient(service: Service, stub: Stub, projectRef: ProjectRe }; for (const method of service.methods) { + const isServerStreaming = method.serverStreaming && !method.clientStreaming; + client.methods[method.name] = async (input: any) => { // Capture request headers from projectRef at request time const requestHeaders: { [key: string]: string } = { ...(projectRef.configuration.headers || {}) }; @@ -81,26 +83,57 @@ export function createClient(service: Service, stub: Stub, projectRef: ProjectRe try { const call = clientStub[lcfirst(method.name)](input, options); - const [response, headers, trailers] = await Promise.all([call.response, call.headers, call.trailers]); - methodCall.output = response; - methodCall.inputTypeName = call.method?.I?.typeName; - methodCall.inputType = call.method?.I; - methodCall.outputTypeName = call.method?.O?.typeName; - methodCall.outputType = call.method?.O; - - // Capture response headers and trailers - const responseHeaders: { [key: string]: string } = {}; - if (headers) { - for (const [key, value] of Object.entries(headers)) { - responseHeaders[key] = String(value); + + if (isServerStreaming) { + const streamCall = call as ServerStreamingCall; + methodCall.inputTypeName = streamCall.method?.I?.typeName; + methodCall.inputType = streamCall.method?.I; + methodCall.outputTypeName = streamCall.method?.O?.typeName; + methodCall.outputType = streamCall.method?.O; + methodCall.streamOutputs = []; + + for await (const message of streamCall.responses) { + methodCall.streamOutputs.push(message); + methodCall.output = message; + client.kaja?._internal.methodCallUpdate(methodCall); } - } - if (trailers) { - for (const [key, value] of Object.entries(trailers)) { - responseHeaders[key] = String(value); + methodCall.streamComplete = true; + + const [headers, trailers] = await Promise.all([streamCall.headers, streamCall.trailers]); + const responseHeaders: { [key: string]: string } = {}; + if (headers) { + for (const [key, value] of Object.entries(headers)) { + responseHeaders[key] = String(value); + } + } + if (trailers) { + for (const [key, value] of Object.entries(trailers)) { + responseHeaders[key] = String(value); + } + } + methodCall.responseHeaders = responseHeaders; + } else { + const [response, headers, trailers] = await Promise.all([call.response, call.headers, call.trailers]); + methodCall.output = response; + methodCall.inputTypeName = call.method?.I?.typeName; + methodCall.inputType = call.method?.I; + methodCall.outputTypeName = call.method?.O?.typeName; + methodCall.outputType = call.method?.O; + + // Capture response headers and trailers + const responseHeaders: { [key: string]: string } = {}; + if (headers) { + for (const [key, value] of Object.entries(headers)) { + responseHeaders[key] = String(value); + } + } + if (trailers) { + for (const [key, value] of Object.entries(trailers)) { + responseHeaders[key] = String(value); + } } + methodCall.responseHeaders = responseHeaders; } - methodCall.responseHeaders = responseHeaders; } catch (error: any) { methodCall.error = serializeError(error); } diff --git a/ui/src/kaja.ts b/ui/src/kaja.ts index 4f39bd83..7d3bdc6f 100644 --- a/ui/src/kaja.ts +++ b/ui/src/kaja.ts @@ -24,6 +24,8 @@ export interface MethodCall { output?: any; outputTypeName?: string; outputType?: IMessageType; + streamOutputs?: any[]; + streamComplete?: boolean; error?: any; requestHeaders?: MethodCallHeaders; responseHeaders?: MethodCallHeaders; diff --git a/ui/src/project.ts b/ui/src/project.ts index b9e8961d..12fbdb02 100644 --- a/ui/src/project.ts +++ b/ui/src/project.ts @@ -48,6 +48,8 @@ export interface Service { export interface Method { name: string; + serverStreaming?: boolean; + clientStreaming?: boolean; } export interface Clients { diff --git a/ui/src/projectLoader.ts b/ui/src/projectLoader.ts index 1c4bf5fe..e4690c86 100644 --- a/ui/src/projectLoader.ts +++ b/ui/src/projectLoader.ts @@ -47,6 +47,8 @@ export async function loadProject(apiSources: ApiSource[], stubCode: string, con serviceInfo.methods.forEach((methodInfo) => { methods.push({ name: methodInfo.name, + serverStreaming: methodInfo.serverStreaming, + clientStreaming: methodInfo.clientStreaming, }); }); // Extract package name from typeName (e.g., "quirks.v1.Quirks" -> "quirks.v1") diff --git a/ui/src/server/wails-transport.ts b/ui/src/server/wails-transport.ts index b5d1cd0f..cb4ea9ec 100644 --- a/ui/src/server/wails-transport.ts +++ b/ui/src/server/wails-transport.ts @@ -6,11 +6,16 @@ import type { RpcOptions, RpcStatus, RpcTransport, - ServerStreamingCall, UnaryCall, } from "@protobuf-ts/runtime-rpc"; -import { UnaryCall as UnaryCallImpl } from "@protobuf-ts/runtime-rpc"; -import { Twirp, Target } from "../wailsjs/go/main/App"; +import { + Deferred, + RpcOutputStreamController, + ServerStreamingCall, + UnaryCall as UnaryCallImpl, +} from "@protobuf-ts/runtime-rpc"; +import { Twirp, Target, TargetServerStream, CancelStream } from "../wailsjs/go/main/App"; +import { EventsOn } from "../wailsjs/runtime"; import { RpcProtocol } from "./api"; import { ProjectRef } from "../project"; @@ -57,7 +62,92 @@ export class WailsTransport implements RpcTransport { } serverStreaming(method: MethodInfo, input: I, options: RpcOptions): ServerStreamingCall { - throw new Error(`Server streaming not supported in Wails ${this.mode} transport`); + if (this.mode !== "target" || this.protocol !== RpcProtocol.GRPC) { + throw new Error(`Server streaming only supported for gRPC targets in Wails transport`); + } + + const streamID = crypto.randomUUID(); + const responseStream = new RpcOutputStreamController(); + const headersDeferred = new Deferred(); + const statusDeferred = new Deferred(); + const trailersDeferred = new Deferred(); + + // Resolve headers immediately (gRPC headers arrive before messages, but we don't capture them yet) + headersDeferred.resolve({}); + + const unsubscribers: (() => void)[] = []; + + const cleanup = () => { + for (const unsub of unsubscribers) { + unsub(); + } + }; + + // Listen for streamed response messages + unsubscribers.push( + EventsOn("stream:" + streamID, (base64Data: string) => { + try { + const responseBytes = Uint8Array.from(atob(base64Data), (c) => c.charCodeAt(0)); + const message = method.O.fromBinary(responseBytes); + responseStream.notifyMessage(message); + } catch (err) { + responseStream.notifyError(err instanceof Error ? err : new Error(String(err))); + cleanup(); + } + }), + ); + + // Listen for stream end + unsubscribers.push( + EventsOn("stream:" + streamID + ":end", () => { + responseStream.notifyComplete(); + statusDeferred.resolve({ code: "OK", detail: "" }); + trailersDeferred.resolve({}); + cleanup(); + }), + ); + + // Listen for stream error + unsubscribers.push( + EventsOn("stream:" + streamID + ":error", (errorMessage: string) => { + const err = new Error(errorMessage); + responseStream.notifyError(err); + statusDeferred.reject(err); + trailersDeferred.reject(err); + cleanup(); + }), + ); + + // Start the stream + const inputBytes = method.I.toBinary(input, { writeUnknownFields: false }); + const inputArray = Array.from(inputBytes); + const fullMethodPath = `${method.service.typeName}/${method.name}`; + const headersJson = JSON.stringify(this.projectRef!.configuration.headers || {}); + + TargetServerStream(this.projectRef!.configuration.url, fullMethodPath, inputArray, headersJson, streamID).catch((err) => { + responseStream.notifyError(err instanceof Error ? err : new Error(String(err))); + statusDeferred.reject(err); + trailersDeferred.reject(err); + cleanup(); + }); + + // Handle abort signal + if (options.abort) { + options.abort.addEventListener("abort", () => { + CancelStream(streamID).catch(() => {}); + cleanup(); + }); + } + + return new ServerStreamingCall( + method, + options.meta || {}, + input, + headersDeferred.promise, + responseStream, + statusDeferred.promise, + trailersDeferred.promise, + ); } clientStreaming(method: MethodInfo, options: RpcOptions): ClientStreamingCall { diff --git a/ui/src/wailsjs/go/main/App.d.ts b/ui/src/wailsjs/go/main/App.d.ts index 99745e09..d2b62a59 100755 --- a/ui/src/wailsjs/go/main/App.d.ts +++ b/ui/src/wailsjs/go/main/App.d.ts @@ -2,10 +2,14 @@ // This file is automatically generated. DO NOT EDIT import {main} from '../models'; +export function CancelStream(arg1:string):Promise; + export function CheckForUpdate():Promise; export function OpenDirectoryDialog():Promise; export function Target(arg1:string,arg2:string,arg3:Array,arg4:number,arg5:string):Promise; +export function TargetServerStream(arg1:string,arg2:string,arg3:Array,arg4:string,arg5:string):Promise; + export function Twirp(arg1:string,arg2:Array):Promise>; diff --git a/ui/src/wailsjs/go/main/App.js b/ui/src/wailsjs/go/main/App.js index c567272b..14ee50e3 100755 --- a/ui/src/wailsjs/go/main/App.js +++ b/ui/src/wailsjs/go/main/App.js @@ -2,6 +2,10 @@ // Cynhyrchwyd y ffeil hon yn awtomatig. PEIDIWCH Â MODIWL // This file is automatically generated. DO NOT EDIT +export function CancelStream(arg1) { + return window['go']['main']['App']['CancelStream'](arg1); +} + export function CheckForUpdate() { return window['go']['main']['App']['CheckForUpdate'](); } @@ -14,6 +18,10 @@ export function Target(arg1, arg2, arg3, arg4, arg5) { return window['go']['main']['App']['Target'](arg1, arg2, arg3, arg4, arg5); } +export function TargetServerStream(arg1, arg2, arg3, arg4, arg5) { + return window['go']['main']['App']['TargetServerStream'](arg1, arg2, arg3, arg4, arg5); +} + export function Twirp(arg1, arg2) { return window['go']['main']['App']['Twirp'](arg1, arg2); }