diff --git a/cmd/hubauth-ext/main.go b/cmd/hubauth-ext/main.go index d78b7c7..ef75f1f 100644 --- a/cmd/hubauth-ext/main.go +++ b/cmd/hubauth-ext/main.go @@ -15,6 +15,7 @@ import ( "github.com/flynn/hubauth/pkg/datastore" "github.com/flynn/hubauth/pkg/httpapi" "github.com/flynn/hubauth/pkg/idp" + "github.com/flynn/hubauth/pkg/idp/token" "github.com/flynn/hubauth/pkg/kmssign" "github.com/flynn/hubauth/pkg/rp/google" "go.opencensus.io/plugin/ochttp" @@ -77,10 +78,12 @@ func main() { os.Getenv("RP_GOOGLE_CLIENT_SECRET"), os.Getenv("BASE_URL")+"/rp/google", ), - kmsClient, []byte(secret("CODE_KEY_SECRET")), refreshKey, - idp.AudienceKeyNameFunc(os.Getenv("PROJECT_ID"), os.Getenv("KMS_LOCATION"), os.Getenv("KMS_KEYRING")), + token.NewSignedPBBuilder( + kmsClient, + kmssign.AudienceKeyNameFunc(os.Getenv("PROJECT_ID"), os.Getenv("KMS_LOCATION"), os.Getenv("KMS_KEYRING")), + ), ), CookieKey: []byte(secret("COOKIE_KEY_SECRET")), ProjectID: os.Getenv("PROJECT_ID"), diff --git a/pkg/idp/oauth.go b/pkg/idp/oauth.go index 3815fad..1dccd8f 100644 --- a/pkg/idp/oauth.go +++ b/pkg/idp/oauth.go @@ -2,16 +2,14 @@ package idp import ( "context" - "crypto" "encoding/base64" - "net/url" "strings" "time" "github.com/flynn/hubauth/pkg/clog" "github.com/flynn/hubauth/pkg/hmacpb" "github.com/flynn/hubauth/pkg/hubauth" - "github.com/flynn/hubauth/pkg/kmssign" + "github.com/flynn/hubauth/pkg/idp/token" "github.com/flynn/hubauth/pkg/pb" "github.com/flynn/hubauth/pkg/rp" "github.com/flynn/hubauth/pkg/signpb" @@ -22,22 +20,10 @@ import ( "golang.org/x/sync/errgroup" ) -type AudienceKeyNamer func(audience string) string - const oobRedirectURI = "urn:ietf:wg:oauth:2.0:oob" const codeExpiry = 30 * time.Second const accessTokenDuration = 5 * time.Minute -func AudienceKeyNameFunc(projectID, location, keyRing string) func(string) string { - return func(aud string) string { - u, err := url.Parse(aud) - if err != nil { - return "" - } - return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/1", projectID, location, keyRing, strings.Replace(u.Host, ".", "_", -1)) - } -} - type clock interface { Now() time.Time } @@ -61,17 +47,15 @@ type idpSteps interface { SignRefreshToken(ctx context.Context, signKey signpb.PrivateKey, t *signedRefreshTokenData) (string, error) RenewRefreshToken(ctx context.Context, clientID, oldTokenID string, oldTokenIssueTime, now time.Time) (*hubauth.RefreshToken, error) VerifyRefreshToken(ctx context.Context, rt *hubauth.RefreshToken, now time.Time) error - SignAccessToken(ctx context.Context, signKey signpb.PrivateKey, t *accessTokenData, now time.Time) (string, error) + SignAccessToken(ctx context.Context, audience string, t *token.AccessTokenData, now time.Time) (string, error) } type idpService struct { - db hubauth.DataStore - rp rp.AuthService - kms kmssign.KMSClient + db hubauth.DataStore + rp rp.AuthService - codeKey hmacpb.Key - refreshKey signpb.Key - audienceKey AudienceKeyNamer + codeKey hmacpb.Key + refreshKey signpb.Key steps idpSteps clock clock @@ -79,16 +63,15 @@ type idpService struct { var _ hubauth.IdPService = (*idpService)(nil) -func New(db hubauth.DataStore, rp rp.AuthService, kms kmssign.KMSClient, codeKey hmacpb.Key, refreshKey signpb.Key, audienceKey AudienceKeyNamer) hubauth.IdPService { +func New(db hubauth.DataStore, rp rp.AuthService, codeKey hmacpb.Key, refreshKey signpb.Key, tokenBuilder token.AccessTokenBuilder) hubauth.IdPService { return &idpService{ - db: db, - rp: rp, - kms: kms, - codeKey: codeKey, - refreshKey: refreshKey, - audienceKey: audienceKey, + db: db, + rp: rp, + codeKey: codeKey, + refreshKey: refreshKey, steps: &steps{ - db: db, + db: db, + builder: tokenBuilder, }, clock: clockImpl{}, } @@ -325,11 +308,11 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan if req.Audience == "" { return nil } - signKey := kmssign.NewPrivateKey(s.kms, s.audienceKey(req.Audience), crypto.SHA256) - accessToken, err = s.steps.SignAccessToken(ctx, signKey, &accessTokenData{ - clientID: req.ClientID, - userID: codeInfo.UserId, - userEmail: codeInfo.UserEmail, + + accessToken, err = s.steps.SignAccessToken(ctx, req.Audience, &token.AccessTokenData{ + ClientID: req.ClientID, + UserID: codeInfo.UserId, + UserEmail: codeInfo.UserEmail, }, now) return err }) @@ -399,11 +382,10 @@ func (s *idpService) RefreshToken(ctx context.Context, req *hubauth.RefreshToken if req.Audience == "" { return nil } - signKey := kmssign.NewPrivateKey(s.kms, s.audienceKey(req.Audience), crypto.SHA256) - accessToken, err = s.steps.SignAccessToken(ctx, signKey, &accessTokenData{ - clientID: req.ClientID, - userID: oldToken.UserID, - userEmail: oldToken.UserEmail, + accessToken, err = s.steps.SignAccessToken(ctx, req.Audience, &token.AccessTokenData{ + ClientID: req.ClientID, + UserID: oldToken.UserID, + UserEmail: oldToken.UserEmail, }, now) return err }) diff --git a/pkg/idp/oauth_test.go b/pkg/idp/oauth_test.go index d513374..3f2e62c 100644 --- a/pkg/idp/oauth_test.go +++ b/pkg/idp/oauth_test.go @@ -2,7 +2,6 @@ package idp import ( "context" - "crypto" "crypto/rand" "errors" "fmt" @@ -14,6 +13,7 @@ import ( "github.com/flynn/hubauth/pkg/datastore" "github.com/flynn/hubauth/pkg/hmacpb" "github.com/flynn/hubauth/pkg/hubauth" + "github.com/flynn/hubauth/pkg/idp/token" "github.com/flynn/hubauth/pkg/kmssign" "github.com/flynn/hubauth/pkg/kmssign/kmssim" "github.com/flynn/hubauth/pkg/pb" @@ -40,10 +40,6 @@ func (m *mockAuthService) Exchange(ctx context.Context, rr *rp.RedirectResult) ( return args.Get(0).(*rp.Token), args.Error(1) } -func audienceKeyNamer(s string) string { - return fmt.Sprintf("%s_named", s) -} - type mockSteps struct { mock.Mock } @@ -82,8 +78,8 @@ func (m *mockSteps) SignRefreshToken(ctx context.Context, signKey signpb.Private args := m.Called(ctx, signKey, t) return args.String(0), args.Error(1) } -func (m *mockSteps) SignAccessToken(ctx context.Context, signKey signpb.PrivateKey, t *accessTokenData, now time.Time) (string, error) { - args := m.Called(ctx, signKey, t, now) +func (m *mockSteps) SignAccessToken(ctx context.Context, audience string, t *token.AccessTokenData, now time.Time) (string, error) { + args := m.Called(ctx, audience, t, now) return args.String(0), args.Error(1) } func (m *mockSteps) RenewRefreshToken(ctx context.Context, clientID, oldTokenID string, oldTokenIssueTime, now time.Time) (*hubauth.RefreshToken, error) { @@ -124,7 +120,7 @@ func newTestIdPService(t *testing.T, kmsKeys ...string) *idpService { refreshKey, err := kmssign.NewKey(context.Background(), kms, refreshKeyName) require.NoError(t, err) - s := New(db, authService, kms, codeKey, refreshKey, audienceKeyNamer).(*idpService) + s := New(db, authService, codeKey, refreshKey, nil).(*idpService) s.steps = &mockSteps{} s.clock = &mockClock{} @@ -669,10 +665,10 @@ func TestExchangeCode(t *testing.T) { }).Return(verifiedCode, nil) idpService.steps.(*mockSteps).On("SaveRefreshToken", mock.Anything, b64CodeID, redirectURI, rtData).Return(client, nil) idpService.steps.(*mockSteps).On("SignRefreshToken", mock.Anything, idpService.refreshKey, signedRTData).Return(refreshToken, nil) - idpService.steps.(*mockSteps).On("SignAccessToken", mock.Anything, kmssign.NewPrivateKey(idpService.kms, audienceKeyNamer(audienceURL), crypto.SHA256), &accessTokenData{ - clientID: clientID, - userID: userID, - userEmail: userEmail, + idpService.steps.(*mockSteps).On("SignAccessToken", mock.Anything, audienceURL, &token.AccessTokenData{ + ClientID: clientID, + UserID: userID, + UserEmail: userEmail, }, now).Return(accessToken, nil) req := &hubauth.ExchangeCodeRequest{ @@ -898,11 +894,10 @@ func TestRefreshToken(t *testing.T) { }, ExpiryTime: expireTimeProto.AsTime(), }).Return(newRefreshTokenStr, nil) - signKey := kmssign.NewPrivateKey(idpService.kms, audienceKeyNamer(testCase.AudienceURL), crypto.SHA256) - idpService.steps.(*mockSteps).On("SignAccessToken", mock.Anything, signKey, &accessTokenData{ - clientID: b64ClientID, - userID: userID, - userEmail: userEmail, + idpService.steps.(*mockSteps).On("SignAccessToken", mock.Anything, testCase.AudienceURL, &token.AccessTokenData{ + ClientID: b64ClientID, + UserID: userID, + UserEmail: userEmail, }, now).Return(newAccessTokenStr, nil) oldTokenSigned, err := signpb.SignMarshal(context.Background(), idpService.refreshKey, &pb.RefreshToken{ @@ -1015,12 +1010,6 @@ func TestRefreshTokenStepErrors(t *testing.T) { } func prepareInvalidRefreshTokenTestCases(t *testing.T, idpService *idpService, wrongKeyName string) []*invalidRefreshTokenTestCase { - wrongKey, err := kmssign.NewKey(context.Background(), idpService.kms, wrongKeyName) - require.NoError(t, err) - - wrongKeyRefreshToken, err := signpb.SignMarshal(context.Background(), wrongKey, &pb.RefreshToken{}) - require.NoError(t, err) - now := time.Now() expiredTime, _ := ptypes.TimestampProto(now.Add(-1 * time.Second)) expiredRefreshToken, err := signpb.SignMarshal(context.Background(), idpService.refreshKey, &pb.RefreshToken{ @@ -1045,13 +1034,6 @@ func prepareInvalidRefreshTokenTestCases(t *testing.T, idpService *idpService, w Description: "invalid refresh_token", }, }, - { - RefreshToken: base64Encode(wrongKeyRefreshToken), - Err: &hubauth.OAuthError{ - Code: "invalid_grant", - Description: "invalid refresh_token", - }, - }, { RefreshToken: base64Encode(expiredRefreshToken), Err: &hubauth.OAuthError{ diff --git a/pkg/idp/steps.go b/pkg/idp/steps.go index 5a2c6be..96769c9 100644 --- a/pkg/idp/steps.go +++ b/pkg/idp/steps.go @@ -9,6 +9,7 @@ import ( "github.com/flynn/hubauth/pkg/clog" "github.com/flynn/hubauth/pkg/hmacpb" "github.com/flynn/hubauth/pkg/hubauth" + "github.com/flynn/hubauth/pkg/idp/token" "github.com/flynn/hubauth/pkg/pb" "github.com/flynn/hubauth/pkg/signpb" "github.com/golang/protobuf/ptypes" @@ -19,7 +20,8 @@ import ( ) type steps struct { - db hubauth.DataStore + db hubauth.DataStore + builder token.AccessTokenBuilder } var _ idpSteps = (*steps)(nil) @@ -349,37 +351,21 @@ func (s *steps) VerifyRefreshToken(ctx context.Context, rt *hubauth.RefreshToken return nil } -type accessTokenData struct { - clientID string - userID string - userEmail string -} - -func (s *steps) SignAccessToken(ctx context.Context, signKey signpb.PrivateKey, t *accessTokenData, now time.Time) (token string, err error) { +func (s *steps) SignAccessToken(ctx context.Context, audience string, t *token.AccessTokenData, now time.Time) (token string, err error) { ctx, span := trace.StartSpan(ctx, "idp.SignAccessToken") span.AddAttributes( - trace.StringAttribute("client_id", t.clientID), - trace.StringAttribute("user_id", t.userID), - trace.StringAttribute("user_email", t.userEmail), + trace.StringAttribute("client_id", t.ClientID), + trace.StringAttribute("user_id", t.UserID), + trace.StringAttribute("user_email", t.UserEmail), ) defer span.End() - exp, _ := ptypes.TimestampProto(now.Add(accessTokenDuration)) - iss, _ := ptypes.TimestampProto(now) - msg := &pb.AccessToken{ - ClientId: t.clientID, - UserId: t.userID, - UserEmail: t.userEmail, - IssueTime: iss, - ExpireTime: exp, - } - tokenBytes, err := signpb.SignMarshal(ctx, signKey, msg) + tokenBytes, err := s.builder.Build(ctx, audience, t, now, accessTokenDuration) if err != nil { - return "", fmt.Errorf("idp: error signing access token: %w", err) + return "", fmt.Errorf("idp: error building access token: %w", err) } - idBytes := sha256.Sum256(tokenBytes) - token = base64.URLEncoding.EncodeToString(tokenBytes) + idBytes := sha256.Sum256(tokenBytes) accessTokenID := base64Encode(idBytes[:]) span.AddAttributes(trace.StringAttribute("access_token_id", accessTokenID)) @@ -387,5 +373,6 @@ func (s *steps) SignAccessToken(ctx context.Context, signKey signpb.PrivateKey, zap.String("issued_access_token_id", accessTokenID), zap.Duration("issued_access_token_expires_in", accessTokenDuration), ) - return token, nil + + return base64.URLEncoding.EncodeToString(tokenBytes), nil } diff --git a/pkg/idp/steps_test.go b/pkg/idp/steps_test.go index 46865c6..971b0a9 100644 --- a/pkg/idp/steps_test.go +++ b/pkg/idp/steps_test.go @@ -12,21 +12,39 @@ import ( "github.com/flynn/hubauth/pkg/datastore" "github.com/flynn/hubauth/pkg/hmacpb" "github.com/flynn/hubauth/pkg/hubauth" + "github.com/flynn/hubauth/pkg/idp/token" "github.com/flynn/hubauth/pkg/kmssign" "github.com/flynn/hubauth/pkg/kmssign/kmssim" "github.com/flynn/hubauth/pkg/pb" "github.com/flynn/hubauth/pkg/signpb" "github.com/golang/protobuf/ptypes" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" ) +const ( + testAudienceName = "audienceXYZ" +) + +type mockAccessTokenBuilder struct { + mock.Mock +} + +var _ token.AccessTokenBuilder = (*mockAccessTokenBuilder)(nil) + +func (m *mockAccessTokenBuilder) Build(ctx context.Context, audience string, t *token.AccessTokenData, now time.Time, duration time.Duration) ([]byte, error) { + args := m.Called(ctx, audience, t, now, duration) + return args.Get(0).([]byte), args.Error(1) +} + func newTestSteps(t *testing.T) *steps { dsc, err := gdatastore.NewClient(context.Background(), "test") require.NoError(t, err) return &steps{ - db: datastore.New(dsc), + db: datastore.New(dsc), + builder: &mockAccessTokenBuilder{}, } } @@ -754,36 +772,23 @@ func TestVerifyRefreshTokenErrors(t *testing.T) { func TestSignAccessToken(t *testing.T) { s := newTestSteps(t) - signKeyName := "refreshKey" - kms := kmssim.NewClient([]string{signKeyName}) - signKey, err := kmssign.NewKey(context.Background(), kms, signKeyName) - require.NoError(t, err) - now := time.Now() - data := &accessTokenData{ - clientID: "clientID", - userID: "userID", - userEmail: "userEmail", + data := &token.AccessTokenData{ + ClientID: "clientID", + UserID: "userID", + UserEmail: "userEmail", } - accessToken, err := s.SignAccessToken(context.Background(), signKey, data, now) + expectedAccessToken := []byte("expected-access-token") + + s.builder.(*mockAccessTokenBuilder).On("Build", mock.Anything, testAudienceName, data, now, accessTokenDuration).Return(expectedAccessToken, nil) + + accessToken, err := s.SignAccessToken(context.Background(), testAudienceName, data, now) require.NoError(t, err) require.NotEmpty(t, accessToken) - got := new(pb.AccessToken) - accessTokenBytes, err := base64Decode(accessToken) require.NoError(t, err) - - require.NoError(t, signpb.VerifyUnmarshal(signKey, accessTokenBytes, got)) - require.Equal(t, data.clientID, got.ClientId) - require.Equal(t, data.userID, got.UserId) - require.Equal(t, data.userEmail, got.UserEmail) - - nowPb, _ := ptypes.TimestampProto(now) - require.Equal(t, nowPb, got.IssueTime) - - expirePb, _ := ptypes.TimestampProto(now.Add(accessTokenDuration)) - require.Equal(t, expirePb, got.ExpireTime) + require.Equal(t, expectedAccessToken, accessTokenBytes) } diff --git a/pkg/idp/token/builder.go b/pkg/idp/token/builder.go new file mode 100644 index 0000000..5a60575 --- /dev/null +++ b/pkg/idp/token/builder.go @@ -0,0 +1,57 @@ +package token + +import ( + "context" + "crypto" + "fmt" + "time" + + "github.com/flynn/hubauth/pkg/kmssign" + "github.com/flynn/hubauth/pkg/pb" + "github.com/flynn/hubauth/pkg/signpb" + "github.com/golang/protobuf/ptypes" +) + +type AccessTokenData struct { + ClientID string + UserID string + UserEmail string +} + +type AccessTokenBuilder interface { + Build(ctx context.Context, audience string, t *AccessTokenData, now time.Time, duration time.Duration) ([]byte, error) +} + +type signedPbBuilder struct { + kms kmssign.KMSClient + audienceKey kmssign.AudienceKeyNamer +} + +var _ AccessTokenBuilder = (*signedPbBuilder)(nil) + +func NewSignedPBBuilder(kms kmssign.KMSClient, audienceKey kmssign.AudienceKeyNamer) AccessTokenBuilder { + return &signedPbBuilder{ + kms: kms, + audienceKey: audienceKey, + } +} + +func (b *signedPbBuilder) Build(ctx context.Context, audience string, t *AccessTokenData, now time.Time, duration time.Duration) ([]byte, error) { + signKey := kmssign.NewPrivateKey(b.kms, b.audienceKey(audience), crypto.SHA256) + + exp, _ := ptypes.TimestampProto(now.Add(duration)) + iss, _ := ptypes.TimestampProto(now) + msg := &pb.AccessToken{ + ClientId: t.ClientID, + UserId: t.UserID, + UserEmail: t.UserEmail, + IssueTime: iss, + ExpireTime: exp, + } + tokenBytes, err := signpb.SignMarshal(ctx, signKey, msg) + if err != nil { + return nil, fmt.Errorf("token: error signing access token: %w", err) + } + + return tokenBytes, nil +} diff --git a/pkg/idp/token/builder_test.go b/pkg/idp/token/builder_test.go new file mode 100644 index 0000000..9785c77 --- /dev/null +++ b/pkg/idp/token/builder_test.go @@ -0,0 +1,57 @@ +package token + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/flynn/hubauth/pkg/kmssign" + "github.com/flynn/hubauth/pkg/kmssign/kmssim" + "github.com/flynn/hubauth/pkg/pb" + "github.com/flynn/hubauth/pkg/signpb" + "github.com/golang/protobuf/ptypes" + "github.com/stretchr/testify/require" +) + +func audienceKeyNamer(s string) string { + return fmt.Sprintf("%s_named", s) +} + +func TestSignedPBBuilder(t *testing.T) { + audienceName := "audience_url" + audienceKeyName := audienceKeyNamer(audienceName) + kms := kmssim.NewClient([]string{audienceKeyName}) + + builder := NewSignedPBBuilder(kms, audienceKeyNamer) + + signKey, err := kmssign.NewKey(context.Background(), kms, audienceKeyName) + require.NoError(t, err) + + now := time.Now() + ctx := context.Background() + + data := &AccessTokenData{ + ClientID: "clientID", + UserEmail: "userEmail", + UserID: "userID", + } + + accessTokenDuration := 5 * time.Minute + + accessTokenBytes, err := builder.Build(ctx, audienceName, data, now, accessTokenDuration) + require.NoError(t, err) + + got := new(pb.AccessToken) + require.NoError(t, signpb.VerifyUnmarshal(signKey, accessTokenBytes, got)) + + require.Equal(t, data.ClientID, got.ClientId) + require.Equal(t, data.UserID, got.UserId) + require.Equal(t, data.UserEmail, got.UserEmail) + + nowPb, _ := ptypes.TimestampProto(now) + require.Equal(t, nowPb, got.IssueTime) + + expirePb, _ := ptypes.TimestampProto(now.Add(accessTokenDuration)) + require.Equal(t, expirePb, got.ExpireTime) +} diff --git a/pkg/kmssign/kms.go b/pkg/kmssign/kms.go index 8671e3e..766ffe3 100644 --- a/pkg/kmssign/kms.go +++ b/pkg/kmssign/kms.go @@ -8,6 +8,8 @@ import ( "encoding/pem" "io" "math/big" + "net/url" + "strings" gax "github.com/googleapis/gax-go/v2" "golang.org/x/crypto/cryptobyte" @@ -16,6 +18,18 @@ import ( kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1" ) +type AudienceKeyNamer func(audience string) string + +func AudienceKeyNameFunc(projectID, location, keyRing string) func(string) string { + return func(aud string) string { + u, err := url.Parse(aud) + if err != nil { + return "" + } + return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/1", projectID, location, keyRing, strings.Replace(u.Host, ".", "_", -1)) + } +} + type KMSClient interface { AsymmetricSign(ctx context.Context, req *kmspb.AsymmetricSignRequest, opts ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error) GetPublicKey(ctx context.Context, req *kmspb.GetPublicKeyRequest, opts ...gax.CallOption) (*kmspb.PublicKey, error)