Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func main() {
os.Exit(1)
}

userManager := user.NewManager(accountClient, secretClient)
userManager := user.NewManager(accountManager, secretClient)
userReconciler := controller.NewUserReconciler(
mgr.GetClient(),
mgr.GetScheme(),
Expand Down
41 changes: 41 additions & 0 deletions internal/account/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/WirelessCar/nauth/internal/domain"
"github.com/WirelessCar/nauth/internal/k8s"
"github.com/WirelessCar/nauth/internal/ports"
"github.com/WirelessCar/nauth/internal/user"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nkeys"
)
Expand Down Expand Up @@ -349,6 +350,43 @@ func (a *Manager) Delete(ctx context.Context, state *v1alpha1.Account) error {
return nil
}

func (a *Manager) SignUserJWT(ctx context.Context, accountRef domain.NamespacedName, claims *jwt.UserClaims) (*user.SignedUserJWT, error) {
if err := accountRef.Validate(); err != nil {
return nil, fmt.Errorf("invalid account reference %q: %w", accountRef, err)
}
account, err := a.accountReader.Get(ctx, accountRef)
if err != nil {
return nil, fmt.Errorf("failed to get account for user JWT signing: %w", err)
}
accountID := account.GetLabels()[k8s.LabelAccountID]
if accountID == "" {
return nil, fmt.Errorf("account ID is missing for account %s during user JWT signing", accountRef)
}
if claims.IssuerAccount != "" && claims.IssuerAccount != accountID {
return nil, fmt.Errorf("claims issuer account ID %s does not match %s bound to account %q during user JWT signing", claims.IssuerAccount, accountID, accountRef)
}
if claims.IssuerAccount == "" {
claims.IssuerAccount = accountID
}
accountSecrets, err := a.secretManager.GetSecrets(ctx, accountRef, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get account secrets for user JWT signing: %w", err)
}
signPubKey, err := accountSecrets.Sign.PublicKey()
if err != nil {
return nil, fmt.Errorf("failed to get account signing public key for user JWT signing: %w", err)
}
userJWT, err := claims.Encode(accountSecrets.Sign)
if err != nil {
return nil, fmt.Errorf("failed to sign user JWT using %s for account %s (%q): %w", signPubKey, accountID, accountRef, err)
}
return &user.SignedUserJWT{
UserJWT: userJWT,
AccountID: accountID,
SignedBy: signPubKey,
}, nil
}

func (a *Manager) resolveClusterTarget(ctx context.Context, account *v1alpha1.Account) (*clusterTarget, error) {
natsClusterRef := account.Spec.NatsClusterRef
if natsClusterRef != nil && natsClusterRef.Namespace == "" {
Expand All @@ -365,3 +403,6 @@ func getDisplayName(account *v1alpha1.Account) string {
}
return fmt.Sprintf("%s/%s", account.GetNamespace(), account.GetName())
}

var _ controller.AccountManager = (*Manager)(nil)
var _ user.JWTSigner = (*Manager)(nil)
98 changes: 98 additions & 0 deletions internal/account/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,104 @@ func (t *ManagerTestSuite) Test_Delete_ShouldSucceed() {
t.Equal([]interface{}{accountID}, deleteClaims.Data["accounts"])
}

func (t *ManagerTestSuite) Test_SignUserJWT_ShouldSucceed() {
// Given
accountRef := domain.NewNamespacedName("account-namespace", "account-name")
accountRootKey, _ := nkeys.CreateAccount()
accountID, _ := accountRootKey.PublicKey()
accountSignKey, _ := nkeys.CreateAccount()
accountSignKeyPublic, _ := accountSignKey.PublicKey()

account := &v1alpha1.Account{
ObjectMeta: metav1.ObjectMeta{
Namespace: "account-namespace",
Name: "account-name",
Labels: map[string]string{
k8s.LabelAccountID: accountID,
},
},
}
t.accountReaderMock.mockGet(t.ctx, accountRef, account)
t.secretManagerMock.mockGetSecrets(t.ctx, accountRef, accountID, &Secrets{
Root: accountRootKey,
Sign: accountSignKey,
})

userKey, _ := nkeys.CreateUser()
userKeyPublic, _ := userKey.PublicKey()
claims := jwt.NewUserClaims(userKeyPublic)

// When
result, err := t.unitUnderTest.SignUserJWT(t.ctx, accountRef, claims)

// Then
t.NoError(err)
t.NotNil(result)
t.Equal(accountID, result.AccountID)
t.Equal(accountSignKeyPublic, result.SignedBy)

// Verify the JWT is signed with the account's signing key
parsedClaims, err := jwt.DecodeUserClaims(result.UserJWT)
t.NoError(err, "failed to decode signed user JWT")
t.Equal(accountID, parsedClaims.IssuerAccount)
t.Equal(accountSignKeyPublic, parsedClaims.Issuer)
}

func (t *ManagerTestSuite) Test_SignUserJWT_ShouldFailWhenAccountIsNotReady() {
// Given
accountRef := domain.NewNamespacedName("account-namespace", "account-name")

account := &v1alpha1.Account{
ObjectMeta: metav1.ObjectMeta{
Namespace: "account-namespace",
Name: "account-name",
},
}
t.accountReaderMock.mockGet(t.ctx, accountRef, account)

userKey, _ := nkeys.CreateUser()
userKeyPublic, _ := userKey.PublicKey()
claims := jwt.NewUserClaims(userKeyPublic)

// When
result, err := t.unitUnderTest.SignUserJWT(t.ctx, accountRef, claims)

// Then
t.Nil(result)
t.ErrorContains(err, "account ID is missing for account account-namespace/account-name during user JWT signing")
}

func (t *ManagerTestSuite) Test_SignUserJWT_ShouldFailWhenClaimsIssuerAccountDoesNotMatchFoundAccountID() {
// Given
accountRef := domain.NewNamespacedName("account-namespace", "account-name")
accountRootKey, _ := nkeys.CreateAccount()
accountID, _ := accountRootKey.PublicKey()

account := &v1alpha1.Account{
ObjectMeta: metav1.ObjectMeta{
Namespace: "account-namespace",
Name: "account-name",
Labels: map[string]string{
k8s.LabelAccountID: accountID,
},
},
}
t.accountReaderMock.mockGet(t.ctx, accountRef, account)

userKey, _ := nkeys.CreateUser()
userKeyPublic, _ := userKey.PublicKey()
claims := jwt.NewUserClaims(userKeyPublic)
claims.IssuerAccount = "some-other-account-id"

// When
result, err := t.unitUnderTest.SignUserJWT(t.ctx, accountRef, claims)

// Then
t.Nil(result)
t.ErrorContains(err, "claims issuer account ID some-other-account-id does not match "+
accountID+" bound to account \"account-namespace/account-name\" during user JWT signing")
}

/* ****************************************************
* Helpers
*****************************************************/
Expand Down
11 changes: 5 additions & 6 deletions internal/account/claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@ func TestClaims(t *testing.T) {

ctx := context.Background()
accountReaderMock := NewAccountReaderMock()
getAccountCall := accountReaderMock.On("Get", mock.Anything, mock.Anything)
getAccountCall.RunFn = func(args mock.Arguments) {
accountID := fakeAccountId(args.Get(1).(domain.NamespacedName))
getAccountCall := accountReaderMock.mockGetCallback(mock.Anything, mock.Anything, func(accountRef domain.NamespacedName) (*v1alpha1.Account, error) {
accountID := fakeAccountId(accountRef)
account := &v1alpha1.Account{}
account.Labels = map[string]string{
k8s.LabelAccountID: accountID,
}
getAccountCall.Return(*account, nil)
}
return account, nil
})

// Build NATS JWT AccountClaims from AccountSpec
builder := newClaimsBuilder(ctx, testClaimsDisplayName, *spec, testClaimsAccountPubKey, accountReaderMock)
Expand Down Expand Up @@ -95,7 +94,7 @@ func TestClaims(t *testing.T) {
account.Labels = map[string]string{
k8s.LabelAccountID: testClaimsFakeAccountID,
}
getAccountCall.Return(*account, nil)
getAccountCall.Return(account, nil)
}

// Verify that the resulting NAuth AccountClaim generates the same NATS JWT when encoded
Expand Down
15 changes: 13 additions & 2 deletions internal/account/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,19 @@ func NewAccountReaderMock() *AccountReaderMock {

func (a *AccountReaderMock) Get(ctx context.Context, accountRef domain.NamespacedName) (account *v1alpha1.Account, err error) {
args := a.Called(ctx, accountRef)
anAccount := args.Get(0).(v1alpha1.Account)
return &anAccount, args.Error(1)
return args.Get(0).(*v1alpha1.Account), args.Error(1)
}

func (a *AccountReaderMock) mockGet(ctx context.Context, accountRef domain.NamespacedName, result *v1alpha1.Account) {
a.On("Get", ctx, accountRef).Return(result, nil)
}

func (a *AccountReaderMock) mockGetCallback(ctx interface{}, accountRef interface{}, generator func(accountRef domain.NamespacedName) (*v1alpha1.Account, error)) *mock.Call {
call := a.On("Get", ctx, accountRef)
call.RunFn = func(args mock.Arguments) {
call.Return(generator(args.Get(1).(domain.NamespacedName)))
}
return call
}

var _ ports.AccountReader = &AccountReaderMock{}
Expand Down
66 changes: 28 additions & 38 deletions internal/user/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,41 @@ package user
import (
"context"

"github.com/WirelessCar/nauth/api/v1alpha1"
"github.com/WirelessCar/nauth/internal/domain"
"github.com/WirelessCar/nauth/internal/ports"
"github.com/nats-io/jwt/v2"
"github.com/stretchr/testify/mock"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

/* ****************************************************
* JWTSigner Mock
*****************************************************/

func NewUserJWTSignerMock() *UserJWTSignerMock {
return &UserJWTSignerMock{}
}

type UserJWTSignerMock struct {
mock.Mock
}

func (m *UserJWTSignerMock) SignUserJWT(ctx context.Context, accountRef domain.NamespacedName, claims *jwt.UserClaims) (*SignedUserJWT, error) {
args := m.Called(ctx, accountRef, claims)
return args.Get(0).(*SignedUserJWT), args.Error(1)
}

func (m *UserJWTSignerMock) mockSignUserJWT(ctx context.Context, accountRef domain.NamespacedName, callback func(claims *jwt.UserClaims) *SignedUserJWT) {
call := m.On("SignUserJWT", ctx, accountRef, mock.Anything)
call.RunFn = func(args mock.Arguments) {
claims := args.Get(2).(*jwt.UserClaims)
call.Return(callback(claims), nil)
}
}

var _ JWTSigner = &UserJWTSignerMock{}

/* ****************************************************
* ports.SecretClient Mock
*****************************************************/
Expand Down Expand Up @@ -39,20 +66,12 @@ func (s *SecretClientMock) Get(ctx context.Context, secretRef domain.NamespacedN
return args.Get(0).(map[string]string), args.Error(1)
}

func (s *SecretClientMock) mockGet(ctx context.Context, namespacedName domain.NamespacedName, result map[string]string) {
s.On("Get", ctx, namespacedName).Return(result, nil)
}

// GetByLabels implements ports.SecretStorer.
func (s *SecretClientMock) GetByLabels(ctx context.Context, namespace domain.Namespace, labels map[string]string) (*corev1.SecretList, error) {
args := s.Called(ctx, namespace, labels)
return args.Get(0).(*corev1.SecretList), args.Error(1)
}

func (s *SecretClientMock) mockGetByLabels(ctx context.Context, namespace domain.Namespace, labels interface{}, list *corev1.SecretList) {
s.On("GetByLabels", ctx, namespace, labels).Return(list, nil)
}

// DeleteSecret implements ports.SecretStorer.
func (s *SecretClientMock) Delete(ctx context.Context, secretRef domain.NamespacedName) error {
args := s.Called(ctx, secretRef)
Expand All @@ -71,34 +90,5 @@ func (s *SecretClientMock) Label(ctx context.Context, secretRef domain.Namespace
return args.Error(0)
}

func (s *SecretClientMock) mockLabel(namespacedName domain.NamespacedName, labels map[string]string) {
s.On("Label", mock.Anything, namespacedName, labels).Return(nil)
}

// Compile-time assertion that implementation satisfies the ports interface
var _ ports.SecretClient = &SecretClientMock{}

/* ****************************************************
* ports.AccountReader Mock
*****************************************************/

type AccountReaderMock struct {
mock.Mock
}

func NewAccountReaderMock() *AccountReaderMock {
return &AccountReaderMock{}
}

func (a *AccountReaderMock) Get(ctx context.Context, accountRef domain.NamespacedName) (account *v1alpha1.Account, err error) {
args := a.Called(ctx, accountRef)
anAccount := args.Get(0).(v1alpha1.Account)
return &anAccount, args.Error(1)
}

func (a *AccountReaderMock) mockGet(ctx context.Context, accountRef domain.NamespacedName, result v1alpha1.Account) *mock.Call {
return a.On("Get", ctx, accountRef).Return(result, nil)
}

// Compile-time assertion that implementation satisfies the ports interface
var _ ports.AccountReader = &AccountReaderMock{}
Loading
Loading