From 16932a703f3c1fd1923e6bec25c8165d00250bef Mon Sep 17 00:00:00 2001 From: Ryan Fowler Date: Mon, 26 Jan 2026 19:05:31 -0800 Subject: [PATCH] Add support for gRPC and protocol buffers --- README.md | 1 + docs/CONFIG.md | 20 ++ docs/USAGE.md | 139 ++++++++++ go.sum | 2 + integration/integration_test.go | 110 ++++++++ internal/cli/app.go | 86 +++++++ internal/complete/complete.go | 2 +- internal/fetch/fetch.go | 91 ++++++- internal/fetch/proto.go | 111 ++++++++ internal/format/protobuf.go | 39 +++ internal/grpc/framing.go | 43 ++++ internal/grpc/framing_test.go | 148 +++++++++++ internal/grpc/headers.go | 19 ++ internal/grpc/status.go | 87 +++++++ internal/grpc/status_test.go | 136 ++++++++++ internal/proto/compile.go | 92 +++++++ internal/proto/compile_test.go | 285 +++++++++++++++++++++ internal/proto/descriptor.go | 29 +++ internal/proto/descriptor_test.go | 112 +++++++++ internal/proto/message.go | 60 +++++ internal/proto/message_test.go | 404 ++++++++++++++++++++++++++++++ internal/proto/schema.go | 144 +++++++++++ internal/proto/schema_test.go | 386 ++++++++++++++++++++++++++++ main.go | 4 + 24 files changed, 2544 insertions(+), 6 deletions(-) create mode 100644 internal/fetch/proto.go create mode 100644 internal/grpc/framing.go create mode 100644 internal/grpc/framing_test.go create mode 100644 internal/grpc/headers.go create mode 100644 internal/grpc/status.go create mode 100644 internal/grpc/status_test.go create mode 100644 internal/proto/compile.go create mode 100644 internal/proto/compile_test.go create mode 100644 internal/proto/descriptor.go create mode 100644 internal/proto/descriptor_test.go create mode 100644 internal/proto/message.go create mode 100644 internal/proto/message_test.go create mode 100644 internal/proto/schema.go create mode 100644 internal/proto/schema_test.go diff --git a/README.md b/README.md index 380667e..ee849fd 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ - **Compression**: automatic gzip and zstd response body decompression - **Authentication**: support for Basic Auth, Bearer Token, and AWS Signature V4 - **Form body**: send multipart or urlencoded form bodies +- **gRPC support**: make gRPC calls with automatic JSON-to-protobuf conversion - **Editor integration**: use an editor to modify the request body - **Configuration**: global and per-host configuration - _and more!_ diff --git a/docs/CONFIG.md b/docs/CONFIG.md index 5d408e4..99c97b0 100644 --- a/docs/CONFIG.md +++ b/docs/CONFIG.md @@ -45,6 +45,7 @@ option = host_specific_value ### Auto-Update Options #### `auto-update` + **Type**: Boolean or duration interval **Default**: `false` (disabled) @@ -66,6 +67,7 @@ auto-update = 1d ### Output Control Options #### `color` / `colour` + **Type**: String **Values**: `auto`, `off`, `on` **Default**: `auto` @@ -84,6 +86,7 @@ color = on ``` #### `format` + **Type**: String **Values**: `auto`, `off`, `on` **Default**: `auto` @@ -102,6 +105,7 @@ format = on ``` #### `image` + **Type**: String **Values**: `auto`, `native`, `off` **Default**: `auto` @@ -120,6 +124,7 @@ image = off ``` #### `no-pager` + **Type**: Boolean **Default**: `false` @@ -134,6 +139,7 @@ no-pager = false ``` #### `silent` + **Type**: Boolean **Default**: `false` @@ -148,6 +154,7 @@ silent = false ``` #### `verbosity` + **Type**: Integer **Values**: `0` or greater **Default**: `0` @@ -171,6 +178,7 @@ verbosity = 3 ### Network Options #### `ca-cert` + **Type**: CA certificate path **Default**: System default @@ -182,6 +190,7 @@ ca-cert = ca-cert.pem ``` #### `dns-server` + **Type**: IP address with optional port, or HTTPS URL **Default**: System default @@ -203,6 +212,7 @@ dns-server = https://dns.google/dns-query ``` #### `proxy` + **Type**: URL **Default**: None @@ -220,6 +230,7 @@ proxy = socks5://localhost:1080 ``` #### `timeout` + **Type**: Number (seconds) **Default**: System default @@ -234,6 +245,7 @@ timeout = 2.5 ``` #### `redirects` + **Type**: Integer **Default**: System default @@ -248,6 +260,7 @@ redirects = 10 ``` #### `http` + **Type**: String **Values**: `1`, `2`, `3` @@ -262,6 +275,7 @@ http = 2 ``` #### `tls` + **Type**: String **Values**: `1.0`, `1.1`, `1.2`, `1.3` **Default**: System default @@ -277,6 +291,7 @@ tls = 1.3 ``` #### `insecure` + **Type**: Boolean **Default**: `false` @@ -291,6 +306,7 @@ insecure = false ``` #### `no-encode` + **Type**: Boolean **Default**: `false` @@ -307,6 +323,7 @@ no-encode = false ### Request Options #### `header` + **Type**: String (name:value format) **Repeatable**: Yes @@ -323,6 +340,7 @@ header = User-Agent: MyApp/1.0 ``` #### `query` + **Type**: String (key=value format) **Repeatable**: Yes @@ -339,6 +357,7 @@ query = sort=name ``` #### `ignore-status` + **Type**: Boolean **Default**: `false` @@ -453,6 +472,7 @@ config file '/home/user/.config/fetch/config': line 15: invalid option: 'invalid ``` Common validation errors include: + - Invalid option names - Invalid values for specific options (e.g., `color = invalid`) - Malformed key=value pairs diff --git a/docs/USAGE.md b/docs/USAGE.md index ff26455..fbbd1e8 100644 --- a/docs/USAGE.md +++ b/docs/USAGE.md @@ -211,6 +211,7 @@ fetch --timeout 2.5 example.com **Flag**: `--dns-server IP[:PORT]|URL` Use a custom DNS server. Can be either: + - IP address with optional port for UDP DNS - HTTPS URL for DNS-over-HTTPS @@ -333,6 +334,7 @@ fetch --colour on example.com Set whether output should be formatted. Options: `auto`, `off`, `on`. Supported formats for automatic formatting and syntax highlighting: + - JSON (`application/json`) - HTML (`text/html`) - XML (`application/xml`, `text/xml`) @@ -519,6 +521,143 @@ fetch -j '{"message": "Hello \"World\""}' example.com fetch -H 'Authorization: Bearer token-with-$pecial-chars' example.com ``` +## gRPC Support + +`fetch` supports making gRPC calls with automatic protocol handling, JSON-to-protobuf conversion, and response formatting. + +### Basic gRPC Request + +**Flag**: `--grpc` + +Enable gRPC mode. This flag: + +- Uses the HTTP/2 protocol +- Sets the method to `POST` +- Adds gRPC headers (`Content-Type: application/grpc+proto`, `TE: trailers`, etc.) +- Applies gRPC framing to the request body +- Handles gRPC framing on the response + +The service and method are specified in the URL path in the format `/package.Service/Method`: + +```sh +fetch https://localhost:50051/mypackage.MyService/MyMethod --grpc --insecure +``` + +### Proto Schema Options + +To enable JSON-to-protobuf conversion for request bodies and rich protobuf formatting for responses, provide a proto schema using one of these options: + +**Flag**: `--proto-file PATH` + +Compile `.proto` file(s) using `protoc`. Requires `protoc` to be installed. Can be specified multiple times or with comma-separated paths. + +```sh +fetch https://localhost:50051/echo.EchoService/Echo \ + --grpc \ + --proto-file service.proto \ + -j '{"message": "hello", "count": 42}' \ + --insecure +``` + +**Flag**: `--proto-desc PATH` + +Use a pre-compiled descriptor set file. This is useful when `protoc` isn't available at runtime or to avoid recompilation. + +Generate a descriptor set with: + +```sh +protoc --descriptor_set_out=service.pb --include_imports service.proto +``` + +Then use it: + +```sh +fetch https://localhost:50051/echo.EchoService/Echo \ + --grpc \ + --proto-desc service.pb \ + -j '{"message": "hello", "count": 42}' \ + --insecure +``` + +**Flag**: `--proto-import PATH` + +Add import paths for proto compilation. Use with `--proto-file` when your proto files have imports. + +```sh +fetch https://localhost:50051/mypackage.MyService/MyMethod \ + --grpc \ + --proto-file service.proto \ + --proto-import ./proto \ + --proto-import /usr/local/include \ + -j '{"field": "value"}' \ + --insecure +``` + +### How It Works + +When `--grpc` is used with a proto schema (`--proto-file` or `--proto-desc`): + +1. The service and method are extracted from the URL path +2. The method's input/output message types are looked up in the schema +3. JSON request bodies are automatically converted to protobuf +4. Protobuf responses are formatted with field names from the schema + +Without a proto schema, `fetch` still handles gRPC framing but: + +- Request bodies must be raw protobuf (not JSON) +- Responses are formatted using generic protobuf parsing (field numbers instead of names) + +### Examples + +**Simple gRPC call with JSON body:** + +```sh +fetch https://api.example.com/package.Service/Method \ + --grpc \ + --proto-file service.proto \ + -j '{"name": "test", "value": 123}' +``` + +**gRPC call with verbose output:** + +```sh +fetch https://api.example.com/package.Service/Method \ + --grpc \ + --proto-file service.proto \ + -j '{"request": "data"}' \ + -vv +``` + +**gRPC call to local server with self-signed certificate:** + +```sh +fetch https://localhost:50051/echo.EchoService/Echo \ + --grpc \ + --proto-desc echo.pb \ + -j '{"message": "hello"}' \ + --insecure +``` + +**Dry run to inspect the request:** + +```sh +fetch https://api.example.com/package.Service/Method \ + --grpc \ + --proto-file service.proto \ + -j '{"field": "value"}' \ + --dry-run +``` + +**Use an editor to modify the request body:** + +```sh +fetch https://api.example.com/package.Service/Method \ + --grpc \ + --proto-file service.proto \ + -j '{"field": "value"}' \ + --edit +``` + ## Advanced Usage ### Combining Options diff --git a/go.sum b/go.sum index fed436c..53779f0 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/clipperhouse/uax29/v2 v2.2.0 h1:ChwIKnQN3kcZteTXMgb1wztSgaU+ZemkgWdoh github.com/clipperhouse/uax29/v2 v2.2.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw= github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= diff --git a/integration/integration_test.go b/integration/integration_test.go index 0112a01..868754d 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -5,6 +5,7 @@ import ( "archive/zip" "bytes" "encoding/base64" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -28,6 +29,7 @@ import ( "github.com/klauspost/compress/gzip" "github.com/klauspost/compress/zstd" + "google.golang.org/protobuf/encoding/protowire" ) func TestMain(t *testing.T) { @@ -949,6 +951,114 @@ func TestMain(t *testing.T) { assertBufNotContains(t, res.stderr, "zstd") }) + t.Run("protobuf response formatting", func(t *testing.T) { + // Build a simple protobuf message: field 1 = 123 (varint), field 2 = "hello" (string) + var protoData []byte + protoData = protowire.AppendTag(protoData, 1, protowire.VarintType) + protoData = protowire.AppendVarint(protoData, 123) + protoData = protowire.AppendTag(protoData, 2, protowire.BytesType) + protoData = protowire.AppendString(protoData, "hello") + + server := startServer(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/protobuf") + w.WriteHeader(200) + w.Write(protoData) + }) + defer server.Close() + + // Without formatting, output is the raw protobuf. + res := runFetch(t, fetchPath, server.URL, "--format", "off") + assertExitCode(t, 0, res) + if !bytes.Equal(res.stdout.Bytes(), protoData) { + t.Fatalf("expected raw protobuf data") + } + + // With formatting, protobuf is parsed and displayed. + res = runFetch(t, fetchPath, server.URL, "--format", "on") + assertExitCode(t, 0, res) + assertBufContains(t, res.stdout, "1:") + assertBufContains(t, res.stdout, "123") + assertBufContains(t, res.stdout, "2:") + assertBufContains(t, res.stdout, "hello") + }) + + t.Run("connect rpc error response", func(t *testing.T) { + // Simulate a Connect RPC error response. + server := startServer(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(404) + io.WriteString(w, `{"code":"not_found","message":"resource not found"}`) + }) + defer server.Close() + + res := runFetch(t, fetchPath, server.URL, "--format", "on") + assertExitCode(t, 4, res) // 4xx status code + assertBufContains(t, res.stdout, "not_found") + assertBufContains(t, res.stdout, "resource not found") + }) + + t.Run("grpc response unframing", func(t *testing.T) { + // Build a gRPC-framed protobuf response. + var protoData []byte + protoData = protowire.AppendTag(protoData, 1, protowire.VarintType) + protoData = protowire.AppendVarint(protoData, 42) + protoData = protowire.AppendTag(protoData, 2, protowire.BytesType) + protoData = protowire.AppendString(protoData, "grpc test") + + // gRPC framing: [compressed:1][length:4][data] + framedData := make([]byte, 5+len(protoData)) + framedData[0] = 0 // not compressed + binary.BigEndian.PutUint32(framedData[1:5], uint32(len(protoData))) + copy(framedData[5:], protoData) + + server := startServer(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/grpc+proto") + w.WriteHeader(200) + w.Write(framedData) + }) + defer server.Close() + + // With formatting, gRPC response is unframed and protobuf is parsed. + res := runFetch(t, fetchPath, server.URL, "--format", "on") + assertExitCode(t, 0, res) + assertBufContains(t, res.stdout, "1:") + assertBufContains(t, res.stdout, "42") + assertBufContains(t, res.stdout, "2:") + assertBufContains(t, res.stdout, "grpc test") + }) + + t.Run("proto flags mutual exclusivity", func(t *testing.T) { + // proto-file and proto-desc cannot be used together + // Create temp files so we get past file existence validation + tmpDir := t.TempDir() + protoFile := filepath.Join(tmpDir, "a.proto") + descFile := filepath.Join(tmpDir, "b.pb") + os.WriteFile(protoFile, []byte("syntax = \"proto3\";"), 0644) + os.WriteFile(descFile, []byte{}, 0644) + + res := runFetch(t, fetchPath, "http://example.com/svc/Method", "--grpc", "--proto-file", protoFile, "--proto-desc", descFile) + assertExitCode(t, 1, res) + assertBufContains(t, res.stderr, "cannot be used together") + }) + + t.Run("proto-file requires protoc", func(t *testing.T) { + // If protoc isn't found, we should get a helpful error. + // This test will only fail if protoc is not installed. + // When protoc IS installed, it should fail because the file doesn't exist. + res := runFetch(t, fetchPath, "http://example.com/svc/Method", "--grpc", "--proto-file", "/nonexistent/file.proto") + assertExitCode(t, 1, res) + // Should either complain about protoc not found or file not found + if !strings.Contains(res.stderr.String(), "protoc") && !strings.Contains(res.stderr.String(), "exist") { + t.Fatalf("expected error about protoc or file not found, got: %s", res.stderr.String()) + } + }) + + t.Run("proto-desc file not found", func(t *testing.T) { + res := runFetch(t, fetchPath, "http://example.com/svc/Method", "--grpc", "--proto-desc", "/nonexistent/file.pb") + assertExitCode(t, 1, res) + assertBufContains(t, res.stderr, "does not exist") + }) + t.Run("update", func(t *testing.T) { var empty string var urlStr atomic.Pointer[string] diff --git a/internal/cli/app.go b/internal/cli/app.go index ef17162..9f24d79 100644 --- a/internal/cli/app.go +++ b/internal/cli/app.go @@ -32,10 +32,14 @@ type App struct { DryRun bool Edit bool Form []core.KeyVal[string] + GRPC bool Help bool Method string Multipart []core.KeyVal[string] Output string + ProtoDesc string + ProtoFiles []string + ProtoImports []string Range []string RemoteHeaderName bool RemoteName bool @@ -102,8 +106,12 @@ func (a *App) CLI() *CLI { {"aws-sigv4", "basic", "bearer"}, {"data", "form", "json", "multipart", "xml"}, {"output", "remote-name"}, + {"proto-file", "proto-desc"}, }, RequiredFlags: []core.KeyVal[[]string]{ + {Key: "proto-desc", Val: []string{"grpc"}}, + {Key: "proto-file", Val: []string{"grpc"}}, + {Key: "proto-import", Val: []string{"proto-file"}}, {Key: "remote-header-name", Val: []string{"remote-name"}}, }, Flags: []Flag{ @@ -395,6 +403,20 @@ func (a *App) CLI() *CLI { return a.Cfg.ParseFormat(value) }, }, + { + Short: "", + Long: "grpc", + Args: "", + Description: "Enable gRPC mode", + Default: "", + IsSet: func() bool { + return a.GRPC + }, + Fn: func(value string) error { + a.GRPC = true + return nil + }, + }, { Short: "H", Long: "header", @@ -625,6 +647,59 @@ func (a *App) CLI() *CLI { return nil }, }, + { + Short: "", + Long: "proto-desc", + Args: "PATH", + Description: "Pre-compiled descriptor set file", + Default: "", + IsSet: func() bool { + return a.ProtoDesc != "" + }, + Fn: func(value string) error { + a.ProtoDesc = value + return checkFileExists(value) + }, + }, + { + Short: "", + Long: "proto-file", + Args: "PATH", + Description: "Compile .proto file(s) via protoc", + Default: "", + IsSet: func() bool { + return len(a.ProtoFiles) > 0 + }, + Fn: func(value string) error { + // Support comma-separated paths. + for p := range strings.SplitSeq(value, ",") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + err := checkFileExists(p) + if err != nil { + return err + } + a.ProtoFiles = append(a.ProtoFiles, p) + } + return nil + }, + }, + { + Short: "", + Long: "proto-import", + Args: "PATH", + Description: "Import path for proto compilation", + Default: "", + IsSet: func() bool { + return len(a.ProtoImports) > 0 + }, + Fn: func(value string) error { + a.ProtoImports = append(a.ProtoImports, value) + return checkFileExists(value) + }, + }, { Short: "", Long: "proxy", @@ -902,6 +977,17 @@ func requestBody(value string) (io.Reader, string, error) { } } +func checkFileExists(value string) error { + _, err := os.Stat(value) + if err == nil { + return nil + } + if os.IsNotExist(err) { + return core.FileNotExistsError(value) + } + return err +} + func cut(s, sep string) (string, string, bool) { key, val, ok := strings.Cut(s, sep) key = strings.TrimSpace(key) diff --git a/internal/complete/complete.go b/internal/complete/complete.go index f876d90..4bebd07 100644 --- a/internal/complete/complete.go +++ b/internal/complete/complete.go @@ -220,7 +220,7 @@ func completeValue(flag cli.Flag, prefix, value string) []core.KeyVal[string] { } switch flag.Long { - case "ca-cert", "config", "output", "unix": + case "ca-cert", "config", "output", "proto-desc", "proto-file", "proto-import", "unix": return completePath(prefix, value) case "data", "json", "xml": path, ok := strings.CutPrefix(value, "@") diff --git a/internal/fetch/fetch.go b/internal/fetch/fetch.go index a461798..f14ce8a 100644 --- a/internal/fetch/fetch.go +++ b/internal/fetch/fetch.go @@ -20,8 +20,12 @@ import ( "github.com/ryanfowler/fetch/internal/client" "github.com/ryanfowler/fetch/internal/core" "github.com/ryanfowler/fetch/internal/format" + fetchgrpc "github.com/ryanfowler/fetch/internal/grpc" "github.com/ryanfowler/fetch/internal/image" "github.com/ryanfowler/fetch/internal/multipart" + "github.com/ryanfowler/fetch/internal/proto" + + "google.golang.org/protobuf/reflect/protoreflect" ) type ContentType int @@ -30,6 +34,7 @@ const ( TypeUnknown ContentType = iota TypeCSS TypeCSV + TypeGRPC TypeHTML TypeImage TypeJSON @@ -53,6 +58,7 @@ type Request struct { Edit bool Form []core.KeyVal[string] Format core.Format + GRPC bool Headers []core.KeyVal[string] HTTP core.HTTPVersion IgnoreStatus bool @@ -64,6 +70,9 @@ type Request struct { Multipart *multipart.Multipart Output string PrinterHandle *core.Handle + ProtoDesc string + ProtoFiles []string + ProtoImports []string Proxy *url.URL QueryParams []core.KeyVal[string] Range []string @@ -75,6 +84,9 @@ type Request struct { UnixSocket string URL *url.URL Verbosity core.Verbosity + + // responseDescriptor is set internally after proto setup for response formatting. + responseDescriptor protoreflect.MessageDescriptor } func Fetch(ctx context.Context, r *Request) int { @@ -96,6 +108,27 @@ func Fetch(ctx context.Context, r *Request) int { } func fetch(ctx context.Context, r *Request) (int, error) { + // 1. Load proto schema if configured. + var schema *proto.Schema + if len(r.ProtoFiles) > 0 || r.ProtoDesc != "" { + var err error + schema, err = loadProtoSchema(r) + if err != nil { + return 0, err + } + } + + // 2. Setup gRPC (adds headers, sets HTTP version, finds descriptors). + var requestDesc protoreflect.MessageDescriptor + if r.GRPC { + var err error + requestDesc, r.responseDescriptor, err = setupGRPC(r, schema) + if err != nil { + return 0, err + } + } + + // 3. Create HTTP client and request. c := client.NewClient(client.ClientConfig{ CACerts: r.CACerts, DNSServer: r.DNSServer, @@ -131,7 +164,7 @@ func fetch(ctx context.Context, r *Request) (int, error) { } }() - // Open an editor to modify the request body, if necessary. + // 4. Edit step (user edits request body). if r.Edit { err = editRequestBody(req) if err != nil { @@ -139,6 +172,30 @@ func fetch(ctx context.Context, r *Request) (int, error) { } } + // 5. Convert JSON to protobuf AFTER edit. + if requestDesc != nil && req.Body != nil && req.Body != http.NoBody { + // Read the body and convert. + converted, err := convertJSONToProtobuf(req.Body, requestDesc) + if err != nil { + return 0, err + } + req.Body = io.NopCloser(converted) + if req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "application/protobuf") + } + } + + // 6. Frame gRPC request AFTER conversion. + // gRPC requires framing even for empty messages. + if r.GRPC { + framed, err := frameGRPCRequest(req.Body) + if err != nil { + return 0, err + } + req.Body = io.NopCloser(framed) + } + + // 7. Print request metadata / dry-run. if r.Verbosity >= core.VExtraVerbose || r.DryRun { errPrinter := r.PrinterHandle.Stderr() printRequestMetadata(errPrinter, req, r.HTTP) @@ -152,12 +209,12 @@ func fetch(ctx context.Context, r *Request) (int, error) { errPrinter.WriteString("\n") errPrinter.Flush() - ok, r, err := isPrintable(req.Body) + ok, rdr, err := isPrintable(req.Body) if err != nil { return 0, err } if ok { - _, err = io.Copy(os.Stderr, r) + _, err = io.Copy(os.Stderr, rdr) return 0, err } @@ -178,6 +235,7 @@ func fetch(ctx context.Context, r *Request) (int, error) { req = req.WithContext(ctx) } + // 8. Make request. return makeRequest(ctx, r, c, req) } @@ -309,8 +367,29 @@ func formatResponse(ctx context.Context, r *Request, resp *http.Response) (io.Re if format.FormatMsgPack(buf, p) == nil { buf = p.Bytes() } + case TypeGRPC: + // Unframe gRPC response before processing. + unframedBuf, _, err := fetchgrpc.Unframe(buf) + if err != nil { + // If unframing fails, try to process as raw protobuf. + unframedBuf = buf + } + if r.responseDescriptor != nil { + err = format.FormatProtobufWithDescriptor(unframedBuf, r.responseDescriptor, p) + } else { + err = format.FormatProtobuf(unframedBuf, p) + } + if err == nil { + buf = p.Bytes() + } case TypeProtobuf: - if format.FormatProtobuf(buf, p) == nil { + var err error + if r.responseDescriptor != nil { + err = format.FormatProtobufWithDescriptor(buf, r.responseDescriptor, p) + } else { + err = format.FormatProtobuf(buf, p) + } + if err == nil { buf = p.Bytes() } case TypeXML: @@ -340,13 +419,15 @@ func getContentType(headers http.Header) ContentType { switch subtype { case "csv": return TypeCSV + case "grpc", "grpc+proto": + return TypeGRPC case "json": return TypeJSON case "msgpack", "x-msgpack", "vnd.msgpack": return TypeMsgPack case "x-ndjson", "ndjson", "x-jsonl", "jsonl", "x-jsonlines": return TypeNDJSON - case "protobuf", "x-protobuf", "grpc+proto", "x-google-protobuf", "vnd.google.protobuf": + case "protobuf", "x-protobuf", "x-google-protobuf", "vnd.google.protobuf": return TypeProtobuf case "xml": return TypeXML diff --git a/internal/fetch/proto.go b/internal/fetch/proto.go new file mode 100644 index 0000000..63d43b5 --- /dev/null +++ b/internal/fetch/proto.go @@ -0,0 +1,111 @@ +package fetch + +import ( + "bytes" + "fmt" + "io" + "net/http" + "strings" + + "github.com/ryanfowler/fetch/internal/core" + fetchgrpc "github.com/ryanfowler/fetch/internal/grpc" + "github.com/ryanfowler/fetch/internal/proto" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +// loadProtoSchema loads schema from files or descriptor set. +func loadProtoSchema(r *Request) (*proto.Schema, error) { + if len(r.ProtoFiles) > 0 { + return proto.CompileProtos(r.ProtoFiles, r.ProtoImports) + } + if r.ProtoDesc != "" { + return proto.LoadDescriptorSetFile(r.ProtoDesc) + } + return nil, nil +} + +// parseGRPCPath extracts service and method names from URL path. +// Expected format: /package.Service/Method +func parseGRPCPath(urlPath string) (serviceName, methodName string, err error) { + path := strings.TrimPrefix(urlPath, "/") + + idx := strings.LastIndex(path, "/") + if idx < 0 { + return "", "", fmt.Errorf("invalid gRPC path: expected '/Service/Method' format") + } + + serviceName = path[:idx] + methodName = path[idx+1:] + + if serviceName == "" || methodName == "" { + return "", "", fmt.Errorf("invalid gRPC path: service and method cannot be empty") + } + + return serviceName, methodName, nil +} + +// setupGRPC configures request for gRPC protocol. +// Returns headers to add, HTTP version, and request/response descriptors. +func setupGRPC(r *Request, schema *proto.Schema) (protoreflect.MessageDescriptor, protoreflect.MessageDescriptor, error) { + var requestDesc, responseDesc protoreflect.MessageDescriptor + if schema != nil && r.URL != nil { + serviceName, methodName, err := parseGRPCPath(r.URL.Path) + if err != nil { + return nil, nil, err + } + + fullMethod := serviceName + "/" + methodName + method, err := schema.FindMethod(fullMethod) + if err != nil { + return nil, nil, err + } + requestDesc = method.Input() + responseDesc = method.Output() + } + + if r.HTTP == core.HTTPDefault { + r.HTTP = core.HTTP2 + } + if r.Method == "" { + r.Method = "POST" + } + r.Headers = append(r.Headers, fetchgrpc.Headers()...) + r.Headers = append(r.Headers, fetchgrpc.AcceptHeader()) + + return requestDesc, responseDesc, nil +} + +// convertJSONToProtobuf converts JSON body to protobuf. +func convertJSONToProtobuf(data io.Reader, desc protoreflect.MessageDescriptor) (io.Reader, error) { + // Read all the JSON data. + jsonData, err := io.ReadAll(data) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + + // Convert JSON to protobuf. + protoData, err := proto.JSONToProtobuf(jsonData, desc) + if err != nil { + return nil, fmt.Errorf("failed to convert JSON to protobuf: %w", err) + } + + return bytes.NewReader(protoData), nil +} + +// frameGRPCRequest wraps data in gRPC framing. +// Handles nil/empty body by sending an empty framed message. +func frameGRPCRequest(data io.Reader) (io.Reader, error) { + var rawData []byte + if data != nil && data != http.NoBody { + var err error + rawData, err = io.ReadAll(data) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + } + + // Frame with gRPC format (works for empty data too). + framedData := fetchgrpc.Frame(rawData, false) + return bytes.NewReader(framedData), nil +} diff --git a/internal/format/protobuf.go b/internal/format/protobuf.go index af3a3a1..6dd425b 100644 --- a/internal/format/protobuf.go +++ b/internal/format/protobuf.go @@ -7,8 +7,13 @@ import ( "unicode/utf8" "github.com/ryanfowler/fetch/internal/core" + fetchproto "github.com/ryanfowler/fetch/internal/proto" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" ) // FormatProtobuf formats the provided raw protobuf data to the Printer. @@ -20,6 +25,40 @@ func FormatProtobuf(buf []byte, p *core.Printer) error { return err } +// FormatProtobufWithSchema formats protobuf data as JSON using the provided schema. +func FormatProtobufWithSchema(buf []byte, schema *fetchproto.Schema, typeName string, p *core.Printer) error { + md, err := schema.FindMessage(typeName) + if err != nil { + return err + } + + return FormatProtobufWithDescriptor(buf, md, p) +} + +// FormatProtobufWithDescriptor formats protobuf data as JSON using a message descriptor. +func FormatProtobufWithDescriptor(buf []byte, md protoreflect.MessageDescriptor, p *core.Printer) error { + msg := dynamicpb.NewMessage(md) + if err := proto.Unmarshal(buf, msg); err != nil { + return err + } + + // Marshal to JSON. + opts := protojson.MarshalOptions{ + Multiline: true, + Indent: " ", + EmitUnpopulated: false, + UseProtoNames: true, + } + + jsonBytes, err := opts.Marshal(msg) + if err != nil { + return err + } + + // Format the JSON with syntax highlighting. + return FormatJSON(jsonBytes, p) +} + func formatProtobuf(buf []byte, p *core.Printer, indent int) error { for len(buf) > 0 { num, wtype, n := protowire.ConsumeTag(buf) diff --git a/internal/grpc/framing.go b/internal/grpc/framing.go new file mode 100644 index 0000000..429eb60 --- /dev/null +++ b/internal/grpc/framing.go @@ -0,0 +1,43 @@ +package grpc + +import ( + "encoding/binary" + "fmt" +) + +// Frame wraps message in gRPC length-prefixed format. +// Format: [compressed:1][length:4][data] +func Frame(data []byte, compressed bool) []byte { + buf := make([]byte, 5+len(data)) + if compressed { + buf[0] = 1 + } else { + buf[0] = 0 + } + binary.BigEndian.PutUint32(buf[1:5], uint32(len(data))) + copy(buf[5:], data) + return buf +} + +// Unframe extracts a gRPC length-prefixed message from the data. +// Returns the message data and whether it was compressed. +func Unframe(data []byte) ([]byte, bool, error) { + if len(data) < 5 { + return nil, false, fmt.Errorf("failed to read gRPC frame header: insufficient data") + } + + compressed := data[0] != 0 + length := binary.BigEndian.Uint32(data[1:5]) + + // Sanity check on length. + const maxMessageSize = 64 * 1024 * 1024 // 64MB + if length > maxMessageSize { + return nil, false, fmt.Errorf("gRPC message too large: %d bytes", length) + } + + if len(data) < 5+int(length) { + return nil, false, fmt.Errorf("failed to read gRPC message: insufficient data") + } + + return data[5 : 5+length], compressed, nil +} diff --git a/internal/grpc/framing_test.go b/internal/grpc/framing_test.go new file mode 100644 index 0000000..2669b15 --- /dev/null +++ b/internal/grpc/framing_test.go @@ -0,0 +1,148 @@ +package grpc + +import ( + "bytes" + "testing" +) + +func TestFrame(t *testing.T) { + tests := []struct { + name string + data []byte + compressed bool + want []byte + }{ + { + name: "empty uncompressed", + data: []byte{}, + compressed: false, + want: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, + }, + { + name: "simple uncompressed", + data: []byte{0x01, 0x02, 0x03}, + compressed: false, + want: []byte{0x00, 0x00, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03}, + }, + { + name: "simple compressed", + data: []byte{0x01, 0x02, 0x03}, + compressed: true, + want: []byte{0x01, 0x00, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03}, + }, + { + name: "larger message", + data: bytes.Repeat([]byte{0xAB}, 256), + compressed: false, + want: append([]byte{0x00, 0x00, 0x00, 0x01, 0x00}, bytes.Repeat([]byte{0xAB}, 256)...), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := Frame(tt.data, tt.compressed) + if !bytes.Equal(got, tt.want) { + t.Errorf("Frame() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnframe(t *testing.T) { + tests := []struct { + name string + input []byte + wantData []byte + wantCompressed bool + wantErr bool + }{ + { + name: "empty message", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, + wantData: []byte{}, + wantCompressed: false, + wantErr: false, + }, + { + name: "simple uncompressed", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03}, + wantData: []byte{0x01, 0x02, 0x03}, + wantCompressed: false, + wantErr: false, + }, + { + name: "simple compressed", + input: []byte{0x01, 0x00, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03}, + wantData: []byte{0x01, 0x02, 0x03}, + wantCompressed: true, + wantErr: false, + }, + { + name: "truncated header", + input: []byte{0x00, 0x00, 0x00}, + wantErr: true, + }, + { + name: "truncated data", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x05, 0x01, 0x02}, // claims 5 bytes, has 2 + wantErr: true, + }, + { + name: "empty input", + input: []byte{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, compressed, err := Unframe(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("Unframe() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if !bytes.Equal(data, tt.wantData) { + t.Errorf("Unframe() data = %v, want %v", data, tt.wantData) + } + if compressed != tt.wantCompressed { + t.Errorf("Unframe() compressed = %v, want %v", compressed, tt.wantCompressed) + } + }) + } +} + +func TestFrameUnframeRoundTrip(t *testing.T) { + testData := [][]byte{ + {}, + {0x00}, + {0x01, 0x02, 0x03, 0x04, 0x05}, + bytes.Repeat([]byte{0xAB}, 1000), + } + + for _, data := range testData { + framed := Frame(data, false) + unframed, compressed, err := Unframe(framed) + if err != nil { + t.Errorf("Unframe() error = %v", err) + continue + } + if compressed { + t.Error("expected uncompressed") + } + if !bytes.Equal(unframed, data) { + t.Errorf("round trip failed: got %v, want %v", unframed, data) + } + } +} + +func TestUnframeLargeMessageRejected(t *testing.T) { + // Create a header claiming a very large message + header := []byte{0x00, 0x10, 0x00, 0x00, 0x00} // 256MB + _, _, err := Unframe(header) + if err == nil { + t.Error("expected error for large message") + } +} diff --git a/internal/grpc/headers.go b/internal/grpc/headers.go new file mode 100644 index 0000000..fba53b8 --- /dev/null +++ b/internal/grpc/headers.go @@ -0,0 +1,19 @@ +package grpc + +import "github.com/ryanfowler/fetch/internal/core" + +// ContentType is the Content-Type header value for gRPC requests. +const ContentType = "application/grpc+proto" + +// Headers returns the standard headers for gRPC requests. +func Headers() []core.KeyVal[string] { + return []core.KeyVal[string]{ + {Key: "Content-Type", Val: ContentType}, + {Key: "Te", Val: "trailers"}, + } +} + +// AcceptHeader returns the Accept header for gRPC requests. +func AcceptHeader() core.KeyVal[string] { + return core.KeyVal[string]{Key: "Accept", Val: ContentType} +} diff --git a/internal/grpc/status.go b/internal/grpc/status.go new file mode 100644 index 0000000..967ee88 --- /dev/null +++ b/internal/grpc/status.go @@ -0,0 +1,87 @@ +package grpc + +import ( + "fmt" + "strconv" +) + +// Status represents a gRPC status. +type Status struct { + Code Code + Message string +} + +func (s *Status) Error() string { + if s.Message != "" { + return fmt.Sprintf("grpc error: %s: %s", s.Code.String(), s.Message) + } + return fmt.Sprintf("grpc error: %s", s.Code.String()) +} + +// OK returns true if the status represents success. +func (s *Status) OK() bool { + return s.Code == OK +} + +// Code represents a gRPC status code. +type Code int + +// gRPC status codes. +const ( + OK Code = 0 + Canceled Code = 1 + Unknown Code = 2 + InvalidArgument Code = 3 + DeadlineExceeded Code = 4 + NotFound Code = 5 + AlreadyExists Code = 6 + PermissionDenied Code = 7 + ResourceExhausted Code = 8 + FailedPrecondition Code = 9 + Aborted Code = 10 + OutOfRange Code = 11 + Unimplemented Code = 12 + Internal Code = 13 + Unavailable Code = 14 + DataLoss Code = 15 + Unauthenticated Code = 16 +) + +var codeNames = map[Code]string{ + OK: "OK", + Canceled: "CANCELED", + Unknown: "UNKNOWN", + InvalidArgument: "INVALID_ARGUMENT", + DeadlineExceeded: "DEADLINE_EXCEEDED", + NotFound: "NOT_FOUND", + AlreadyExists: "ALREADY_EXISTS", + PermissionDenied: "PERMISSION_DENIED", + ResourceExhausted: "RESOURCE_EXHAUSTED", + FailedPrecondition: "FAILED_PRECONDITION", + Aborted: "ABORTED", + OutOfRange: "OUT_OF_RANGE", + Unimplemented: "UNIMPLEMENTED", + Internal: "INTERNAL", + Unavailable: "UNAVAILABLE", + DataLoss: "DATA_LOSS", + Unauthenticated: "UNAUTHENTICATED", +} + +// String returns the name of the status code. +func (c Code) String() string { + if name, ok := codeNames[c]; ok { + return name + } + return fmt.Sprintf("CODE(%d)", c) +} + +// ParseStatus parses gRPC status from HTTP trailers. +func ParseStatus(grpcStatus, grpcMessage string) *Status { + code := Unknown + if grpcStatus != "" { + if n, err := strconv.Atoi(grpcStatus); err == nil { + code = Code(n) + } + } + return &Status{Code: code, Message: grpcMessage} +} diff --git a/internal/grpc/status_test.go b/internal/grpc/status_test.go new file mode 100644 index 0000000..7c50f27 --- /dev/null +++ b/internal/grpc/status_test.go @@ -0,0 +1,136 @@ +package grpc + +import "testing" + +func TestCodeString(t *testing.T) { + tests := []struct { + code Code + want string + }{ + {OK, "OK"}, + {Canceled, "CANCELED"}, + {Unknown, "UNKNOWN"}, + {InvalidArgument, "INVALID_ARGUMENT"}, + {DeadlineExceeded, "DEADLINE_EXCEEDED"}, + {NotFound, "NOT_FOUND"}, + {AlreadyExists, "ALREADY_EXISTS"}, + {PermissionDenied, "PERMISSION_DENIED"}, + {ResourceExhausted, "RESOURCE_EXHAUSTED"}, + {FailedPrecondition, "FAILED_PRECONDITION"}, + {Aborted, "ABORTED"}, + {OutOfRange, "OUT_OF_RANGE"}, + {Unimplemented, "UNIMPLEMENTED"}, + {Internal, "INTERNAL"}, + {Unavailable, "UNAVAILABLE"}, + {DataLoss, "DATA_LOSS"}, + {Unauthenticated, "UNAUTHENTICATED"}, + {Code(100), "CODE(100)"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := tt.code.String() + if got != tt.want { + t.Errorf("Code.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStatusError(t *testing.T) { + tests := []struct { + name string + status *Status + want string + wantOK bool + }{ + { + name: "OK status", + status: &Status{Code: OK}, + want: "grpc error: OK", + wantOK: true, + }, + { + name: "error with message", + status: &Status{Code: NotFound, Message: "resource not found"}, + want: "grpc error: NOT_FOUND: resource not found", + wantOK: false, + }, + { + name: "error without message", + status: &Status{Code: Internal}, + want: "grpc error: INTERNAL", + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.status.Error() + if got != tt.want { + t.Errorf("Status.Error() = %v, want %v", got, tt.want) + } + if tt.status.OK() != tt.wantOK { + t.Errorf("Status.OK() = %v, want %v", tt.status.OK(), tt.wantOK) + } + }) + } +} + +func TestParseStatus(t *testing.T) { + tests := []struct { + name string + grpcStatus string + grpcMessage string + wantCode Code + wantMessage string + }{ + { + name: "OK", + grpcStatus: "0", + grpcMessage: "", + wantCode: OK, + wantMessage: "", + }, + { + name: "NotFound with message", + grpcStatus: "5", + grpcMessage: "not found", + wantCode: NotFound, + wantMessage: "not found", + }, + { + name: "invalid status string", + grpcStatus: "invalid", + grpcMessage: "some message", + wantCode: Unknown, + wantMessage: "some message", + }, + { + name: "empty status string", + grpcStatus: "", + grpcMessage: "error occurred", + wantCode: Unknown, + wantMessage: "error occurred", + }, + { + name: "Unauthenticated", + grpcStatus: "16", + grpcMessage: "invalid token", + wantCode: Unauthenticated, + wantMessage: "invalid token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + status := ParseStatus(tt.grpcStatus, tt.grpcMessage) + if status.Code != tt.wantCode { + t.Errorf("ParseStatus() Code = %v, want %v", status.Code, tt.wantCode) + } + if status.Message != tt.wantMessage { + t.Errorf("ParseStatus() Message = %v, want %v", status.Message, tt.wantMessage) + } + }) + } +} diff --git a/internal/proto/compile.go b/internal/proto/compile.go new file mode 100644 index 0000000..f4020c4 --- /dev/null +++ b/internal/proto/compile.go @@ -0,0 +1,92 @@ +package proto + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// CompileProtos compiles .proto files via protoc and returns the loaded schema. +// protoFiles is a list of .proto file paths. +// importPaths is a list of directories to search for imports (-I flags to protoc). +func CompileProtos(protoFiles, importPaths []string) (*Schema, error) { + // Check that protoc is available. + protocPath, err := exec.LookPath("protoc") + if err != nil { + return nil, &ProtocNotFoundError{} + } + + // Create temp file for descriptor output. + tmpFile, err := os.CreateTemp("", "fetch-proto-*.pb") + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + tmpPath := tmpFile.Name() + tmpFile.Close() + defer os.Remove(tmpPath) + + // Build protoc command. + args := []string{ + "--descriptor_set_out=" + tmpPath, + "--include_imports", + } + + // Add import paths. + // If no import paths specified, add the directory of each proto file. + seenDirs := make(map[string]bool) + if len(importPaths) == 0 { + for _, f := range protoFiles { + dir := filepath.Dir(f) + absDir, err := filepath.Abs(dir) + if err != nil { + absDir = dir + } + if !seenDirs[absDir] { + seenDirs[absDir] = true + args = append(args, "-I="+absDir) + } + } + } else { + for _, imp := range importPaths { + args = append(args, "-I="+imp) + } + } + + // Add proto files. + args = append(args, protoFiles...) + + // Execute protoc. + var stderr bytes.Buffer + cmd := exec.Command(protocPath, args...) + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + errMsg := strings.TrimSpace(stderr.String()) + if errMsg == "" { + errMsg = err.Error() + } + return nil, &ProtocError{Message: errMsg} + } + + // Load the generated descriptor set. + return LoadDescriptorSetFile(tmpPath) +} + +// ProtocNotFoundError indicates protoc is not installed or not in PATH. +type ProtocNotFoundError struct{} + +func (e *ProtocNotFoundError) Error() string { + return "protoc not found in PATH. Install protoc from https://github.com/protocolbuffers/protobuf/releases" +} + +// ProtocError indicates protoc execution failed. +type ProtocError struct { + Message string +} + +func (e *ProtocError) Error() string { + return fmt.Sprintf("protoc failed: %s", e.Message) +} diff --git a/internal/proto/compile_test.go b/internal/proto/compile_test.go new file mode 100644 index 0000000..59dd33e --- /dev/null +++ b/internal/proto/compile_test.go @@ -0,0 +1,285 @@ +package proto + +import ( + "os" + "os/exec" + "path/filepath" + "testing" +) + +func TestCompileProtosSuccess(t *testing.T) { + // Skip if protoc is not available. + if _, err := exec.LookPath("protoc"); err != nil { + t.Skip("protoc not found in PATH, skipping compile tests") + } + + // Create a test proto file. + tmpDir := t.TempDir() + protoFile := filepath.Join(tmpDir, "test.proto") + protoContent := ` +syntax = "proto3"; +package testcompile; + +message TestRequest { + int64 id = 1; + string name = 2; +} + +message TestResponse { + bool success = 1; + string message = 2; +} + +service TestService { + rpc GetTest(TestRequest) returns (TestResponse); +} +` + if err := os.WriteFile(protoFile, []byte(protoContent), 0644); err != nil { + t.Fatalf("os.WriteFile() error = %v", err) + } + + // Compile the proto. + schema, err := CompileProtos([]string{protoFile}, nil) + if err != nil { + t.Fatalf("CompileProtos() error = %v", err) + } + + // Verify messages were loaded. + md, err := schema.FindMessage("testcompile.TestRequest") + if err != nil { + t.Errorf("FindMessage(TestRequest) error = %v", err) + } + if md == nil { + t.Error("FindMessage(TestRequest) returned nil") + } + + md, err = schema.FindMessage("testcompile.TestResponse") + if err != nil { + t.Errorf("FindMessage(TestResponse) error = %v", err) + } + if md == nil { + t.Error("FindMessage(TestResponse) returned nil") + } + + // Verify service was loaded. + sd, err := schema.FindService("testcompile.TestService") + if err != nil { + t.Errorf("FindService() error = %v", err) + } + if sd == nil { + t.Error("FindService() returned nil") + } + + // Verify method was loaded. + method, err := schema.FindMethod("testcompile.TestService/GetTest") + if err != nil { + t.Errorf("FindMethod() error = %v", err) + } + if method == nil { + t.Error("FindMethod() returned nil") + } +} + +func TestCompileProtosWithImports(t *testing.T) { + // Skip if protoc is not available. + if _, err := exec.LookPath("protoc"); err != nil { + t.Skip("protoc not found in PATH, skipping compile tests") + } + + // Create a directory structure with imports. + tmpDir := t.TempDir() + commonDir := filepath.Join(tmpDir, "common") + serviceDir := filepath.Join(tmpDir, "service") + + if err := os.MkdirAll(commonDir, 0755); err != nil { + t.Fatalf("os.MkdirAll(common) error = %v", err) + } + if err := os.MkdirAll(serviceDir, 0755); err != nil { + t.Fatalf("os.MkdirAll(service) error = %v", err) + } + + // Create common proto. + commonProto := filepath.Join(commonDir, "common.proto") + commonContent := ` +syntax = "proto3"; +package common; + +message Timestamp { + int64 seconds = 1; + int32 nanos = 2; +} +` + if err := os.WriteFile(commonProto, []byte(commonContent), 0644); err != nil { + t.Fatalf("os.WriteFile(common) error = %v", err) + } + + // Create service proto that imports common. + serviceProto := filepath.Join(serviceDir, "service.proto") + serviceContent := ` +syntax = "proto3"; +package myservice; + +import "common/common.proto"; + +message Event { + string id = 1; + common.Timestamp timestamp = 2; +} +` + if err := os.WriteFile(serviceProto, []byte(serviceContent), 0644); err != nil { + t.Fatalf("os.WriteFile(service) error = %v", err) + } + + // Compile with import path. + schema, err := CompileProtos([]string{serviceProto}, []string{tmpDir}) + if err != nil { + t.Fatalf("CompileProtos() error = %v", err) + } + + // Verify message was loaded. + md, err := schema.FindMessage("myservice.Event") + if err != nil { + t.Errorf("FindMessage(Event) error = %v", err) + } + if md == nil { + t.Error("FindMessage(Event) returned nil") + } + + // Verify imported message is also available. + md, err = schema.FindMessage("common.Timestamp") + if err != nil { + t.Errorf("FindMessage(Timestamp) error = %v", err) + } + if md == nil { + t.Error("FindMessage(Timestamp) returned nil") + } +} + +func TestCompileProtosFileNotFound(t *testing.T) { + // Skip if protoc is not available. + if _, err := exec.LookPath("protoc"); err != nil { + t.Skip("protoc not found in PATH, skipping compile tests") + } + + _, err := CompileProtos([]string{"/nonexistent/path/to/file.proto"}, nil) + if err == nil { + t.Error("CompileProtos() expected error for nonexistent file") + } +} + +func TestCompileProtosInvalidSyntax(t *testing.T) { + // Skip if protoc is not available. + if _, err := exec.LookPath("protoc"); err != nil { + t.Skip("protoc not found in PATH, skipping compile tests") + } + + tmpDir := t.TempDir() + protoFile := filepath.Join(tmpDir, "invalid.proto") + + // Write invalid proto syntax. + invalidContent := ` +this is not valid proto syntax!!! +message { + broken = 1; +} +` + if err := os.WriteFile(protoFile, []byte(invalidContent), 0644); err != nil { + t.Fatalf("os.WriteFile() error = %v", err) + } + + _, err := CompileProtos([]string{protoFile}, nil) + if err == nil { + t.Error("CompileProtos() expected error for invalid proto syntax") + } + + // Should be a ProtocError. + protocErr, ok := err.(*ProtocError) + if !ok { + t.Errorf("expected ProtocError, got %T", err) + } + if protocErr != nil && protocErr.Message == "" { + t.Error("ProtocError.Message should not be empty") + } +} + +func TestCompileProtosMultipleFiles(t *testing.T) { + // Skip if protoc is not available. + if _, err := exec.LookPath("protoc"); err != nil { + t.Skip("protoc not found in PATH, skipping compile tests") + } + + tmpDir := t.TempDir() + + // Create first proto file. + proto1 := filepath.Join(tmpDir, "first.proto") + proto1Content := ` +syntax = "proto3"; +package first; + +message FirstMessage { + string value = 1; +} +` + if err := os.WriteFile(proto1, []byte(proto1Content), 0644); err != nil { + t.Fatalf("os.WriteFile(first) error = %v", err) + } + + // Create second proto file. + proto2 := filepath.Join(tmpDir, "second.proto") + proto2Content := ` +syntax = "proto3"; +package second; + +message SecondMessage { + int32 count = 1; +} +` + if err := os.WriteFile(proto2, []byte(proto2Content), 0644); err != nil { + t.Fatalf("os.WriteFile(second) error = %v", err) + } + + // Compile both. + schema, err := CompileProtos([]string{proto1, proto2}, nil) + if err != nil { + t.Fatalf("CompileProtos() error = %v", err) + } + + // Verify both messages are available. + md1, err := schema.FindMessage("first.FirstMessage") + if err != nil { + t.Errorf("FindMessage(FirstMessage) error = %v", err) + } + if md1 == nil { + t.Error("FindMessage(FirstMessage) returned nil") + } + + md2, err := schema.FindMessage("second.SecondMessage") + if err != nil { + t.Errorf("FindMessage(SecondMessage) error = %v", err) + } + if md2 == nil { + t.Error("FindMessage(SecondMessage) returned nil") + } +} + +func TestProtocNotFoundError(t *testing.T) { + err := &ProtocNotFoundError{} + msg := err.Error() + if msg == "" { + t.Error("ProtocNotFoundError.Error() returned empty string") + } + if len(msg) < 10 { + t.Error("ProtocNotFoundError.Error() message too short") + } +} + +func TestProtocError(t *testing.T) { + err := &ProtocError{Message: "test error message"} + msg := err.Error() + if msg == "" { + t.Error("ProtocError.Error() returned empty string") + } + if msg != "protoc failed: test error message" { + t.Errorf("ProtocError.Error() = %v, want 'protoc failed: test error message'", msg) + } +} diff --git a/internal/proto/descriptor.go b/internal/proto/descriptor.go new file mode 100644 index 0000000..506ebea --- /dev/null +++ b/internal/proto/descriptor.go @@ -0,0 +1,29 @@ +package proto + +import ( + "fmt" + "os" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" +) + +// LoadDescriptorSetFile loads a schema from a pre-compiled FileDescriptorSet file (.pb). +func LoadDescriptorSetFile(path string) (*Schema, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read descriptor set file: %w", err) + } + + return loadDescriptorSetBytes(data) +} + +// loadDescriptorSetBytes loads a schema from FileDescriptorSet bytes. +func loadDescriptorSetBytes(data []byte) (*Schema, error) { + fds := &descriptorpb.FileDescriptorSet{} + if err := proto.Unmarshal(data, fds); err != nil { + return nil, fmt.Errorf("failed to unmarshal FileDescriptorSet: %w", err) + } + + return LoadFromDescriptorSet(fds) +} diff --git a/internal/proto/descriptor_test.go b/internal/proto/descriptor_test.go new file mode 100644 index 0000000..49364e7 --- /dev/null +++ b/internal/proto/descriptor_test.go @@ -0,0 +1,112 @@ +package proto + +import ( + "os" + "path/filepath" + "testing" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" +) + +func TestLoadDescriptorSetFile(t *testing.T) { + // Create a temporary descriptor set file. + fds := createTestDescriptorSet() + data, err := proto.Marshal(fds) + if err != nil { + t.Fatalf("proto.Marshal() error = %v", err) + } + + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "test.pb") + if err := os.WriteFile(tmpFile, data, 0644); err != nil { + t.Fatalf("os.WriteFile() error = %v", err) + } + + // Test loading. + schema, err := LoadDescriptorSetFile(tmpFile) + if err != nil { + t.Fatalf("LoadDescriptorSetFile() error = %v", err) + } + + // Verify schema was loaded correctly. + md, err := schema.FindMessage("testpkg.TestMessage") + if err != nil { + t.Errorf("FindMessage() error = %v", err) + } + if md == nil { + t.Error("FindMessage() returned nil") + } +} + +func TestLoadDescriptorSetFileNotFound(t *testing.T) { + _, err := LoadDescriptorSetFile("/nonexistent/path/to/file.pb") + if err == nil { + t.Error("LoadDescriptorSetFile() expected error for nonexistent file") + } +} + +func TestLoadDescriptorSetFileInvalidContent(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "invalid.pb") + + // Write invalid protobuf data. + if err := os.WriteFile(tmpFile, []byte("not a valid protobuf"), 0644); err != nil { + t.Fatalf("os.WriteFile() error = %v", err) + } + + _, err := LoadDescriptorSetFile(tmpFile) + if err == nil { + t.Error("LoadDescriptorSetFile() expected error for invalid protobuf") + } +} + +func TestLoadDescriptorSetBytes(t *testing.T) { + fds := createTestDescriptorSet() + data, err := proto.Marshal(fds) + if err != nil { + t.Fatalf("proto.Marshal() error = %v", err) + } + + schema, err := loadDescriptorSetBytes(data) + if err != nil { + t.Fatalf("LoadDescriptorSetBytes() error = %v", err) + } + + // Verify schema was loaded correctly. + md, err := schema.FindMessage("testpkg.TestMessage") + if err != nil { + t.Errorf("FindMessage() error = %v", err) + } + if md == nil { + t.Error("FindMessage() returned nil") + } +} + +func TestLoadDescriptorSetBytesEmpty(t *testing.T) { + // Empty bytes should produce empty schema. + fds := &descriptorpb.FileDescriptorSet{} + data, err := proto.Marshal(fds) + if err != nil { + t.Fatalf("proto.Marshal() error = %v", err) + } + + schema, err := loadDescriptorSetBytes(data) + if err != nil { + t.Fatalf("LoadDescriptorSetBytes() error = %v", err) + } + + if len(schema.ListMessages()) != 0 { + t.Errorf("expected 0 messages, got %d", len(schema.ListMessages())) + } + if len(schema.ListServices()) != 0 { + t.Errorf("expected 0 services, got %d", len(schema.ListServices())) + } +} + +func TestLoadDescriptorSetBytesInvalid(t *testing.T) { + _, err := loadDescriptorSetBytes([]byte("not valid protobuf")) + if err == nil { + t.Error("LoadDescriptorSetBytes() expected error for invalid protobuf") + } +} diff --git a/internal/proto/message.go b/internal/proto/message.go new file mode 100644 index 0000000..f97973c --- /dev/null +++ b/internal/proto/message.go @@ -0,0 +1,60 @@ +package proto + +import ( + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" +) + +// JSONToProtobuf converts JSON data to protobuf binary format. +func JSONToProtobuf(jsonData []byte, md protoreflect.MessageDescriptor) ([]byte, error) { + msg := dynamicpb.NewMessage(md) + + // Configure unmarshaler to be lenient with field names. + opts := protojson.UnmarshalOptions{ + DiscardUnknown: true, + } + + if err := opts.Unmarshal(jsonData, msg); err != nil { + return nil, err + } + + return proto.Marshal(msg) +} + +// ProtobufToJSON converts protobuf binary data to JSON format. +func ProtobufToJSON(data []byte, md protoreflect.MessageDescriptor) ([]byte, error) { + msg := dynamicpb.NewMessage(md) + + if err := proto.Unmarshal(data, msg); err != nil { + return nil, err + } + + // Configure marshaler for readable output. + opts := protojson.MarshalOptions{ + Multiline: true, + Indent: " ", + EmitUnpopulated: false, + UseProtoNames: true, + } + + return opts.Marshal(msg) +} + +// ProtobufToJSONCompact converts protobuf binary data to compact JSON format. +func ProtobufToJSONCompact(data []byte, md protoreflect.MessageDescriptor) ([]byte, error) { + msg := dynamicpb.NewMessage(md) + + if err := proto.Unmarshal(data, msg); err != nil { + return nil, err + } + + opts := protojson.MarshalOptions{ + Multiline: false, + EmitUnpopulated: false, + UseProtoNames: true, + } + + return opts.Marshal(msg) +} diff --git a/internal/proto/message_test.go b/internal/proto/message_test.go new file mode 100644 index 0000000..f7ad8aa --- /dev/null +++ b/internal/proto/message_test.go @@ -0,0 +1,404 @@ +package proto + +import ( + "encoding/json" + "slices" + "testing" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/types/descriptorpb" +) + +func TestJSONToProtobuf(t *testing.T) { + fds := createTestDescriptorSet() + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Fatalf("LoadFromDescriptorSet() error = %v", err) + } + + md, err := schema.FindMessage("testpkg.TestMessage") + if err != nil { + t.Fatalf("FindMessage() error = %v", err) + } + + tests := []struct { + name string + jsonInput string + wantErr bool + checkFunc func(t *testing.T, data []byte) + }{ + { + name: "simple message", + jsonInput: `{"id": 123, "name": "test"}`, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + // Verify the protobuf contains expected fields. + if len(data) == 0 { + t.Error("expected non-empty protobuf data") + } + }, + }, + { + name: "empty message", + jsonInput: `{}`, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + // Empty message should produce empty or minimal protobuf. + }, + }, + { + name: "partial message - id only", + jsonInput: `{"id": 456}`, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + if len(data) == 0 { + t.Error("expected non-empty protobuf data") + } + }, + }, + { + name: "partial message - name only", + jsonInput: `{"name": "only name"}`, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + if len(data) == 0 { + t.Error("expected non-empty protobuf data") + } + }, + }, + { + name: "unknown field is discarded", + jsonInput: `{"id": 1, "unknownField": "ignored"}`, + wantErr: false, + }, + { + name: "invalid JSON", + jsonInput: `{invalid`, + wantErr: true, + }, + { + name: "type mismatch - string for int", + jsonInput: `{"id": "not a number"}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := JSONToProtobuf([]byte(tt.jsonInput), md) + if (err != nil) != tt.wantErr { + t.Errorf("JSONToProtobuf() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && tt.checkFunc != nil { + tt.checkFunc(t, data) + } + }) + } +} + +func TestProtobufToJSON(t *testing.T) { + fds := createTestDescriptorSet() + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Fatalf("LoadFromDescriptorSet() error = %v", err) + } + + md, err := schema.FindMessage("testpkg.TestMessage") + if err != nil { + t.Fatalf("FindMessage() error = %v", err) + } + + tests := []struct { + name string + protoInput []byte + wantErr bool + wantID string // protojson outputs int64 as strings + wantName string + }{ + { + name: "simple message", + protoInput: buildTestProtobuf(123, "test"), + wantErr: false, + wantID: "123", + wantName: "test", + }, + { + name: "empty message", + protoInput: []byte{}, + wantErr: false, + }, + { + name: "id only", + protoInput: buildTestProtobuf(999, ""), + wantErr: false, + wantID: "999", + }, + { + name: "name only", + protoInput: buildNameOnlyProtobuf("hello"), + wantErr: false, + wantName: "hello", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + jsonData, err := ProtobufToJSON(tt.protoInput, md) + if (err != nil) != tt.wantErr { + t.Errorf("ProtobufToJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + + // Verify JSON is valid. + var result map[string]any + if err := json.Unmarshal(jsonData, &result); err != nil { + t.Errorf("ProtobufToJSON() produced invalid JSON: %v", err) + return + } + + // Check ID field (protojson outputs int64 as strings). + if tt.wantID != "" { + // protojson may output as string or number depending on config + switch v := result["id"].(type) { + case string: + if v != tt.wantID { + t.Errorf("id = %v, want %v", v, tt.wantID) + } + case float64: + // Also accept numeric if config allows + default: + if result["id"] != nil { + t.Errorf("id has unexpected type %T", result["id"]) + } + } + } + + // Check name field. + if tt.wantName != "" { + if name, ok := result["name"].(string); !ok || name != tt.wantName { + t.Errorf("name = %v, want %v", result["name"], tt.wantName) + } + } + }) + } +} + +func TestProtobufToJSONCompact(t *testing.T) { + fds := createTestDescriptorSet() + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Fatalf("LoadFromDescriptorSet() error = %v", err) + } + + md, err := schema.FindMessage("testpkg.TestMessage") + if err != nil { + t.Fatalf("FindMessage() error = %v", err) + } + + protoData := buildTestProtobuf(123, "test") + jsonData, err := ProtobufToJSONCompact(protoData, md) + if err != nil { + t.Fatalf("ProtobufToJSONCompact() error = %v", err) + } + + // Compact JSON should not have newlines. + if slices.Contains(jsonData, byte('\n')) { + t.Error("ProtobufToJSONCompact() output contains newlines") + } + + // Verify it's valid JSON. + var result map[string]any + if err := json.Unmarshal(jsonData, &result); err != nil { + t.Errorf("ProtobufToJSONCompact() produced invalid JSON: %v", err) + } +} + +func TestJSONToProtobufRoundTrip(t *testing.T) { + fds := createTestDescriptorSet() + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Fatalf("LoadFromDescriptorSet() error = %v", err) + } + + md, err := schema.FindMessage("testpkg.TestMessage") + if err != nil { + t.Fatalf("FindMessage() error = %v", err) + } + + tests := []struct { + name string + jsonInput string + wantID string // protojson outputs int64 as strings + wantName string + }{ + { + name: "full message", + jsonInput: `{"id": 42, "name": "roundtrip"}`, + wantID: "42", + wantName: "roundtrip", + }, + { + name: "zero values", + jsonInput: `{"id": 0, "name": ""}`, + wantID: "", + wantName: "", + }, + { + name: "large id", + jsonInput: `{"id": 9223372036854775807, "name": "max int64"}`, + wantID: "9223372036854775807", + wantName: "max int64", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // JSON -> Protobuf. + protoData, err := JSONToProtobuf([]byte(tt.jsonInput), md) + if err != nil { + t.Fatalf("JSONToProtobuf() error = %v", err) + } + + // Protobuf -> JSON. + jsonData, err := ProtobufToJSON(protoData, md) + if err != nil { + t.Fatalf("ProtobufToJSON() error = %v", err) + } + + // Parse result. + var result map[string]any + if err := json.Unmarshal(jsonData, &result); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + + // Verify values (only check non-zero since zero values may be omitted). + if tt.wantID != "" { + // protojson outputs int64 as strings + switch v := result["id"].(type) { + case string: + if v != tt.wantID { + t.Errorf("id = %v, want %v", v, tt.wantID) + } + case float64: + // Also accept numeric + default: + if result["id"] != nil { + t.Errorf("id has unexpected type %T", result["id"]) + } + } + } + if tt.wantName != "" { + if name, ok := result["name"].(string); !ok || name != tt.wantName { + t.Errorf("name = %v, want %v", result["name"], tt.wantName) + } + } + }) + } +} + +func TestJSONToProtobufWithNestedMessage(t *testing.T) { + // Create a schema with nested message. + fds := &descriptorpb.FileDescriptorSet{ + File: []*descriptorpb.FileDescriptorProto{ + { + Name: strPtr("nested.proto"), + Package: strPtr("nested"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: strPtr("Inner"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: strPtr("value"), + Number: int32Ptr(1), + Type: typePtr(descriptorpb.FieldDescriptorProto_TYPE_STRING), + }, + }, + }, + { + Name: strPtr("Outer"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: strPtr("inner"), + Number: int32Ptr(1), + Type: typePtr(descriptorpb.FieldDescriptorProto_TYPE_MESSAGE), + TypeName: strPtr(".nested.Inner"), + }, + { + Name: strPtr("count"), + Number: int32Ptr(2), + Type: typePtr(descriptorpb.FieldDescriptorProto_TYPE_INT32), + }, + }, + }, + }, + }, + }, + } + + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Fatalf("LoadFromDescriptorSet() error = %v", err) + } + + md, err := schema.FindMessage("nested.Outer") + if err != nil { + t.Fatalf("FindMessage() error = %v", err) + } + + jsonInput := `{"inner": {"value": "nested value"}, "count": 5}` + protoData, err := JSONToProtobuf([]byte(jsonInput), md) + if err != nil { + t.Fatalf("JSONToProtobuf() error = %v", err) + } + + // Convert back and verify. + jsonOutput, err := ProtobufToJSON(protoData, md) + if err != nil { + t.Fatalf("ProtobufToJSON() error = %v", err) + } + + var result map[string]any + if err := json.Unmarshal(jsonOutput, &result); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + + // Check nested value. + inner, ok := result["inner"].(map[string]any) + if !ok { + t.Fatal("inner field not found or not an object") + } + if inner["value"] != "nested value" { + t.Errorf("inner.value = %v, want 'nested value'", inner["value"]) + } + if result["count"] != float64(5) { + t.Errorf("count = %v, want 5", result["count"]) + } +} + +// Helper functions to build test protobuf data. + +func buildTestProtobuf(id int64, name string) []byte { + var buf []byte + // Field 1: id (int64, varint). + if id != 0 { + buf = protowire.AppendTag(buf, 1, protowire.VarintType) + buf = protowire.AppendVarint(buf, uint64(id)) + } + // Field 2: name (string, bytes). + if name != "" { + buf = protowire.AppendTag(buf, 2, protowire.BytesType) + buf = protowire.AppendString(buf, name) + } + return buf +} + +func buildNameOnlyProtobuf(name string) []byte { + var buf []byte + buf = protowire.AppendTag(buf, 2, protowire.BytesType) + buf = protowire.AppendString(buf, name) + return buf +} diff --git a/internal/proto/schema.go b/internal/proto/schema.go new file mode 100644 index 0000000..bb0557a --- /dev/null +++ b/internal/proto/schema.go @@ -0,0 +1,144 @@ +package proto + +import ( + "fmt" + "strings" + + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/descriptorpb" +) + +// Schema holds loaded protobuf type information from descriptors. +type Schema struct { + files *protoregistry.Files + messages map[string]protoreflect.MessageDescriptor + services map[string]protoreflect.ServiceDescriptor +} + +// NewSchema creates an empty schema. +func NewSchema() *Schema { + return &Schema{ + files: new(protoregistry.Files), + messages: make(map[string]protoreflect.MessageDescriptor), + services: make(map[string]protoreflect.ServiceDescriptor), + } +} + +// LoadFromDescriptorSet loads schema from a FileDescriptorSet. +func LoadFromDescriptorSet(fds *descriptorpb.FileDescriptorSet) (*Schema, error) { + schema := NewSchema() + + // Build file descriptors. + files, err := protodesc.NewFiles(fds) + if err != nil { + return nil, fmt.Errorf("failed to create file descriptors: %w", err) + } + schema.files = files + + // Index all messages and services. + files.RangeFiles(func(fd protoreflect.FileDescriptor) bool { + schema.indexFile(fd) + return true + }) + + return schema, nil +} + +// indexFile indexes all messages and services in a file descriptor. +func (s *Schema) indexFile(fd protoreflect.FileDescriptor) { + // Index top-level messages. + msgs := fd.Messages() + for i := 0; i < msgs.Len(); i++ { + s.indexMessage(msgs.Get(i)) + } + + // Index services. + svcs := fd.Services() + for i := 0; i < svcs.Len(); i++ { + svc := svcs.Get(i) + s.services[string(svc.FullName())] = svc + } +} + +// indexMessage indexes a message and its nested messages. +func (s *Schema) indexMessage(md protoreflect.MessageDescriptor) { + s.messages[string(md.FullName())] = md + + // Index nested messages. + nested := md.Messages() + for i := 0; i < nested.Len(); i++ { + s.indexMessage(nested.Get(i)) + } +} + +// FindMessage finds a message descriptor by full name. +// The name can be with or without leading dot. +func (s *Schema) FindMessage(name string) (protoreflect.MessageDescriptor, error) { + name = strings.TrimPrefix(name, ".") + + if md, ok := s.messages[name]; ok { + return md, nil + } + return nil, fmt.Errorf("message type not found: %s", name) +} + +// FindService finds a service descriptor by full name. +func (s *Schema) FindService(name string) (protoreflect.ServiceDescriptor, error) { + name = strings.TrimPrefix(name, ".") + + if sd, ok := s.services[name]; ok { + return sd, nil + } + return nil, fmt.Errorf("service not found: %s", name) +} + +// FindMethod finds a method in a service by service and method name. +// The format can be "package.Service/Method" or "package.Service.Method". +func (s *Schema) FindMethod(fullName string) (protoreflect.MethodDescriptor, error) { + // Try both "/" and "." as separators. + var serviceName, methodName string + if idx := strings.LastIndex(fullName, "/"); idx >= 0 { + serviceName = fullName[:idx] + methodName = fullName[idx+1:] + } else if idx := strings.LastIndex(fullName, "."); idx >= 0 { + serviceName = fullName[:idx] + methodName = fullName[idx+1:] + } else { + return nil, fmt.Errorf("invalid method name format: %s (expected 'Service/Method' or 'Service.Method')", fullName) + } + + sd, err := s.FindService(serviceName) + if err != nil { + return nil, err + } + + methods := sd.Methods() + for i := 0; i < methods.Len(); i++ { + m := methods.Get(i) + if string(m.Name()) == methodName { + return m, nil + } + } + + return nil, fmt.Errorf("method %s not found in service %s", methodName, serviceName) +} + +// ListMessages returns all message type names. +func (s *Schema) ListMessages() []string { + names := make([]string, 0, len(s.messages)) + for name := range s.messages { + names = append(names, name) + } + return names +} + +// ListServices returns all service names. +func (s *Schema) ListServices() []string { + names := make([]string, 0, len(s.services)) + for name := range s.services { + names = append(names, name) + } + return names +} diff --git a/internal/proto/schema_test.go b/internal/proto/schema_test.go new file mode 100644 index 0000000..0c24127 --- /dev/null +++ b/internal/proto/schema_test.go @@ -0,0 +1,386 @@ +package proto + +import ( + "testing" + + "google.golang.org/protobuf/types/descriptorpb" +) + +func TestNewSchema(t *testing.T) { + s := NewSchema() + if s == nil { + t.Fatal("NewSchema() returned nil") + } + if s.files == nil { + t.Error("NewSchema().files is nil") + } + if s.messages == nil { + t.Error("NewSchema().messages is nil") + } + if s.services == nil { + t.Error("NewSchema().services is nil") + } +} + +func TestLoadFromDescriptorSet(t *testing.T) { + // Create a minimal FileDescriptorSet with a message and service. + fds := &descriptorpb.FileDescriptorSet{ + File: []*descriptorpb.FileDescriptorProto{ + { + Name: strPtr("test.proto"), + Package: strPtr("testpkg"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: strPtr("TestMessage"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: strPtr("id"), + Number: int32Ptr(1), + Type: typePtr(descriptorpb.FieldDescriptorProto_TYPE_INT64), + }, + { + Name: strPtr("name"), + Number: int32Ptr(2), + Type: typePtr(descriptorpb.FieldDescriptorProto_TYPE_STRING), + }, + }, + }, + { + Name: strPtr("NestedOuter"), + NestedType: []*descriptorpb.DescriptorProto{ + { + Name: strPtr("NestedInner"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: strPtr("value"), + Number: int32Ptr(1), + Type: typePtr(descriptorpb.FieldDescriptorProto_TYPE_STRING), + }, + }, + }, + }, + }, + }, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: strPtr("TestService"), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: strPtr("GetTest"), + InputType: strPtr(".testpkg.TestMessage"), + OutputType: strPtr(".testpkg.TestMessage"), + }, + { + Name: strPtr("CreateTest"), + InputType: strPtr(".testpkg.TestMessage"), + OutputType: strPtr(".testpkg.TestMessage"), + }, + }, + }, + }, + }, + }, + } + + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Fatalf("LoadFromDescriptorSet() error = %v", err) + } + + // Verify messages were indexed. + messages := schema.ListMessages() + if len(messages) < 2 { + t.Errorf("expected at least 2 messages, got %d", len(messages)) + } + + // Verify services were indexed. + services := schema.ListServices() + if len(services) != 1 { + t.Errorf("expected 1 service, got %d", len(services)) + } +} + +func TestFindMessage(t *testing.T) { + fds := createTestDescriptorSet() + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Fatalf("LoadFromDescriptorSet() error = %v", err) + } + + tests := []struct { + name string + msgName string + wantErr bool + }{ + { + name: "find by full name", + msgName: "testpkg.TestMessage", + wantErr: false, + }, + { + name: "find with leading dot", + msgName: ".testpkg.TestMessage", + wantErr: false, + }, + { + name: "find nested message", + msgName: "testpkg.NestedOuter.NestedInner", + wantErr: false, + }, + { + name: "not found", + msgName: "testpkg.NonExistent", + wantErr: true, + }, + { + name: "wrong package", + msgName: "wrongpkg.TestMessage", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + md, err := schema.FindMessage(tt.msgName) + if (err != nil) != tt.wantErr { + t.Errorf("FindMessage() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && md == nil { + t.Error("FindMessage() returned nil without error") + } + }) + } +} + +func TestFindService(t *testing.T) { + fds := createTestDescriptorSet() + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Fatalf("LoadFromDescriptorSet() error = %v", err) + } + + tests := []struct { + name string + svcName string + wantErr bool + }{ + { + name: "find by full name", + svcName: "testpkg.TestService", + wantErr: false, + }, + { + name: "find with leading dot", + svcName: ".testpkg.TestService", + wantErr: false, + }, + { + name: "not found", + svcName: "testpkg.NonExistent", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sd, err := schema.FindService(tt.svcName) + if (err != nil) != tt.wantErr { + t.Errorf("FindService() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && sd == nil { + t.Error("FindService() returned nil without error") + } + }) + } +} + +func TestFindMethod(t *testing.T) { + fds := createTestDescriptorSet() + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Fatalf("LoadFromDescriptorSet() error = %v", err) + } + + tests := []struct { + name string + methodName string + wantErr bool + }{ + { + name: "slash separator", + methodName: "testpkg.TestService/GetTest", + wantErr: false, + }, + { + name: "dot separator", + methodName: "testpkg.TestService.GetTest", + wantErr: false, + }, + { + name: "second method", + methodName: "testpkg.TestService/CreateTest", + wantErr: false, + }, + { + name: "service not found", + methodName: "testpkg.NonExistent/GetTest", + wantErr: true, + }, + { + name: "method not found", + methodName: "testpkg.TestService/NonExistent", + wantErr: true, + }, + { + name: "invalid format - no separator", + methodName: "InvalidMethodName", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + md, err := schema.FindMethod(tt.methodName) + if (err != nil) != tt.wantErr { + t.Errorf("FindMethod() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && md == nil { + t.Error("FindMethod() returned nil without error") + } + }) + } +} + +func TestListMessages(t *testing.T) { + fds := createTestDescriptorSet() + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Fatalf("LoadFromDescriptorSet() error = %v", err) + } + + messages := schema.ListMessages() + // Should have TestMessage, NestedOuter, and NestedOuter.NestedInner + if len(messages) < 3 { + t.Errorf("expected at least 3 messages, got %d: %v", len(messages), messages) + } + + // Check that expected messages are present. + found := make(map[string]bool) + for _, m := range messages { + found[m] = true + } + if !found["testpkg.TestMessage"] { + t.Error("TestMessage not in list") + } + if !found["testpkg.NestedOuter"] { + t.Error("NestedOuter not in list") + } + if !found["testpkg.NestedOuter.NestedInner"] { + t.Error("NestedOuter.NestedInner not in list") + } +} + +func TestListServices(t *testing.T) { + fds := createTestDescriptorSet() + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Fatalf("LoadFromDescriptorSet() error = %v", err) + } + + services := schema.ListServices() + if len(services) != 1 { + t.Errorf("expected 1 service, got %d: %v", len(services), services) + } + if len(services) > 0 && services[0] != "testpkg.TestService" { + t.Errorf("expected testpkg.TestService, got %s", services[0]) + } +} + +func TestLoadFromDescriptorSetError(t *testing.T) { + // Empty FileDescriptorSet should still work. + fds := &descriptorpb.FileDescriptorSet{} + schema, err := LoadFromDescriptorSet(fds) + if err != nil { + t.Errorf("LoadFromDescriptorSet() with empty FDS error = %v", err) + } + if schema == nil { + t.Error("expected non-nil schema for empty FDS") + } +} + +// Helper functions to create test data. + +func createTestDescriptorSet() *descriptorpb.FileDescriptorSet { + return &descriptorpb.FileDescriptorSet{ + File: []*descriptorpb.FileDescriptorProto{ + { + Name: strPtr("test.proto"), + Package: strPtr("testpkg"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: strPtr("TestMessage"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: strPtr("id"), + Number: int32Ptr(1), + Type: typePtr(descriptorpb.FieldDescriptorProto_TYPE_INT64), + }, + { + Name: strPtr("name"), + Number: int32Ptr(2), + Type: typePtr(descriptorpb.FieldDescriptorProto_TYPE_STRING), + }, + }, + }, + { + Name: strPtr("NestedOuter"), + NestedType: []*descriptorpb.DescriptorProto{ + { + Name: strPtr("NestedInner"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: strPtr("value"), + Number: int32Ptr(1), + Type: typePtr(descriptorpb.FieldDescriptorProto_TYPE_STRING), + }, + }, + }, + }, + }, + }, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: strPtr("TestService"), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: strPtr("GetTest"), + InputType: strPtr(".testpkg.TestMessage"), + OutputType: strPtr(".testpkg.TestMessage"), + }, + { + Name: strPtr("CreateTest"), + InputType: strPtr(".testpkg.TestMessage"), + OutputType: strPtr(".testpkg.TestMessage"), + }, + }, + }, + }, + }, + }, + } +} + +func strPtr(s string) *string { + return &s +} + +func int32Ptr(i int32) *int32 { + return &i +} + +func typePtr(t descriptorpb.FieldDescriptorProto_Type) *descriptorpb.FieldDescriptorProto_Type { + return &t +} diff --git a/main.go b/main.go index 55cc44c..7827ad6 100644 --- a/main.go +++ b/main.go @@ -134,6 +134,7 @@ func main() { Edit: app.Edit, Form: app.Form, Format: app.Cfg.Format, + GRPC: app.GRPC, Headers: app.Cfg.Headers, HTTP: app.Cfg.HTTP, IgnoreStatus: getValue(app.Cfg.IgnoreStatus), @@ -145,6 +146,9 @@ func main() { NoPager: getValue(app.Cfg.NoPager), Output: app.Output, PrinterHandle: handle, + ProtoDesc: app.ProtoDesc, + ProtoFiles: app.ProtoFiles, + ProtoImports: app.ProtoImports, Proxy: app.Cfg.Proxy, QueryParams: app.Cfg.QueryParams, Range: app.Range,