Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cmd/hubauth-ext/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/flynn/hubauth/pkg/datastore"
"github.com/flynn/hubauth/pkg/httpapi"
"github.com/flynn/hubauth/pkg/idp"
"github.com/flynn/hubauth/pkg/idp/token"
"github.com/flynn/hubauth/pkg/kmssign"
"github.com/flynn/hubauth/pkg/rp/google"
"go.opencensus.io/plugin/ochttp"
Expand Down Expand Up @@ -77,10 +78,12 @@ func main() {
os.Getenv("RP_GOOGLE_CLIENT_SECRET"),
os.Getenv("BASE_URL")+"/rp/google",
),
kmsClient,
[]byte(secret("CODE_KEY_SECRET")),
refreshKey,
idp.AudienceKeyNameFunc(os.Getenv("PROJECT_ID"), os.Getenv("KMS_LOCATION"), os.Getenv("KMS_KEYRING")),
token.NewSignedPBBuilder(
kmsClient,
kmssign.AudienceKeyNameFunc(os.Getenv("PROJECT_ID"), os.Getenv("KMS_LOCATION"), os.Getenv("KMS_KEYRING")),
),
),
CookieKey: []byte(secret("COOKIE_KEY_SECRET")),
ProjectID: os.Getenv("PROJECT_ID"),
Expand Down
62 changes: 22 additions & 40 deletions pkg/idp/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ package idp

import (
"context"
"crypto"
"encoding/base64"
"net/url"
"strings"
"time"

"github.com/flynn/hubauth/pkg/clog"
"github.com/flynn/hubauth/pkg/hmacpb"
"github.com/flynn/hubauth/pkg/hubauth"
"github.com/flynn/hubauth/pkg/kmssign"
"github.com/flynn/hubauth/pkg/idp/token"
"github.com/flynn/hubauth/pkg/pb"
"github.com/flynn/hubauth/pkg/rp"
"github.com/flynn/hubauth/pkg/signpb"
Expand All @@ -22,22 +20,10 @@ import (
"golang.org/x/sync/errgroup"
)

type AudienceKeyNamer func(audience string) string

const oobRedirectURI = "urn:ietf:wg:oauth:2.0:oob"
const codeExpiry = 30 * time.Second
const accessTokenDuration = 5 * time.Minute

func AudienceKeyNameFunc(projectID, location, keyRing string) func(string) string {
return func(aud string) string {
u, err := url.Parse(aud)
if err != nil {
return ""
}
return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/1", projectID, location, keyRing, strings.Replace(u.Host, ".", "_", -1))
}
}

type clock interface {
Now() time.Time
}
Expand All @@ -61,34 +47,31 @@ type idpSteps interface {
SignRefreshToken(ctx context.Context, signKey signpb.PrivateKey, t *signedRefreshTokenData) (string, error)
RenewRefreshToken(ctx context.Context, clientID, oldTokenID string, oldTokenIssueTime, now time.Time) (*hubauth.RefreshToken, error)
VerifyRefreshToken(ctx context.Context, rt *hubauth.RefreshToken, now time.Time) error
SignAccessToken(ctx context.Context, signKey signpb.PrivateKey, t *accessTokenData, now time.Time) (string, error)
SignAccessToken(ctx context.Context, audience string, t *token.AccessTokenData, now time.Time) (string, error)
}

type idpService struct {
db hubauth.DataStore
rp rp.AuthService
kms kmssign.KMSClient
db hubauth.DataStore
rp rp.AuthService

codeKey hmacpb.Key
refreshKey signpb.Key
audienceKey AudienceKeyNamer
codeKey hmacpb.Key
refreshKey signpb.Key

steps idpSteps
clock clock
}

var _ hubauth.IdPService = (*idpService)(nil)

func New(db hubauth.DataStore, rp rp.AuthService, kms kmssign.KMSClient, codeKey hmacpb.Key, refreshKey signpb.Key, audienceKey AudienceKeyNamer) hubauth.IdPService {
func New(db hubauth.DataStore, rp rp.AuthService, codeKey hmacpb.Key, refreshKey signpb.Key, tokenBuilder token.AccessTokenBuilder) hubauth.IdPService {
return &idpService{
db: db,
rp: rp,
kms: kms,
codeKey: codeKey,
refreshKey: refreshKey,
audienceKey: audienceKey,
db: db,
rp: rp,
codeKey: codeKey,
refreshKey: refreshKey,
steps: &steps{
db: db,
db: db,
builder: tokenBuilder,
},
clock: clockImpl{},
}
Expand Down Expand Up @@ -325,11 +308,11 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan
if req.Audience == "" {
return nil
}
signKey := kmssign.NewPrivateKey(s.kms, s.audienceKey(req.Audience), crypto.SHA256)
accessToken, err = s.steps.SignAccessToken(ctx, signKey, &accessTokenData{
clientID: req.ClientID,
userID: codeInfo.UserId,
userEmail: codeInfo.UserEmail,

accessToken, err = s.steps.SignAccessToken(ctx, req.Audience, &token.AccessTokenData{
ClientID: req.ClientID,
UserID: codeInfo.UserId,
UserEmail: codeInfo.UserEmail,
}, now)
return err
})
Expand Down Expand Up @@ -399,11 +382,10 @@ func (s *idpService) RefreshToken(ctx context.Context, req *hubauth.RefreshToken
if req.Audience == "" {
return nil
}
signKey := kmssign.NewPrivateKey(s.kms, s.audienceKey(req.Audience), crypto.SHA256)
accessToken, err = s.steps.SignAccessToken(ctx, signKey, &accessTokenData{
clientID: req.ClientID,
userID: oldToken.UserID,
userEmail: oldToken.UserEmail,
accessToken, err = s.steps.SignAccessToken(ctx, req.Audience, &token.AccessTokenData{
ClientID: req.ClientID,
UserID: oldToken.UserID,
UserEmail: oldToken.UserEmail,
}, now)
return err
})
Expand Down
42 changes: 12 additions & 30 deletions pkg/idp/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package idp

import (
"context"
"crypto"
"crypto/rand"
"errors"
"fmt"
Expand All @@ -14,6 +13,7 @@ import (
"github.com/flynn/hubauth/pkg/datastore"
"github.com/flynn/hubauth/pkg/hmacpb"
"github.com/flynn/hubauth/pkg/hubauth"
"github.com/flynn/hubauth/pkg/idp/token"
"github.com/flynn/hubauth/pkg/kmssign"
"github.com/flynn/hubauth/pkg/kmssign/kmssim"
"github.com/flynn/hubauth/pkg/pb"
Expand All @@ -40,10 +40,6 @@ func (m *mockAuthService) Exchange(ctx context.Context, rr *rp.RedirectResult) (
return args.Get(0).(*rp.Token), args.Error(1)
}

func audienceKeyNamer(s string) string {
return fmt.Sprintf("%s_named", s)
}

type mockSteps struct {
mock.Mock
}
Expand Down Expand Up @@ -82,8 +78,8 @@ func (m *mockSteps) SignRefreshToken(ctx context.Context, signKey signpb.Private
args := m.Called(ctx, signKey, t)
return args.String(0), args.Error(1)
}
func (m *mockSteps) SignAccessToken(ctx context.Context, signKey signpb.PrivateKey, t *accessTokenData, now time.Time) (string, error) {
args := m.Called(ctx, signKey, t, now)
func (m *mockSteps) SignAccessToken(ctx context.Context, audience string, t *token.AccessTokenData, now time.Time) (string, error) {
args := m.Called(ctx, audience, t, now)
return args.String(0), args.Error(1)
}
func (m *mockSteps) RenewRefreshToken(ctx context.Context, clientID, oldTokenID string, oldTokenIssueTime, now time.Time) (*hubauth.RefreshToken, error) {
Expand Down Expand Up @@ -124,7 +120,7 @@ func newTestIdPService(t *testing.T, kmsKeys ...string) *idpService {
refreshKey, err := kmssign.NewKey(context.Background(), kms, refreshKeyName)
require.NoError(t, err)

s := New(db, authService, kms, codeKey, refreshKey, audienceKeyNamer).(*idpService)
s := New(db, authService, codeKey, refreshKey, nil).(*idpService)
s.steps = &mockSteps{}
s.clock = &mockClock{}

Expand Down Expand Up @@ -669,10 +665,10 @@ func TestExchangeCode(t *testing.T) {
}).Return(verifiedCode, nil)
idpService.steps.(*mockSteps).On("SaveRefreshToken", mock.Anything, b64CodeID, redirectURI, rtData).Return(client, nil)
idpService.steps.(*mockSteps).On("SignRefreshToken", mock.Anything, idpService.refreshKey, signedRTData).Return(refreshToken, nil)
idpService.steps.(*mockSteps).On("SignAccessToken", mock.Anything, kmssign.NewPrivateKey(idpService.kms, audienceKeyNamer(audienceURL), crypto.SHA256), &accessTokenData{
clientID: clientID,
userID: userID,
userEmail: userEmail,
idpService.steps.(*mockSteps).On("SignAccessToken", mock.Anything, audienceURL, &token.AccessTokenData{
ClientID: clientID,
UserID: userID,
UserEmail: userEmail,
}, now).Return(accessToken, nil)

req := &hubauth.ExchangeCodeRequest{
Expand Down Expand Up @@ -898,11 +894,10 @@ func TestRefreshToken(t *testing.T) {
},
ExpiryTime: expireTimeProto.AsTime(),
}).Return(newRefreshTokenStr, nil)
signKey := kmssign.NewPrivateKey(idpService.kms, audienceKeyNamer(testCase.AudienceURL), crypto.SHA256)
idpService.steps.(*mockSteps).On("SignAccessToken", mock.Anything, signKey, &accessTokenData{
clientID: b64ClientID,
userID: userID,
userEmail: userEmail,
idpService.steps.(*mockSteps).On("SignAccessToken", mock.Anything, testCase.AudienceURL, &token.AccessTokenData{
ClientID: b64ClientID,
UserID: userID,
UserEmail: userEmail,
}, now).Return(newAccessTokenStr, nil)

oldTokenSigned, err := signpb.SignMarshal(context.Background(), idpService.refreshKey, &pb.RefreshToken{
Expand Down Expand Up @@ -1015,12 +1010,6 @@ func TestRefreshTokenStepErrors(t *testing.T) {
}

func prepareInvalidRefreshTokenTestCases(t *testing.T, idpService *idpService, wrongKeyName string) []*invalidRefreshTokenTestCase {
wrongKey, err := kmssign.NewKey(context.Background(), idpService.kms, wrongKeyName)
require.NoError(t, err)

wrongKeyRefreshToken, err := signpb.SignMarshal(context.Background(), wrongKey, &pb.RefreshToken{})
require.NoError(t, err)

now := time.Now()
expiredTime, _ := ptypes.TimestampProto(now.Add(-1 * time.Second))
expiredRefreshToken, err := signpb.SignMarshal(context.Background(), idpService.refreshKey, &pb.RefreshToken{
Expand All @@ -1045,13 +1034,6 @@ func prepareInvalidRefreshTokenTestCases(t *testing.T, idpService *idpService, w
Description: "invalid refresh_token",
},
},
{
RefreshToken: base64Encode(wrongKeyRefreshToken),
Err: &hubauth.OAuthError{
Code: "invalid_grant",
Description: "invalid refresh_token",
},
},
{
RefreshToken: base64Encode(expiredRefreshToken),
Err: &hubauth.OAuthError{
Expand Down
37 changes: 12 additions & 25 deletions pkg/idp/steps.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/flynn/hubauth/pkg/clog"
"github.com/flynn/hubauth/pkg/hmacpb"
"github.com/flynn/hubauth/pkg/hubauth"
"github.com/flynn/hubauth/pkg/idp/token"
"github.com/flynn/hubauth/pkg/pb"
"github.com/flynn/hubauth/pkg/signpb"
"github.com/golang/protobuf/ptypes"
Expand All @@ -19,7 +20,8 @@ import (
)

type steps struct {
db hubauth.DataStore
db hubauth.DataStore
builder token.AccessTokenBuilder
}

var _ idpSteps = (*steps)(nil)
Expand Down Expand Up @@ -349,43 +351,28 @@ func (s *steps) VerifyRefreshToken(ctx context.Context, rt *hubauth.RefreshToken
return nil
}

type accessTokenData struct {
clientID string
userID string
userEmail string
}

func (s *steps) SignAccessToken(ctx context.Context, signKey signpb.PrivateKey, t *accessTokenData, now time.Time) (token string, err error) {
func (s *steps) SignAccessToken(ctx context.Context, audience string, t *token.AccessTokenData, now time.Time) (token string, err error) {
ctx, span := trace.StartSpan(ctx, "idp.SignAccessToken")
span.AddAttributes(
trace.StringAttribute("client_id", t.clientID),
trace.StringAttribute("user_id", t.userID),
trace.StringAttribute("user_email", t.userEmail),
trace.StringAttribute("client_id", t.ClientID),
trace.StringAttribute("user_id", t.UserID),
trace.StringAttribute("user_email", t.UserEmail),
)
defer span.End()

exp, _ := ptypes.TimestampProto(now.Add(accessTokenDuration))
iss, _ := ptypes.TimestampProto(now)
msg := &pb.AccessToken{
ClientId: t.clientID,
UserId: t.userID,
UserEmail: t.userEmail,
IssueTime: iss,
ExpireTime: exp,
}
tokenBytes, err := signpb.SignMarshal(ctx, signKey, msg)
tokenBytes, err := s.builder.Build(ctx, audience, t, now, accessTokenDuration)
if err != nil {
return "", fmt.Errorf("idp: error signing access token: %w", err)
return "", fmt.Errorf("idp: error building access token: %w", err)
}
idBytes := sha256.Sum256(tokenBytes)

token = base64.URLEncoding.EncodeToString(tokenBytes)
idBytes := sha256.Sum256(tokenBytes)
accessTokenID := base64Encode(idBytes[:])
span.AddAttributes(trace.StringAttribute("access_token_id", accessTokenID))

clog.Set(ctx,
zap.String("issued_access_token_id", accessTokenID),
zap.Duration("issued_access_token_expires_in", accessTokenDuration),
)
return token, nil

return base64.URLEncoding.EncodeToString(tokenBytes), nil
}
Loading