Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions desktop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"embed"
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
Expand All @@ -12,6 +13,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"

"github.com/wailsapp/wails/v2"
Expand Down Expand Up @@ -58,6 +60,7 @@ type App struct {
ctx context.Context
twirpHandler api.TwirpServer
configurationWatcher *api.ConfigurationWatcher
activeStreams sync.Map // streamID -> context.CancelFunc
}

// NewApp creates a new App application struct
Expand Down Expand Up @@ -308,6 +311,63 @@ func (a *App) targetTwirp(target string, method string, req []byte, headers map[
}, nil
}

// TargetServerStream starts a server-streaming gRPC call.
// Each response message is emitted as a Wails event "stream:<streamID>" with base64-encoded body.
// When the stream ends, "stream:<streamID>:end" is emitted.
// On error, "stream:<streamID>:error" is emitted with the error message.
func (a *App) TargetServerStream(target string, method string, req []byte, headersJson string, streamID string) error {
slog.Info("TargetServerStream called", "target", target, "method", method, "streamID", streamID)

if req == nil {
return fmt.Errorf("nil request")
}

headers := make(map[string]string)
if headersJson != "" && headersJson != "{}" {
if err := json.Unmarshal([]byte(headersJson), &headers); err != nil {
return fmt.Errorf("failed to parse headers: %w", err)
}
}

client, err := grpc.NewClientFromString(target)
if err != nil {
return fmt.Errorf("failed to create gRPC client: %w", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
a.activeStreams.Store(streamID, cancel)

messages, errc := client.ServerStream(ctx, method, req, headers)

go func() {
defer cancel()
defer a.activeStreams.Delete(streamID)

for msg := range messages {
encoded := base64.StdEncoding.EncodeToString(msg)
runtime.EventsEmit(a.ctx, "stream:"+streamID, encoded)
}

if err := <-errc; err != nil {
slog.Error("Server stream error", "streamID", streamID, "error", err)
runtime.EventsEmit(a.ctx, "stream:"+streamID+":error", err.Error())
} else {
runtime.EventsEmit(a.ctx, "stream:"+streamID+":end")
}
}()

return nil
}

// CancelStream cancels an active streaming call.
func (a *App) CancelStream(streamID string) error {
if cancel, ok := a.activeStreams.LoadAndDelete(streamID); ok {
cancel.(context.CancelFunc)()
return nil
}
return fmt.Errorf("stream not found: %s", streamID)
}

func main() {
// Get user's home directory and use ~/.kaja for config
homeDir, err := os.UserHomeDir()
Expand Down
78 changes: 78 additions & 0 deletions server/pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ package grpc
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net/url"
"strings"
"time"
Expand Down Expand Up @@ -167,3 +169,79 @@ func (c *Client) InvokeWithTimeout(method string, request []byte, timeout time.D
defer cancel()
return c.Invoke(ctx, method, request, headers)
}

// ServerStream opens a server-streaming gRPC call.
// It sends a single request and returns a channel that yields response messages.
// The channel is closed when the stream ends. Errors are sent on the error channel.
func (c *Client) ServerStream(ctx context.Context, method string, request []byte, headers map[string]string) (<-chan []byte, <-chan error) {
messages := make(chan []byte, 16)
errc := make(chan error, 1)

go func() {
defer close(messages)
defer close(errc)

if !strings.HasPrefix(method, "/") {
method = "/" + method
}

var creds credentials.TransportCredentials
if c.useTLS {
creds = credentials.NewTLS(&tls.Config{})
} else {
creds = insecure.NewCredentials()
}

codec := &grpcCodec{}
conn, err := grpc.NewClient(c.target, grpc.WithTransportCredentials(creds), grpc.WithDefaultCallOptions(grpc.ForceCodec(codec)))
if err != nil {
errc <- fmt.Errorf("failed to create gRPC client: %w", err)
return
}
defer conn.Close()

if len(headers) > 0 {
md := metadata.New(headers)
ctx = metadata.NewOutgoingContext(ctx, md)
}

streamDesc := &grpc.StreamDesc{
ServerStreams: true,
}

stream, err := conn.NewStream(ctx, streamDesc, method)
if err != nil {
errc <- fmt.Errorf("failed to open stream: %w", err)
return
}

if err := stream.SendMsg(request); err != nil {
errc <- fmt.Errorf("failed to send request: %w", err)
return
}

if err := stream.CloseSend(); err != nil {
errc <- fmt.Errorf("failed to close send: %w", err)
return
}

for {
var response []byte
err := stream.RecvMsg(&response)
if err != nil {
if errors.Is(err, io.EOF) || ctx.Err() != nil {
return
}
errc <- fmt.Errorf("stream receive failed: %w", err)
return
}
select {
case messages <- response:
case <-ctx.Done():
return
}
}
}()

return messages, errc
}
53 changes: 49 additions & 4 deletions ui/src/Console.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,21 @@ interface MethodCallRowProps {
}

Console.MethodCallRow = function ({ methodCall, isSelected, onClick, now }: MethodCallRowProps) {
const status = methodCall.error ? "error" : methodCall.output ? "success" : "pending";
const isStreaming = methodCall.streamOutputs !== undefined;
const isStreamActive = isStreaming && !methodCall.streamComplete && !methodCall.error;

const status = methodCall.error ? "error" : isStreamActive ? "streaming" : methodCall.output ? "success" : "pending";

const statusColor = {
pending: "var(--fgColor-muted)",
streaming: "var(--fgColor-accent)",
success: "var(--fgColor-success)",
error: "var(--fgColor-danger)",
}[status];

const statusIcon = {
pending: "○",
streaming: "◉",
success: "●",
error: "●",
}[status];
Expand All @@ -295,6 +300,18 @@ Console.MethodCallRow = function ({ methodCall, isSelected, onClick, now }: Meth
>
{methodId(methodCall.service, methodCall.method)}
</span>
{isStreaming && methodCall.streamOutputs!.length > 0 && (
<span
style={{
color: "var(--fgColor-muted)",
fontSize: 10,
marginLeft: 6,
flexShrink: 0,
}}
>
{methodCall.streamOutputs!.length}
</span>
)}
<span
style={{
color: "var(--fgColor-muted)",
Expand All @@ -318,6 +335,9 @@ interface DetailTabsProps {
}

Console.DetailTabs = function ({ methodCall, activeTab, onTabChange }: DetailTabsProps) {
const isStreaming = methodCall.streamOutputs !== undefined;
const streamCount = isStreaming ? methodCall.streamOutputs!.length : 0;

return (
<div style={{ display: "flex" }}>
<div className={`console-tab ${activeTab === "request" ? "active" : ""}`} onClick={() => onTabChange("request")}>
Expand All @@ -330,7 +350,7 @@ Console.DetailTabs = function ({ methodCall, activeTab, onTabChange }: DetailTab
color: methodCall.error ? "var(--fgColor-danger)" : activeTab === "response" ? "var(--fgColor-default)" : "var(--fgColor-muted)",
}}
>
Response
Response{isStreaming && streamCount > 0 ? ` (${streamCount})` : ""}
</div>
<div className={`console-tab ${activeTab === "headers" ? "active" : ""}`} onClick={() => onTabChange("headers")}>
Headers
Expand All @@ -348,7 +368,8 @@ interface DetailContentProps {
}

Console.DetailContent = function ({ methodCall, activeTab, onTabChange, colorMode = "night", jsonViewerRef }: DetailContentProps) {
const hasResponse = methodCall.output !== undefined || methodCall.error !== undefined;
const isStreaming = methodCall.streamOutputs !== undefined;
const hasResponse = methodCall.output !== undefined || methodCall.error !== undefined || (isStreaming && methodCall.streamOutputs!.length > 0);
const hasError = methodCall.error !== undefined;

// Switch to response tab when response arrives
Expand All @@ -362,7 +383,16 @@ Console.DetailContent = function ({ methodCall, activeTab, onTabChange, colorMod
return <Console.HeadersContent methodCall={methodCall} />;
}

const content = activeTab === "request" ? methodCall.input : methodCall.error || methodCall.output;
let content;
if (activeTab === "request") {
content = methodCall.input;
} else if (hasError) {
content = methodCall.error;
} else if (isStreaming) {
content = methodCall.streamOutputs;
} else {
content = methodCall.output;
}

return (
<div
Expand Down Expand Up @@ -402,6 +432,21 @@ Console.DetailContent = function ({ methodCall, activeTab, onTabChange, colorMod
POST {methodCall.url}
</div>
)}
{activeTab === "response" && isStreaming && !hasError && (
<div
style={{
padding: "4px 12px",
fontFamily: "monospace",
fontSize: 11,
color: methodCall.streamComplete ? "var(--fgColor-muted)" : "var(--fgColor-accent)",
borderBottom: "1px solid var(--borderColor-muted)",
}}
>
{methodCall.streamComplete
? `Stream complete - ${methodCall.streamOutputs!.length} message${methodCall.streamOutputs!.length !== 1 ? "s" : ""}`
: `Streaming - ${methodCall.streamOutputs!.length} message${methodCall.streamOutputs!.length !== 1 ? "s" : ""} received...`}
</div>
)}
<JsonViewer ref={jsonViewerRef} value={content} colorMode={colorMode} />
</>
)}
Expand Down
5 changes: 5 additions & 0 deletions ui/src/Sidebar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,11 @@ export function Sidebar({
current={currentMethod === method}
>
{method.name}
{method.serverStreaming && (
<TreeView.TrailingVisual>
<span style={{ fontSize: 10, color: "var(--fgColor-muted)" }}>stream</span>
</TreeView.TrailingVisual>
)}
</TreeView.Item>
);
})}
Expand Down
69 changes: 51 additions & 18 deletions ui/src/client.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { GrpcWebFetchTransport } from "@protobuf-ts/grpcweb-transport";
import { RpcOptions, UnaryCall } from "@protobuf-ts/runtime-rpc";
import type { RpcOptions, ServerStreamingCall, UnaryCall } from "@protobuf-ts/runtime-rpc";
import { TwirpFetchTransport } from "@protobuf-ts/twirp-transport";
import { MethodCall } from "./kaja";
import { Client, ProjectRef, Service, serviceId } from "./project";
Expand Down Expand Up @@ -64,6 +64,8 @@ export function createClient(service: Service, stub: Stub, projectRef: ProjectRe
};

for (const method of service.methods) {
const isServerStreaming = method.serverStreaming && !method.clientStreaming;

client.methods[method.name] = async (input: any) => {
// Capture request headers from projectRef at request time
const requestHeaders: { [key: string]: string } = { ...(projectRef.configuration.headers || {}) };
Expand All @@ -81,26 +83,57 @@ export function createClient(service: Service, stub: Stub, projectRef: ProjectRe

try {
const call = clientStub[lcfirst(method.name)](input, options);
const [response, headers, trailers] = await Promise.all([call.response, call.headers, call.trailers]);
methodCall.output = response;
methodCall.inputTypeName = call.method?.I?.typeName;
methodCall.inputType = call.method?.I;
methodCall.outputTypeName = call.method?.O?.typeName;
methodCall.outputType = call.method?.O;

// Capture response headers and trailers
const responseHeaders: { [key: string]: string } = {};
if (headers) {
for (const [key, value] of Object.entries(headers)) {
responseHeaders[key] = String(value);

if (isServerStreaming) {
const streamCall = call as ServerStreamingCall<any, any>;
methodCall.inputTypeName = streamCall.method?.I?.typeName;
methodCall.inputType = streamCall.method?.I;
methodCall.outputTypeName = streamCall.method?.O?.typeName;
methodCall.outputType = streamCall.method?.O;
methodCall.streamOutputs = [];

for await (const message of streamCall.responses) {
methodCall.streamOutputs.push(message);
methodCall.output = message;
client.kaja?._internal.methodCallUpdate(methodCall);
}
}
if (trailers) {
for (const [key, value] of Object.entries(trailers)) {
responseHeaders[key] = String(value);
methodCall.streamComplete = true;

const [headers, trailers] = await Promise.all([streamCall.headers, streamCall.trailers]);
const responseHeaders: { [key: string]: string } = {};
if (headers) {
for (const [key, value] of Object.entries(headers)) {
responseHeaders[key] = String(value);
}
}
if (trailers) {
for (const [key, value] of Object.entries(trailers)) {
responseHeaders[key] = String(value);
}
}
methodCall.responseHeaders = responseHeaders;
} else {
const [response, headers, trailers] = await Promise.all([call.response, call.headers, call.trailers]);
methodCall.output = response;
methodCall.inputTypeName = call.method?.I?.typeName;
methodCall.inputType = call.method?.I;
methodCall.outputTypeName = call.method?.O?.typeName;
methodCall.outputType = call.method?.O;

// Capture response headers and trailers
const responseHeaders: { [key: string]: string } = {};
if (headers) {
for (const [key, value] of Object.entries(headers)) {
responseHeaders[key] = String(value);
}
}
if (trailers) {
for (const [key, value] of Object.entries(trailers)) {
responseHeaders[key] = String(value);
}
}
methodCall.responseHeaders = responseHeaders;
}
methodCall.responseHeaders = responseHeaders;
} catch (error: any) {
methodCall.error = serializeError(error);
}
Expand Down
Loading