Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"

v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/authzed/authzed-go/v1"
Expand Down Expand Up @@ -214,6 +215,43 @@ func isNoneOf(routes ...string) func(_ context.Context, c interceptors.CallMeta)
}
}

func extraHeadersUnaryInterceptor(headers map[string]string) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if len(headers) > 0 {
md := metadata.New(headers)
ctx = metadata.NewOutgoingContext(ctx, md)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}

func extraHeadersStreamInterceptor(headers map[string]string) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if len(headers) > 0 {
md := metadata.New(headers)
ctx = metadata.NewOutgoingContext(ctx, md)
}
return streamer(ctx, desc, cc, method, opts...)
}
}

func parseExtraHeaders(headerStrings []string) (map[string]string, error) {
headers := make(map[string]string)
for _, headerStr := range headerStrings {
parts := strings.SplitN(headerStr, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid header format '%s': expected 'key=value'", headerStr)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
if key == "" {
return nil, fmt.Errorf("invalid header format '%s': key cannot be empty", headerStr)
}
headers[key] = value
}
return headers, nil
}

// DialOptsFromFlags returns the dial options from the CLI-specified flags.
func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOption, error) {
maxRetries := cobrautil.MustGetUint(cmd, "max-retries")
Expand All @@ -239,6 +277,17 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti
selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute, watchRoute))),
}

// Parse and add extra headers if provided
extraHeaderStrings := cobrautil.MustGetStringSlice(cmd, "extra-header")
if len(extraHeaderStrings) > 0 {
headers, err := parseExtraHeaders(extraHeaderStrings)
if err != nil {
return nil, fmt.Errorf("failed to parse extra headers: %w", err)
}
unaryInterceptors = append(unaryInterceptors, extraHeadersUnaryInterceptor(headers))
streamInterceptors = append(streamInterceptors, extraHeadersStreamInterceptor(headers))
}

if !cobrautil.MustGetBool(cmd, "skip-version-check") {
unaryInterceptors = append(unaryInterceptors, zgrpcutil.CheckServerVersion)
}
Expand Down
2 changes: 2 additions & 0 deletions internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ func TestRetries(t *testing.T) {
zedtesting.StringFlag{FlagName: "proxy", FlagValue: "", Changed: true},
zedtesting.StringFlag{FlagName: "hostname-override", FlagValue: "", Changed: true},
zedtesting.IntFlag{FlagName: "max-message-size", FlagValue: 1000, Changed: true},
zedtesting.StringSliceFlag{FlagName: "extra-header", FlagValue: []string{}, Changed: false},
)
dialOpts, err := client.DialOptsFromFlags(cmd, storage.Token{Insecure: &secure})
require.NoError(t, err)
Expand Down Expand Up @@ -224,6 +225,7 @@ func TestDoesNotRetry(t *testing.T) {
zedtesting.StringFlag{FlagName: "proxy", FlagValue: "", Changed: true},
zedtesting.StringFlag{FlagName: "hostname-override", FlagValue: "", Changed: true},
zedtesting.IntFlag{FlagName: "max-message-size", FlagValue: 1000, Changed: true},
zedtesting.StringSliceFlag{FlagName: "extra-header", FlagValue: []string{}, Changed: false},
)
dialOpts, err := client.DialOptsFromFlags(cmd, storage.Token{Insecure: &secure})
require.NoError(t, err)
Expand Down
1 change: 1 addition & 0 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ zed permission check --explain document:firstdoc writer user:emilia
rootCmd.PersistentFlags().Int("max-message-size", 0, "maximum size *in bytes* (defaults to 4_194_304 bytes ~= 4MB) of a gRPC message that can be sent or received by zed")
rootCmd.PersistentFlags().String("proxy", "", "specify a SOCKS5 proxy address")
rootCmd.PersistentFlags().Uint("max-retries", 10, "maximum number of sequential retries to attempt when a request fails")
rootCmd.PersistentFlags().StringSlice("extra-header", []string{}, "extra header(s) to add to gRPC requests in the format 'key=value' (can be specified multiple times)")
_ = rootCmd.PersistentFlags().MarkHidden("debug") // This cannot return its error.

versionCmd := &cobra.Command{
Expand Down
9 changes: 9 additions & 0 deletions internal/testing/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ type StringFlag struct {
Changed bool
}

type StringSliceFlag struct {
FlagName string
FlagValue []string
Changed bool
}

type BoolFlag struct {
FlagName string
FlagValue bool
Expand Down Expand Up @@ -107,6 +113,9 @@ func CreateTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *co
case StringFlag:
c.Flags().String(f.FlagName, f.FlagValue, "")
c.Flag(f.FlagName).Changed = f.Changed
case StringSliceFlag:
c.Flags().StringSlice(f.FlagName, f.FlagValue, "")
c.Flag(f.FlagName).Changed = f.Changed
case BoolFlag:
c.Flags().Bool(f.FlagName, f.FlagValue, "")
c.Flag(f.FlagName).Changed = f.Changed
Expand Down