From a728521ff92d7758cf53b2a3d7536f7d1d088f87 Mon Sep 17 00:00:00 2001 From: Thobias Karlsson Date: Mon, 16 Mar 2026 09:59:47 +0100 Subject: [PATCH] feat: account manager to sign User JWT Today User Manager has its own (duplicated) code to look up the Account Signing Key. Since this has been strictened up and the Account Manager should be the sole manager dealing with Account related secrets, we should let the Account Manager be the entity (approving and) signing the User JWTs created during User reconciliation. Signed-off-by: Thobias Karlsson --- cmd/main.go | 2 +- internal/account/account.go | 41 ++++++ internal/account/account_test.go | 98 +++++++++++++++ internal/account/claims_test.go | 11 +- internal/account/mocks_test.go | 15 ++- internal/user/mocks_test.go | 66 +++++----- internal/user/user.go | 178 ++++---------------------- internal/user/user_test.go | 210 +++++++++---------------------- 8 files changed, 266 insertions(+), 355 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index 932d9aa..d15f8f6 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -301,7 +301,7 @@ func main() { os.Exit(1) } - userManager := user.NewManager(accountClient, secretClient) + userManager := user.NewManager(accountManager, secretClient) userReconciler := controller.NewUserReconciler( mgr.GetClient(), mgr.GetScheme(), diff --git a/internal/account/account.go b/internal/account/account.go index e079288..d2f68ae 100644 --- a/internal/account/account.go +++ b/internal/account/account.go @@ -10,6 +10,7 @@ import ( "github.com/WirelessCar/nauth/internal/domain" "github.com/WirelessCar/nauth/internal/k8s" "github.com/WirelessCar/nauth/internal/ports" + "github.com/WirelessCar/nauth/internal/user" "github.com/nats-io/jwt/v2" "github.com/nats-io/nkeys" ) @@ -349,6 +350,43 @@ func (a *Manager) Delete(ctx context.Context, state *v1alpha1.Account) error { return nil } +func (a *Manager) SignUserJWT(ctx context.Context, accountRef domain.NamespacedName, claims *jwt.UserClaims) (*user.SignedUserJWT, error) { + if err := accountRef.Validate(); err != nil { + return nil, fmt.Errorf("invalid account reference %q: %w", accountRef, err) + } + account, err := a.accountReader.Get(ctx, accountRef) + if err != nil { + return nil, fmt.Errorf("failed to get account for user JWT signing: %w", err) + } + accountID := account.GetLabels()[k8s.LabelAccountID] + if accountID == "" { + return nil, fmt.Errorf("account ID is missing for account %s during user JWT signing", accountRef) + } + if claims.IssuerAccount != "" && claims.IssuerAccount != accountID { + return nil, fmt.Errorf("claims issuer account ID %s does not match %s bound to account %q during user JWT signing", claims.IssuerAccount, accountID, accountRef) + } + if claims.IssuerAccount == "" { + claims.IssuerAccount = accountID + } + accountSecrets, err := a.secretManager.GetSecrets(ctx, accountRef, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get account secrets for user JWT signing: %w", err) + } + signPubKey, err := accountSecrets.Sign.PublicKey() + if err != nil { + return nil, fmt.Errorf("failed to get account signing public key for user JWT signing: %w", err) + } + userJWT, err := claims.Encode(accountSecrets.Sign) + if err != nil { + return nil, fmt.Errorf("failed to sign user JWT using %s for account %s (%q): %w", signPubKey, accountID, accountRef, err) + } + return &user.SignedUserJWT{ + UserJWT: userJWT, + AccountID: accountID, + SignedBy: signPubKey, + }, nil +} + func (a *Manager) resolveClusterTarget(ctx context.Context, account *v1alpha1.Account) (*clusterTarget, error) { natsClusterRef := account.Spec.NatsClusterRef if natsClusterRef != nil && natsClusterRef.Namespace == "" { @@ -365,3 +403,6 @@ func getDisplayName(account *v1alpha1.Account) string { } return fmt.Sprintf("%s/%s", account.GetNamespace(), account.GetName()) } + +var _ controller.AccountManager = (*Manager)(nil) +var _ user.JWTSigner = (*Manager)(nil) diff --git a/internal/account/account_test.go b/internal/account/account_test.go index 71ab3c6..28616b0 100644 --- a/internal/account/account_test.go +++ b/internal/account/account_test.go @@ -363,6 +363,104 @@ func (t *ManagerTestSuite) Test_Delete_ShouldSucceed() { t.Equal([]interface{}{accountID}, deleteClaims.Data["accounts"]) } +func (t *ManagerTestSuite) Test_SignUserJWT_ShouldSucceed() { + // Given + accountRef := domain.NewNamespacedName("account-namespace", "account-name") + accountRootKey, _ := nkeys.CreateAccount() + accountID, _ := accountRootKey.PublicKey() + accountSignKey, _ := nkeys.CreateAccount() + accountSignKeyPublic, _ := accountSignKey.PublicKey() + + account := &v1alpha1.Account{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "account-namespace", + Name: "account-name", + Labels: map[string]string{ + k8s.LabelAccountID: accountID, + }, + }, + } + t.accountReaderMock.mockGet(t.ctx, accountRef, account) + t.secretManagerMock.mockGetSecrets(t.ctx, accountRef, accountID, &Secrets{ + Root: accountRootKey, + Sign: accountSignKey, + }) + + userKey, _ := nkeys.CreateUser() + userKeyPublic, _ := userKey.PublicKey() + claims := jwt.NewUserClaims(userKeyPublic) + + // When + result, err := t.unitUnderTest.SignUserJWT(t.ctx, accountRef, claims) + + // Then + t.NoError(err) + t.NotNil(result) + t.Equal(accountID, result.AccountID) + t.Equal(accountSignKeyPublic, result.SignedBy) + + // Verify the JWT is signed with the account's signing key + parsedClaims, err := jwt.DecodeUserClaims(result.UserJWT) + t.NoError(err, "failed to decode signed user JWT") + t.Equal(accountID, parsedClaims.IssuerAccount) + t.Equal(accountSignKeyPublic, parsedClaims.Issuer) +} + +func (t *ManagerTestSuite) Test_SignUserJWT_ShouldFailWhenAccountIsNotReady() { + // Given + accountRef := domain.NewNamespacedName("account-namespace", "account-name") + + account := &v1alpha1.Account{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "account-namespace", + Name: "account-name", + }, + } + t.accountReaderMock.mockGet(t.ctx, accountRef, account) + + userKey, _ := nkeys.CreateUser() + userKeyPublic, _ := userKey.PublicKey() + claims := jwt.NewUserClaims(userKeyPublic) + + // When + result, err := t.unitUnderTest.SignUserJWT(t.ctx, accountRef, claims) + + // Then + t.Nil(result) + t.ErrorContains(err, "account ID is missing for account account-namespace/account-name during user JWT signing") +} + +func (t *ManagerTestSuite) Test_SignUserJWT_ShouldFailWhenClaimsIssuerAccountDoesNotMatchFoundAccountID() { + // Given + accountRef := domain.NewNamespacedName("account-namespace", "account-name") + accountRootKey, _ := nkeys.CreateAccount() + accountID, _ := accountRootKey.PublicKey() + + account := &v1alpha1.Account{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "account-namespace", + Name: "account-name", + Labels: map[string]string{ + k8s.LabelAccountID: accountID, + }, + }, + } + t.accountReaderMock.mockGet(t.ctx, accountRef, account) + + userKey, _ := nkeys.CreateUser() + userKeyPublic, _ := userKey.PublicKey() + claims := jwt.NewUserClaims(userKeyPublic) + claims.IssuerAccount = "some-other-account-id" + + // When + result, err := t.unitUnderTest.SignUserJWT(t.ctx, accountRef, claims) + + // Then + t.Nil(result) + t.ErrorContains(err, "claims issuer account ID some-other-account-id does not match "+ + accountID+" bound to account \"account-namespace/account-name\" during user JWT signing") +} + /* **************************************************** * Helpers *****************************************************/ diff --git a/internal/account/claims_test.go b/internal/account/claims_test.go index 19375a6..02800a8 100644 --- a/internal/account/claims_test.go +++ b/internal/account/claims_test.go @@ -50,15 +50,14 @@ func TestClaims(t *testing.T) { ctx := context.Background() accountReaderMock := NewAccountReaderMock() - getAccountCall := accountReaderMock.On("Get", mock.Anything, mock.Anything) - getAccountCall.RunFn = func(args mock.Arguments) { - accountID := fakeAccountId(args.Get(1).(domain.NamespacedName)) + getAccountCall := accountReaderMock.mockGetCallback(mock.Anything, mock.Anything, func(accountRef domain.NamespacedName) (*v1alpha1.Account, error) { + accountID := fakeAccountId(accountRef) account := &v1alpha1.Account{} account.Labels = map[string]string{ k8s.LabelAccountID: accountID, } - getAccountCall.Return(*account, nil) - } + return account, nil + }) // Build NATS JWT AccountClaims from AccountSpec builder := newClaimsBuilder(ctx, testClaimsDisplayName, *spec, testClaimsAccountPubKey, accountReaderMock) @@ -95,7 +94,7 @@ func TestClaims(t *testing.T) { account.Labels = map[string]string{ k8s.LabelAccountID: testClaimsFakeAccountID, } - getAccountCall.Return(*account, nil) + getAccountCall.Return(account, nil) } // Verify that the resulting NAuth AccountClaim generates the same NATS JWT when encoded diff --git a/internal/account/mocks_test.go b/internal/account/mocks_test.go index 6ce1fb9..9ed4b5d 100644 --- a/internal/account/mocks_test.go +++ b/internal/account/mocks_test.go @@ -237,8 +237,19 @@ func NewAccountReaderMock() *AccountReaderMock { func (a *AccountReaderMock) Get(ctx context.Context, accountRef domain.NamespacedName) (account *v1alpha1.Account, err error) { args := a.Called(ctx, accountRef) - anAccount := args.Get(0).(v1alpha1.Account) - return &anAccount, args.Error(1) + return args.Get(0).(*v1alpha1.Account), args.Error(1) +} + +func (a *AccountReaderMock) mockGet(ctx context.Context, accountRef domain.NamespacedName, result *v1alpha1.Account) { + a.On("Get", ctx, accountRef).Return(result, nil) +} + +func (a *AccountReaderMock) mockGetCallback(ctx interface{}, accountRef interface{}, generator func(accountRef domain.NamespacedName) (*v1alpha1.Account, error)) *mock.Call { + call := a.On("Get", ctx, accountRef) + call.RunFn = func(args mock.Arguments) { + call.Return(generator(args.Get(1).(domain.NamespacedName))) + } + return call } var _ ports.AccountReader = &AccountReaderMock{} diff --git a/internal/user/mocks_test.go b/internal/user/mocks_test.go index 154bd0e..e207904 100644 --- a/internal/user/mocks_test.go +++ b/internal/user/mocks_test.go @@ -3,14 +3,41 @@ package user import ( "context" - "github.com/WirelessCar/nauth/api/v1alpha1" "github.com/WirelessCar/nauth/internal/domain" "github.com/WirelessCar/nauth/internal/ports" + "github.com/nats-io/jwt/v2" "github.com/stretchr/testify/mock" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +/* **************************************************** +* JWTSigner Mock +*****************************************************/ + +func NewUserJWTSignerMock() *UserJWTSignerMock { + return &UserJWTSignerMock{} +} + +type UserJWTSignerMock struct { + mock.Mock +} + +func (m *UserJWTSignerMock) SignUserJWT(ctx context.Context, accountRef domain.NamespacedName, claims *jwt.UserClaims) (*SignedUserJWT, error) { + args := m.Called(ctx, accountRef, claims) + return args.Get(0).(*SignedUserJWT), args.Error(1) +} + +func (m *UserJWTSignerMock) mockSignUserJWT(ctx context.Context, accountRef domain.NamespacedName, callback func(claims *jwt.UserClaims) *SignedUserJWT) { + call := m.On("SignUserJWT", ctx, accountRef, mock.Anything) + call.RunFn = func(args mock.Arguments) { + claims := args.Get(2).(*jwt.UserClaims) + call.Return(callback(claims), nil) + } +} + +var _ JWTSigner = &UserJWTSignerMock{} + /* **************************************************** * ports.SecretClient Mock *****************************************************/ @@ -39,20 +66,12 @@ func (s *SecretClientMock) Get(ctx context.Context, secretRef domain.NamespacedN return args.Get(0).(map[string]string), args.Error(1) } -func (s *SecretClientMock) mockGet(ctx context.Context, namespacedName domain.NamespacedName, result map[string]string) { - s.On("Get", ctx, namespacedName).Return(result, nil) -} - // GetByLabels implements ports.SecretStorer. func (s *SecretClientMock) GetByLabels(ctx context.Context, namespace domain.Namespace, labels map[string]string) (*corev1.SecretList, error) { args := s.Called(ctx, namespace, labels) return args.Get(0).(*corev1.SecretList), args.Error(1) } -func (s *SecretClientMock) mockGetByLabels(ctx context.Context, namespace domain.Namespace, labels interface{}, list *corev1.SecretList) { - s.On("GetByLabels", ctx, namespace, labels).Return(list, nil) -} - // DeleteSecret implements ports.SecretStorer. func (s *SecretClientMock) Delete(ctx context.Context, secretRef domain.NamespacedName) error { args := s.Called(ctx, secretRef) @@ -71,34 +90,5 @@ func (s *SecretClientMock) Label(ctx context.Context, secretRef domain.Namespace return args.Error(0) } -func (s *SecretClientMock) mockLabel(namespacedName domain.NamespacedName, labels map[string]string) { - s.On("Label", mock.Anything, namespacedName, labels).Return(nil) -} - // Compile-time assertion that implementation satisfies the ports interface var _ ports.SecretClient = &SecretClientMock{} - -/* **************************************************** -* ports.AccountReader Mock -*****************************************************/ - -type AccountReaderMock struct { - mock.Mock -} - -func NewAccountReaderMock() *AccountReaderMock { - return &AccountReaderMock{} -} - -func (a *AccountReaderMock) Get(ctx context.Context, accountRef domain.NamespacedName) (account *v1alpha1.Account, err error) { - args := a.Called(ctx, accountRef) - anAccount := args.Get(0).(v1alpha1.Account) - return &anAccount, args.Error(1) -} - -func (a *AccountReaderMock) mockGet(ctx context.Context, accountRef domain.NamespacedName, result v1alpha1.Account) *mock.Call { - return a.On("Get", ctx, accountRef).Return(result, nil) -} - -// Compile-time assertion that implementation satisfies the ports interface -var _ ports.AccountReader = &AccountReaderMock{} diff --git a/internal/user/user.go b/internal/user/user.go index 799fa1f..f7030d3 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -2,9 +2,7 @@ package user import ( "context" - "errors" "fmt" - "sync" "github.com/WirelessCar/nauth/api/v1alpha1" "github.com/WirelessCar/nauth/internal/domain" @@ -16,15 +14,25 @@ import ( logf "sigs.k8s.io/controller-runtime/pkg/log" ) +type SignedUserJWT struct { + UserJWT string + AccountID string + SignedBy string +} + +type JWTSigner interface { + SignUserJWT(ctx context.Context, accountRef domain.NamespacedName, claims *jwt.UserClaims) (*SignedUserJWT, error) +} + type Manager struct { - accountsReader ports.AccountReader - secretClient ports.SecretClient + userJWTSigner JWTSigner + secretClient ports.SecretClient } -func NewManager(accountsReader ports.AccountReader, secretClient ports.SecretClient) *Manager { +func NewManager(userJWTSigner JWTSigner, secretClient ports.SecretClient) *Manager { return &Manager{ - accountsReader: accountsReader, - secretClient: secretClient, + userJWTSigner: userJWTSigner, + secretClient: secretClient, } } @@ -34,23 +42,8 @@ func (u *Manager) CreateOrUpdate(ctx context.Context, state *v1alpha1.User) erro if err := accountRef.Validate(); err != nil { return fmt.Errorf("invalid account reference %q: %w", accountRef, err) } - account, err := u.accountsReader.Get(ctx, accountRef) - if err != nil { - return err - } - accountID := account.GetLabels()[k8s.LabelAccountID] - if accountID == "" { - return fmt.Errorf("account %s does not have an account ID yet", accountRef) - } - accountSigningKeyPair, err := u.getAccountSigningKeyPair(ctx, accountRef, accountID) - if err != nil { - return fmt.Errorf("failed to get signing key secret %s: %w", accountRef, err) - } - accountSigningKeyPublicKey, err := accountSigningKeyPair.PublicKey() - if err != nil { - return fmt.Errorf("failed to get account signing public key: %w", err) - } + existingUserAccountID := state.GetLabels()[k8s.LabelUserAccountID] userKeyPair, err := nkeys.CreateUser() if err != nil { @@ -65,14 +58,14 @@ func (u *Manager) CreateOrUpdate(ctx context.Context, state *v1alpha1.User) erro return fmt.Errorf("failed to get user seed: %w", err) } - natsClaims := newClaimsBuilder(getDisplayName(state), state.Spec, userPublicKey, accountID). + natsClaims := newClaimsBuilder(u.getDisplayName(state), state.Spec, userPublicKey, existingUserAccountID). build() - userJwt, err := natsClaims.Encode(accountSigningKeyPair) + signedUserJWT, err := u.userJWTSigner.SignUserJWT(ctx, accountRef, natsClaims) if err != nil { return fmt.Errorf("failed to sign user jwt for %s: %w", userRef, err) } - userCreds, err := jwt.FormatUserConfig(userJwt, userSeed) + userCreds, err := jwt.FormatUserConfig(signedUserJWT.UserJWT, userSeed) if err != nil { return fmt.Errorf("failed to format user credentials: %w", err) } @@ -93,7 +86,6 @@ func (u *Manager) CreateOrUpdate(ctx context.Context, state *v1alpha1.User) erro return err } - toNAuthUserClaims(natsClaims) state.Status.Claims = toNAuthUserClaims(natsClaims) if state.Labels == nil { @@ -101,8 +93,8 @@ func (u *Manager) CreateOrUpdate(ctx context.Context, state *v1alpha1.User) erro } state.GetLabels()[k8s.LabelUserID] = userPublicKey - state.GetLabels()[k8s.LabelUserAccountID] = account.GetLabels()[k8s.LabelAccountID] - state.GetLabels()[k8s.LabelUserSignedBy] = accountSigningKeyPublicKey + state.GetLabels()[k8s.LabelUserAccountID] = signedUserJWT.AccountID + state.GetLabels()[k8s.LabelUserSignedBy] = signedUserJWT.SignedBy state.Status.ObservedGeneration = state.Generation state.Status.ReconcileTimestamp = metav1.Now() @@ -126,135 +118,9 @@ func (u *Manager) Delete(ctx context.Context, state *v1alpha1.User) error { return nil } -func (u *Manager) getAccountSigningKeyPair(ctx context.Context, accountRef domain.NamespacedName, accountID string) (nkeys.KeyPair, error) { - if keyPair, err := u.getAccountSigningKeyPairByAccountID(ctx, accountRef, accountID); err == nil { - return keyPair, nil - } - - keyPair, err := u.getDeprecatedAccountSigningKeyPair(ctx, accountRef, accountID) - if err != nil { - return nil, err - } - - return keyPair, nil -} - -func (u *Manager) getAccountSigningKeyPairByAccountID(ctx context.Context, accountRef domain.NamespacedName, accountID string) (nkeys.KeyPair, error) { - labels := map[string]string{ - k8s.LabelAccountID: accountID, - k8s.LabelSecretType: k8s.SecretTypeAccountSign, - k8s.LabelManaged: k8s.LabelManagedValue, - } - secrets, err := u.secretClient.GetByLabels(ctx, accountRef.GetNamespace(), labels) - if err != nil { - return nil, fmt.Errorf("failed to get signing secret for account: %s due to %w", accountRef, err) - } - - if len(secrets.Items) < 1 { - return nil, fmt.Errorf("no signing secret found for account: %s", accountRef) - } - - if len(secrets.Items) > 1 { - return nil, fmt.Errorf("more than 1 signing secret found for account: %s", accountRef) - } - - seed, ok := secrets.Items[0].Data[k8s.DefaultSecretKeyName] - if !ok { - return nil, fmt.Errorf("secret for user credentials seed was malformed") - } - return nkeys.FromSeed(seed) -} - -func getDisplayName(user *v1alpha1.User) string { +func (u *Manager) getDisplayName(user *v1alpha1.User) string { if user.Spec.DisplayName != "" { return user.Spec.DisplayName } return fmt.Sprintf("%s/%s", user.GetNamespace(), user.GetName()) } - -// Todo: Almost identical to the one in account/account.go - refactor ? -func (u *Manager) getDeprecatedAccountSigningKeyPair(ctx context.Context, accountRef domain.NamespacedName, accountID string) (nkeys.KeyPair, error) { - logger := logf.FromContext(ctx) - - type goRoutineResult struct { - secret map[string]string - err error - } - var wg sync.WaitGroup - ch := make(chan goRoutineResult, 2) - - namespace := accountRef.GetNamespace() - for _, s := range []struct { - secretRef domain.NamespacedName - secretType string - }{ - { - secretRef: namespace.WithName(fmt.Sprintf(k8s.DeprecatedSecretNameAccountRootTemplate, accountRef.Name)), - secretType: k8s.SecretTypeAccountRoot, - }, - { - secretRef: namespace.WithName(fmt.Sprintf(k8s.DeprecatedSecretNameAccountSignTemplate, accountRef.Name)), - secretType: k8s.SecretTypeAccountSign, - }, - } { - wg.Add(1) - go func(secretRef domain.NamespacedName, secretType string) { - result := goRoutineResult{} - defer wg.Done() - defer func() { - if r := recover(); r != nil { - result.err = fmt.Errorf("recovered panicked go routine from trying to get secret %s of type %s: %v", secretRef, secretType, r) - ch <- result - } - }() - - accountSecret, err := u.secretClient.Get(ctx, secretRef) - if err != nil { - result.err = err - ch <- result - return - } - - labels := map[string]string{ - k8s.LabelAccountID: accountID, - k8s.LabelSecretType: secretType, - k8s.LabelManaged: k8s.LabelManagedValue, - } - if err := u.secretClient.Label(ctx, secretRef, labels); err != nil { - logger.Info("unable to label secret", "secretRef", secretRef, "secretType", secretType, "error", err) - } - accountSecret[k8s.LabelSecretType] = secretType - result.secret = accountSecret - ch <- result - }(s.secretRef, s.secretType) - } - - wg.Wait() - close(ch) - - var errs []error - secrets := make(map[string]map[string]string, 2) - - for res := range ch { - if res.err != nil { - errs = append(errs, res.err) - continue - } - secrets[res.secret[k8s.LabelSecretType]] = res.secret - } - - if len(errs) > 0 { - return nil, errors.Join(errs...) - } - - accountSignSecret, ok := secrets[k8s.SecretTypeAccountSign] - if !ok { - return nil, fmt.Errorf("no deprecated signing key found for account %s", accountRef) - } - - accountSignSecretSeed, ok := accountSignSecret[k8s.DefaultSecretKeyName] - if !ok { - return nil, fmt.Errorf("no deprecated signing key seed found for account %s", accountRef) - } - return nkeys.FromSeed([]byte(accountSignSecretSeed)) -} diff --git a/internal/user/user_test.go b/internal/user/user_test.go index e6ef47e..6697c22 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -2,18 +2,16 @@ package user import ( "context" - "fmt" "github.com/WirelessCar/nauth/api/v1alpha1" "github.com/WirelessCar/nauth/internal/domain" "github.com/WirelessCar/nauth/internal/k8s" + "github.com/nats-io/jwt/v2" "github.com/nats-io/nkeys" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/stretchr/testify/mock" - corev1 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/utils/ptr" ) var userName = "user" @@ -21,7 +19,6 @@ var userName = "user" const ( accountName = "test-account" accountNamespace = "default" - unlimitedLimit = -1 ) var _ = Describe("User manager", func() { @@ -29,101 +26,48 @@ var _ = Describe("User manager", func() { var ( ctx = context.Background() userManager *Manager - accountReaderMock *AccountReaderMock + userJWTSignerMock *UserJWTSignerMock secretClientMock *SecretClientMock ) BeforeEach(func() { By("creating the user manager") secretClientMock = NewSecretClientMock() - accountReaderMock = NewAccountReaderMock() - userManager = NewManager(accountReaderMock, secretClientMock) + userJWTSignerMock = NewUserJWTSignerMock() + userManager = NewManager(userJWTSignerMock, secretClientMock) }) AfterEach(func() { secretClientMock.AssertExpectations(GinkgoT()) - accountReaderMock.AssertExpectations(GinkgoT()) + userJWTSignerMock.AssertExpectations(GinkgoT()) }) It("creates a new user belonging to the correct account", func() { - account := GetExistingAccount() - user := GetNewUser() - - By("providing a user specification without any specific configuration") - accountReaderMock.mockGet(ctx, domain.NewNamespacedName(accountNamespace, accountName), *account) - - By("mocking preexisting account keys & CR") - accountSigningKeyPair, _ := nkeys.CreateAccount() - accountSigningSeed, _ := accountSigningKeyPair.Seed() - secretsList := &corev1.SecretList{ - Items: []corev1.Secret{ - { - Data: map[string][]byte{ - k8s.DefaultSecretKeyName: accountSigningSeed, - }, - }, - }, - } - secretClientMock.mockGetByLabels(ctx, accountNamespace, mock.Anything, secretsList) - - By("User credentials are stored") - secretClientMock.mockApply(ctx, mock.Anything, mock.MatchedBy(func(s v1.ObjectMeta) bool { - return s.GetName() == user.GetUserSecretName() && s.GetNamespace() == accountNamespace - }), mock.AnythingOfType("map[string]string")) - - err := userManager.CreateOrUpdate(ctx, user) - - Expect(err).ToNot(HaveOccurred()) - Expect(user.GetLabels()).ToNot(BeNil()) - Expect(user.GetLabels()[k8s.LabelUserID]).Should(Satisfy(isUserPubKey)) - }) + By("providing a fake existing account and signing key") + accountRoot, _ := nkeys.CreateAccount() + accountID, _ := accountRoot.PublicKey() + accountSign, _ := nkeys.CreateAccount() + accountSignPub, _ := accountSign.PublicKey() - It("creates a new user from an account with legacy secrets", func() { By("providing a user specification") user := GetNewUser() + var subsLimit int64 = 43 + user.Spec.NatsLimits.Subs = &subsLimit + + By("mocking user signing") + userJWTSignerMock.mockSignUserJWT(ctx, domain.NewNamespacedName(accountNamespace, accountName), func(claims *jwt.UserClaims) *SignedUserJWT { + Expect(claims.IssuerAccount).To(BeEmpty()) + claims.IssuerAccount = accountID + userJWT, err := claims.Encode(accountSign) + Expect(err).NotTo(HaveOccurred()) + return &SignedUserJWT{ + UserJWT: userJWT, + AccountID: accountID, + SignedBy: accountSignPub, + } + }) - account := GetExistingAccount() - - By("mocking the secret storer") - secretClientMock.mockGetByLabels(ctx, domain.Namespace(account.GetNamespace()), mock.Anything, &corev1.SecretList{}) - - accountKeyPair, _ := nkeys.CreateAccount() - accountPublicKey, _ := accountKeyPair.PublicKey() - accountSeed, _ := accountKeyPair.Seed() - accountSecretValueMock := map[string]string{k8s.DefaultSecretKeyName: string(accountSeed)} - accountSecretNameMock := fmt.Sprintf(k8s.DeprecatedSecretNameAccountRootTemplate, account.GetName()) - secretClientMock.mockGet(ctx, domain.NewNamespacedName(account.GetNamespace(), accountSecretNameMock), accountSecretValueMock) - accountSecretLabelsMock := map[string]string{ - k8s.LabelAccountID: accountPublicKey, - k8s.LabelSecretType: k8s.SecretTypeAccountRoot, - k8s.LabelManaged: k8s.LabelManagedValue, - } - secretClientMock.mockLabel(domain.NewNamespacedName(account.GetNamespace(), accountSecretNameMock), accountSecretLabelsMock) - - accountSigningKeyPair, _ := nkeys.CreateAccount() - accountSigningPublicKey, _ := accountSigningKeyPair.PublicKey() - accountSigningSeed, _ := accountSigningKeyPair.Seed() - accountSigningSecretValueMock := map[string]string{k8s.DefaultSecretKeyName: string(accountSigningSeed)} - accountSigningSecretNameMock := fmt.Sprintf(k8s.DeprecatedSecretNameAccountSignTemplate, account.GetName()) - secretClientMock.mockGet(ctx, domain.NewNamespacedName(account.GetNamespace(), accountSigningSecretNameMock), accountSigningSecretValueMock) - accountSigningSecretLabelsMock := map[string]string{ - k8s.LabelAccountID: accountPublicKey, - k8s.LabelSecretType: k8s.SecretTypeAccountSign, - k8s.LabelManaged: k8s.LabelManagedValue, - } - secretClientMock.mockLabel(domain.NewNamespacedName(account.GetNamespace(), accountSigningSecretNameMock), accountSigningSecretLabelsMock) - - By("mocking existing account") - account.Status.SigningKey = v1alpha1.KeyInfo{ - Name: accountSigningPublicKey, - } - account.Labels = map[string]string{ - k8s.LabelAccountID: accountPublicKey, - } - accountReaderMock.mockGet(ctx, domain.NewNamespacedName(accountNamespace, accountName), *account) - - By("mock storing user credentials") - + By("User credentials are stored") secretClientMock.mockApply(ctx, mock.Anything, mock.MatchedBy(func(s v1.ObjectMeta) bool { return s.GetName() == user.GetUserSecretName() && s.GetNamespace() == accountNamespace }), mock.AnythingOfType("map[string]string")) @@ -133,28 +77,38 @@ var _ = Describe("User manager", func() { Expect(err).ToNot(HaveOccurred()) Expect(user.GetLabels()).ToNot(BeNil()) Expect(user.GetLabels()[k8s.LabelUserID]).Should(Satisfy(isUserPubKey)) + Expect(user.GetLabels()[k8s.LabelUserAccountID]).To(Equal(accountID)) + Expect(user.GetLabels()[k8s.LabelUserSignedBy]).To(Equal(accountSignPub)) + Expect(user.Status.Claims.NatsLimits.Subs).To(Equal(&subsLimit)) }) - It("creates a new user and update settigs", func() { - account := GetExistingAccount() - user := GetNewUser() + It("updates an existing user", func() { + By("providing a fake existing account and signing key") + accountRoot, _ := nkeys.CreateAccount() + accountID, _ := accountRoot.PublicKey() + accountSign, _ := nkeys.CreateAccount() + accountSignPub, _ := accountSign.PublicKey() - By("providing a user specification without any specific configuration") - accountReaderMock.mockGet(ctx, domain.NewNamespacedName(accountNamespace, accountName), *account).Twice() - - By("mocking preexisting account keys & CR") - accountSigningKeyPair, _ := nkeys.CreateAccount() - accountSigningSeed, _ := accountSigningKeyPair.Seed() - secretsList := &corev1.SecretList{ - Items: []corev1.Secret{ - { - Data: map[string][]byte{ - k8s.DefaultSecretKeyName: accountSigningSeed, - }, - }, - }, - } - secretClientMock.mockGetByLabels(ctx, accountNamespace, mock.Anything, secretsList) + By("providing a user specification bound to the existing account") + user := GetNewUser() + var subsLimit int64 = 43 + user.Spec.NatsLimits.Subs = &subsLimit + user.Labels = make(map[string]string) + user.Labels[k8s.LabelUserID] = "fake-prev-user-pub-key" + user.Labels[k8s.LabelUserAccountID] = accountID + user.Labels[k8s.LabelUserSignedBy] = "fake-prev-sign-pub-key" + + By("mocking user signing") + userJWTSignerMock.mockSignUserJWT(ctx, domain.NewNamespacedName(accountNamespace, accountName), func(claims *jwt.UserClaims) *SignedUserJWT { + Expect(claims.IssuerAccount).To(Equal(accountID)) + userJWT, err := claims.Encode(accountSign) + Expect(err).NotTo(HaveOccurred()) + return &SignedUserJWT{ + UserJWT: userJWT, + AccountID: accountID, + SignedBy: accountSignPub, + } + }) By("User credentials are stored") secretClientMock.mockApply(ctx, mock.Anything, mock.MatchedBy(func(s v1.ObjectMeta) bool { @@ -166,21 +120,9 @@ var _ = Describe("User manager", func() { Expect(err).ToNot(HaveOccurred()) Expect(user.GetLabels()).ToNot(BeNil()) Expect(user.GetLabels()[k8s.LabelUserID]).Should(Satisfy(isUserPubKey)) - - user.Spec.NatsLimits = &v1alpha1.NatsLimits{ - Subs: ptr.To[int64](100), - Data: ptr.To[int64](1024), - Payload: ptr.To[int64](256), - } - - err = userManager.CreateOrUpdate(ctx, user) - - Expect(err).ToNot(HaveOccurred()) - Expect(user.GetLabels()).ToNot(BeNil()) - Expect(user.GetLabels()[k8s.LabelUserID]).Should(Satisfy(isUserPubKey)) - Expect(user.Status.Claims.NatsLimits.Subs).Should(Equal(user.Spec.NatsLimits.Subs)) - Expect(user.Status.Claims.NatsLimits.Data).Should(Equal(user.Spec.NatsLimits.Data)) - Expect(user.Status.Claims.NatsLimits.Payload).Should(Equal(user.Spec.NatsLimits.Payload)) + Expect(user.GetLabels()[k8s.LabelUserAccountID]).To(Equal(accountID)) + Expect(user.GetLabels()[k8s.LabelUserSignedBy]).To(Equal(accountSignPub)) + Expect(user.Status.Claims.NatsLimits.Subs).To(Equal(&subsLimit)) }) }) }) @@ -202,39 +144,3 @@ func GetNewUser() *v1alpha1.User { }, } } - -func GetNewAccount() *v1alpha1.Account { - return &v1alpha1.Account{ - ObjectMeta: v1.ObjectMeta{ - Name: accountName, - Namespace: accountNamespace, - }, - Spec: v1alpha1.AccountSpec{ - JetStreamLimits: &v1alpha1.JetStreamLimits{ - MemoryStorage: ptr.To[int64](unlimitedLimit), - DiskStorage: ptr.To[int64](unlimitedLimit), - Consumer: ptr.To[int64](unlimitedLimit), - }, - }, - } -} - -func GetExistingAccount() *v1alpha1.Account { - const ControllerTypeReady = "Ready" - account := GetNewAccount() - account.Labels = map[string]string{ - k8s.LabelAccountID: "ACEXISTINGACCOUNTID", - } - account.Status = v1alpha1.AccountStatus{ - SigningKey: v1alpha1.KeyInfo{ - Name: "OPERATORSIGNPUBKEY", - }, - Conditions: []v1.Condition{ - { - Type: ControllerTypeReady, - Status: v1.ConditionTrue, - }, - }, - } - return account -}