diff --git a/go.mod b/go.mod index e18045b86..1917e1c58 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,7 @@ require ( github.com/twmb/franz-go v1.20.6 github.com/twmb/franz-go/pkg/kadm v1.17.1 github.com/twmb/franz-go/pkg/kmsg v1.12.0 + github.com/xdg-go/scram v1.1.2 go.opentelemetry.io/otel v1.40.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.40.0 @@ -281,7 +282,6 @@ require ( github.com/urfave/cli/v2 v2.27.7 // indirect github.com/wlynxg/anet v0.0.5 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect - github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect diff --git a/pkg/output/kafka/client.go b/pkg/output/kafka/client.go index 88a9892bd..09752e383 100644 --- a/pkg/output/kafka/client.go +++ b/pkg/output/kafka/client.go @@ -1,19 +1,62 @@ package kafka import ( + "crypto/sha256" + "crypto/sha512" "crypto/tls" "crypto/x509" "errors" "fmt" + "hash" "os" "strings" "time" "github.com/IBM/sarama" + "github.com/xdg-go/scram" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" ) +// xdgSCRAMClient implements sarama.SCRAMClient using the xdg-go/scram library. +// Sarama requires a SCRAMClientGeneratorFunc to perform the SCRAM handshake; +// without it the mechanism name is set but authentication always fails. +type xdgSCRAMClient struct { + *scram.Client + *scram.ClientConversation + scram.HashGeneratorFcn +} + +func (x *xdgSCRAMClient) Begin(userName, password, authzID string) error { + client, err := x.HashGeneratorFcn.NewClient(userName, password, authzID) + if err != nil { + return err + } + + x.Client = client + x.ClientConversation = client.NewConversation() + + return nil +} + +func (x *xdgSCRAMClient) Step(challenge string) (string, error) { + return x.ClientConversation.Step(challenge) +} + +func (x *xdgSCRAMClient) Done() bool { + return x.ClientConversation.Done() +} + +// scramSHA256Generator returns a new SCRAM-SHA-256 client. +func scramSHA256Generator() sarama.SCRAMClient { + return &xdgSCRAMClient{HashGeneratorFcn: scram.HashGeneratorFcn(func() hash.Hash { return sha256.New() })} +} + +// scramSHA512Generator returns a new SCRAM-SHA-512 client. +func scramSHA512Generator() sarama.SCRAMClient { + return &xdgSCRAMClient{HashGeneratorFcn: scram.HashGeneratorFcn(func() hash.Hash { return sha512.New() })} +} + // CompressionStrategy defines the compression codec for Kafka messages. type CompressionStrategy string @@ -250,8 +293,10 @@ func InitSaramaConfig(config *ProducerConfig, maxExportBatchSize int) (*sarama.C c.Net.SASL.Mechanism = sarama.SASLTypeOAuth case SASLTypeSCRAMSHA256: c.Net.SASL.Mechanism = sarama.SASLTypeSCRAMSHA256 + c.Net.SASL.SCRAMClientGeneratorFunc = scramSHA256Generator case SASLTypeSCRAMSHA512: c.Net.SASL.Mechanism = sarama.SASLTypeSCRAMSHA512 + c.Net.SASL.SCRAMClientGeneratorFunc = scramSHA512Generator case SASLTypeGSSAPI: c.Net.SASL.Mechanism = sarama.SASLTypeGSSAPI default: diff --git a/pkg/output/kafka/client_test.go b/pkg/output/kafka/client_test.go index 33039310a..18a7fac08 100644 --- a/pkg/output/kafka/client_test.go +++ b/pkg/output/kafka/client_test.go @@ -321,15 +321,16 @@ func TestInitSaramaConfig(t *testing.T) { t.Run("SASL mechanisms", func(t *testing.T) { tests := []struct { - mechanism SASLMechanism - want sarama.SASLMechanism + mechanism SASLMechanism + want sarama.SASLMechanism + wantSCRAMClient bool }{ - {SASLTypeOAuth, sarama.SASLTypeOAuth}, - {SASLTypeSCRAMSHA256, sarama.SASLTypeSCRAMSHA256}, - {SASLTypeSCRAMSHA512, sarama.SASLTypeSCRAMSHA512}, - {SASLTypeGSSAPI, sarama.SASLTypeGSSAPI}, - {SASLTypePlaintext, sarama.SASLTypePlaintext}, - {"UNKNOWN", sarama.SASLTypePlaintext}, + {SASLTypeOAuth, sarama.SASLTypeOAuth, false}, + {SASLTypeSCRAMSHA256, sarama.SASLTypeSCRAMSHA256, true}, + {SASLTypeSCRAMSHA512, sarama.SASLTypeSCRAMSHA512, true}, + {SASLTypeGSSAPI, sarama.SASLTypeGSSAPI, false}, + {SASLTypePlaintext, sarama.SASLTypePlaintext, false}, + {"UNKNOWN", sarama.SASLTypePlaintext, false}, } for _, tt := range tests { @@ -349,6 +350,17 @@ func TestInitSaramaConfig(t *testing.T) { assert.Equal(t, tt.want, sc.Net.SASL.Mechanism) assert.Equal(t, "user", sc.Net.SASL.User) assert.Equal(t, "pass", sc.Net.SASL.Password) + + if tt.wantSCRAMClient { + assert.NotNil(t, sc.Net.SASL.SCRAMClientGeneratorFunc, + "SCRAMClientGeneratorFunc must be set for SCRAM mechanisms") + + // Verify the generated client implements sarama.SCRAMClient. + client := sc.Net.SASL.SCRAMClientGeneratorFunc() + assert.NotNil(t, client) + } else { + assert.Nil(t, sc.Net.SASL.SCRAMClientGeneratorFunc) + } }) } })