diff --git a/internal/client/client.go b/internal/client/client.go index e11a3b62..7656f4b4 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/authzed-go/v1" @@ -214,6 +215,43 @@ func isNoneOf(routes ...string) func(_ context.Context, c interceptors.CallMeta) } } +func extraHeadersUnaryInterceptor(headers map[string]string) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if len(headers) > 0 { + md := metadata.New(headers) + ctx = metadata.NewOutgoingContext(ctx, md) + } + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +func extraHeadersStreamInterceptor(headers map[string]string) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if len(headers) > 0 { + md := metadata.New(headers) + ctx = metadata.NewOutgoingContext(ctx, md) + } + return streamer(ctx, desc, cc, method, opts...) + } +} + +func parseExtraHeaders(headerStrings []string) (map[string]string, error) { + headers := make(map[string]string) + for _, headerStr := range headerStrings { + parts := strings.SplitN(headerStr, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid header format '%s': expected 'key=value'", headerStr) + } + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + if key == "" { + return nil, fmt.Errorf("invalid header format '%s': key cannot be empty", headerStr) + } + headers[key] = value + } + return headers, nil +} + // DialOptsFromFlags returns the dial options from the CLI-specified flags. func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOption, error) { maxRetries := cobrautil.MustGetUint(cmd, "max-retries") @@ -239,6 +277,17 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute, watchRoute))), } + // Parse and add extra headers if provided + extraHeaderStrings := cobrautil.MustGetStringSlice(cmd, "extra-header") + if len(extraHeaderStrings) > 0 { + headers, err := parseExtraHeaders(extraHeaderStrings) + if err != nil { + return nil, fmt.Errorf("failed to parse extra headers: %w", err) + } + unaryInterceptors = append(unaryInterceptors, extraHeadersUnaryInterceptor(headers)) + streamInterceptors = append(streamInterceptors, extraHeadersStreamInterceptor(headers)) + } + if !cobrautil.MustGetBool(cmd, "skip-version-check") { unaryInterceptors = append(unaryInterceptors, zgrpcutil.CheckServerVersion) } diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 1573bd09..e71edd4f 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -180,6 +180,7 @@ func TestRetries(t *testing.T) { zedtesting.StringFlag{FlagName: "proxy", FlagValue: "", Changed: true}, zedtesting.StringFlag{FlagName: "hostname-override", FlagValue: "", Changed: true}, zedtesting.IntFlag{FlagName: "max-message-size", FlagValue: 1000, Changed: true}, + zedtesting.StringSliceFlag{FlagName: "extra-header", FlagValue: []string{}, Changed: false}, ) dialOpts, err := client.DialOptsFromFlags(cmd, storage.Token{Insecure: &secure}) require.NoError(t, err) @@ -224,6 +225,7 @@ func TestDoesNotRetry(t *testing.T) { zedtesting.StringFlag{FlagName: "proxy", FlagValue: "", Changed: true}, zedtesting.StringFlag{FlagName: "hostname-override", FlagValue: "", Changed: true}, zedtesting.IntFlag{FlagName: "max-message-size", FlagValue: 1000, Changed: true}, + zedtesting.StringSliceFlag{FlagName: "extra-header", FlagValue: []string{}, Changed: false}, ) dialOpts, err := client.DialOptsFromFlags(cmd, storage.Token{Insecure: &secure}) require.NoError(t, err) diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 795e6362..aa085b3c 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -91,6 +91,7 @@ zed permission check --explain document:firstdoc writer user:emilia rootCmd.PersistentFlags().Int("max-message-size", 0, "maximum size *in bytes* (defaults to 4_194_304 bytes ~= 4MB) of a gRPC message that can be sent or received by zed") rootCmd.PersistentFlags().String("proxy", "", "specify a SOCKS5 proxy address") rootCmd.PersistentFlags().Uint("max-retries", 10, "maximum number of sequential retries to attempt when a request fails") + rootCmd.PersistentFlags().StringSlice("extra-header", []string{}, "extra header(s) to add to gRPC requests in the format 'key=value' (can be specified multiple times)") _ = rootCmd.PersistentFlags().MarkHidden("debug") // This cannot return its error. versionCmd := &cobra.Command{ diff --git a/internal/testing/test_helpers.go b/internal/testing/test_helpers.go index d921fdbb..f92c2103 100644 --- a/internal/testing/test_helpers.go +++ b/internal/testing/test_helpers.go @@ -68,6 +68,12 @@ type StringFlag struct { Changed bool } +type StringSliceFlag struct { + FlagName string + FlagValue []string + Changed bool +} + type BoolFlag struct { FlagName string FlagValue bool @@ -107,6 +113,9 @@ func CreateTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *co case StringFlag: c.Flags().String(f.FlagName, f.FlagValue, "") c.Flag(f.FlagName).Changed = f.Changed + case StringSliceFlag: + c.Flags().StringSlice(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed case BoolFlag: c.Flags().Bool(f.FlagName, f.FlagValue, "") c.Flag(f.FlagName).Changed = f.Changed