Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func NewADCMechanism(ctx context.Context) (*Mechanism, error) {
}
email, err := getADCPrincipalEmail(creds)
if err != nil {
return nil, fmt.Errorf("error fetching principal email for Appication Default Credentials: %w", err)
return nil, fmt.Errorf("error fetching principal email for Appication Default Credentials: %w. Please set the %q environment variable", err, principalEmailEnvVar)
}
return &Mechanism{
emailAddress: email,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"io"
"net/http"
"net/mail"
"os"
"regexp"
"strings"

Expand All @@ -35,10 +36,19 @@ const (
tokenInfoAPIURL = "https://www.googleapis.com/oauth2/v3/tokeninfo/?access_token="
// Metadatadata Server name for the default Service Account
defaultServiceAccountMetadataServerName = "default"
// Environment variable name for the principal email
principalEmailEnvVar = "GOOGLE_MANAGED_KAFKA_AUTH_PRINCIPAL"
)

// Returns the principal email address for ADC credentials
func getADCPrincipalEmail(creds *google.Credentials) (string, error) {
// Check if the user has explicitly set the principal email via an environment variable
if email, ok := os.LookupEnv(principalEmailEnvVar); ok && email != "" {
if err := validatePrincipalEmail(email); err != nil {
return "", fmt.Errorf("principal email (%q) from environment variable (%q) did not pass validation: %w", email, principalEmailEnvVar, err)
}
return email, nil
}

// If we are in GCE - then we can fetch the email address of the default Service Account
// directly from the Metadata server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,61 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)

func TestGetADCPrincipalEmail(t *testing.T) {
tests := []struct {
name string
setupFunc func()
cleanUpFunc func()
expectErr bool
expectEmail string
}{
{
name: "Env Var Set - Valid Email",
setupFunc: func() { _ = os.Setenv(principalEmailEnvVar, "human-principal@example.com") },
cleanUpFunc: func() { _ = os.Unsetenv(principalEmailEnvVar) },
expectEmail: "human-principal@example.com",
},
{
name: "Env Var Set - Invalid Email",
setupFunc: func() { _ = os.Setenv(principalEmailEnvVar, "this-is-not-an-email") },
cleanUpFunc: func() { _ = os.Unsetenv(principalEmailEnvVar) },
expectErr: true,
expectEmail: "",
},
{
name: "Env Var Empty",
setupFunc: func() { _ = os.Setenv(principalEmailEnvVar, "") },
cleanUpFunc: func() { _ = os.Unsetenv(principalEmailEnvVar) },
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setupFunc != nil {
tt.setupFunc()
}
if tt.cleanUpFunc != nil {
defer tt.cleanUpFunc()
}
gotEmail, gotErr := getADCPrincipalEmail(&google.Credentials{})
if tt.expectErr {
assert.Error(t, gotErr)
return
}
assert.NoError(t, gotErr)
assert.Equal(t, tt.expectEmail, gotEmail)
})
}
}

func TestValidatePrincipalEmail(t *testing.T) {

tests := []struct {
Expand Down