diff --git a/.gitignore b/.gitignore index 8033f7f..9bf8e12 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,8 @@ configs/development* # SQLite database files *.db *.sqlite -*.sqlite3 \ No newline at end of file +*.sqlite3 + +# Build artifacts +identra-gateway +identra-grpc \ No newline at end of file diff --git a/cmd/identra-gateway/main.go b/cmd/identra-gateway/main.go index 04485c0..6c9c612 100644 --- a/cmd/identra-gateway/main.go +++ b/cmd/identra-gateway/main.go @@ -54,7 +54,15 @@ func NewGateway(grpcEndpoint, staticDir, apiPrefix string) (*Gateway, error) { } }), runtime.WithOutgoingHeaderMatcher(func(key string) (string, bool) { - return "", false + // Allow cache-related headers to pass through to HTTP response + switch strings.ToLower(key) { + case "cache-control": + return "Cache-Control", true + case "etag": + return "ETag", true + default: + return "", false + } }), ) diff --git a/docs/KEY_ROTATION.md b/docs/KEY_ROTATION.md new file mode 100644 index 0000000..45bde93 --- /dev/null +++ b/docs/KEY_ROTATION.md @@ -0,0 +1,224 @@ +# JWKS Key Rotation Guide + +## Overview + +Identra's JWT signing key infrastructure supports robust key rotation to enable secure, zero-downtime key updates. This guide explains how key rotation works and provides recommended procedures for operators. + +## Key Lifecycle States + +Keys in the Identra KeyManager can be in one of three states: + +- **ACTIVE**: The key currently used for signing new JWT tokens. Only one key can be ACTIVE at a time. +- **PASSIVE**: Keys published in the JWKS endpoint for token verification but not used for signing. Multiple keys can be PASSIVE simultaneously. +- **RETIRED**: Keys removed from the system entirely. They are no longer published in JWKS and cannot verify tokens. + +## Key Rotation Strategy + +Identra implements **Option 1: Short-lived JWT with Key Rotation** from the JWKS rotation strategy: + +- Access tokens are short-lived (typically 5-15 minutes) +- Refresh tokens are long-lived (typically 7 days) +- Both token types include a `kid` (key ID) in the JWT header +- The JWKS endpoint exposes all ACTIVE and PASSIVE keys +- During rotation, both old and new keys are published simultaneously to ensure continuous token validity + +## Rotation Procedure + +### Recommended Timeline + +For access tokens with a 15-minute lifetime: + +1. **T+0**: Add new key in PASSIVE state +2. **T+1 hour**: Promote new key to ACTIVE (old key becomes PASSIVE) +3. **T+2 hours**: Retire old key (after all tokens signed with it have expired) + +The 1-hour delay before promotion allows JWKS caches to refresh and clients to discover the new key before it's used for signing. + +### Step-by-Step Process + +#### 1. Add New Key (PASSIVE) + +```go +km := security.GetKeyManager() +newKeyID, err := km.AddKeyPassive() +if err != nil { + log.Fatalf("Failed to add passive key: %v", err) +} +log.Printf("Added new passive key: %s", newKeyID) +``` + +At this point: +- New key is published in JWKS but not used for signing +- Existing ACTIVE key continues signing tokens +- All previously issued tokens remain valid + +#### 2. Wait for Cache Propagation + +**Recommended wait time**: 1 hour (longer than the JWKS cache max-age) + +This ensures: +- All clients have refreshed their cached JWKS +- The new key is known to all relying parties +- No verification failures when the key becomes ACTIVE + +#### 3. Promote New Key to ACTIVE + +```go +km := security.GetKeyManager() +err := km.PromoteKey(newKeyID) +if err != nil { + log.Fatalf("Failed to promote key: %v", err) +} +log.Printf("Promoted key %s to ACTIVE", newKeyID) +``` + +At this point: +- New key is now used for signing all new tokens +- Old key is automatically demoted to PASSIVE +- Both keys remain in JWKS for verification +- Tokens signed with either key are valid + +#### 4. Wait for Old Tokens to Expire + +**Recommended wait time**: 2x access token lifetime (e.g., 30 minutes for 15-minute tokens) + +This ensures: +- All tokens signed with the old key have expired +- No valid tokens depend on the old key for verification + +#### 5. Retire Old Key + +```go +km := security.GetKeyManager() +err := km.RetireKey(oldKeyID) +if err != nil { + log.Fatalf("Failed to retire key: %v", err) +} +log.Printf("Retired key %s", oldKeyID) +``` + +At this point: +- Old key is completely removed from the system +- Only the new ACTIVE key appears in JWKS +- System is back to single-key state + +## Key Management API + +### List All Keys + +```go +km := security.GetKeyManager() +keys := km.ListKeys() +for _, key := range keys { + fmt.Printf("Key ID: %s, State: %s\n", key.KeyID, key.State) +} +``` + +### Add Key in PASSIVE State + +```go +km := security.GetKeyManager() +keyID, err := km.AddKeyPassive() +``` + +### Promote PASSIVE Key to ACTIVE + +```go +km := security.GetKeyManager() +err := km.PromoteKey(keyID) +``` + +### Demote ACTIVE Key to PASSIVE + +```go +km := security.GetKeyManager() +err := km.DemoteKey(keyID) +``` + +### Retire PASSIVE Key + +```go +km := security.GetKeyManager() +err := km.RetireKey(keyID) +``` + +Note: You cannot retire an ACTIVE key directly. Demote it first. + +## HTTP Cache Headers + +The JWKS endpoint includes the following cache headers: + +- `Cache-Control: public, max-age=3600` (1 hour cache) +- `ETag: "..."` (content-based hash for efficient cache validation) + +Clients should: +1. Cache the JWKS response for up to 1 hour +2. Use `If-None-Match` with the ETag for cache revalidation +3. Implement a fallback to re-fetch if verification fails + +## Emergency Key Rotation + +If a key is compromised: + +1. **Immediately** add a new PASSIVE key +2. **Immediately** promote it to ACTIVE (skip the propagation wait) +3. Monitor for verification failures and investigate +4. Consider revoking all active sessions (out of scope for current implementation) +5. After investigation, retire the compromised key + +The 1-hour propagation delay is optional in emergencies, but expect some verification failures until caches refresh. + +## Monitoring + +Monitor the following metrics: + +- Number of keys in each state (should normally be 1 ACTIVE, 0-1 PASSIVE) +- JWKS cache hit/miss rates +- Token verification success/failure rates +- Key age (rotate keys periodically, e.g., every 90 days) + +## Best Practices + +1. **Schedule rotations during low-traffic periods** to minimize impact +2. **Automate the rotation process** to reduce human error +3. **Keep rotation windows generous** (at least 2x token lifetime) +4. **Test rotation in staging** before production +5. **Document each rotation** with timestamps and key IDs +6. **Never skip the PASSIVE phase** unless it's an emergency + +## Troubleshooting + +### "Token verification failed" after rotation + +- Check that both keys are in JWKS +- Verify the token's `kid` matches a published key +- Check client's JWKS cache TTL +- Ensure promotion happened after cache refresh + +### "Cannot retire ACTIVE key" + +- Call `DemoteKey(keyID)` first +- Then call `RetireKey(keyID)` + +### "Key not found" error + +- Verify key ID with `ListKeys()` +- Check for typos in the key ID +- Ensure key hasn't already been retired + +## Future Enhancements + +Potential improvements (out of scope for current implementation): + +- CLI tool for key rotation operations +- Admin API endpoints for key management +- Automated periodic rotation +- Key rotation audit log +- Metrics and alerting integration +- Support for multiple key algorithms (ES256, EdDSA) + +## References + +- RFC 7517: JSON Web Key (JWK) +- RFC 7519: JSON Web Token (JWT) +- [JWKS Best Practices](https://auth0.com/docs/secure/tokens/json-web-tokens/json-web-key-sets) diff --git a/internal/application/identra/config.go b/internal/application/identra/config.go index 268db78..ad98352 100644 --- a/internal/application/identra/config.go +++ b/internal/application/identra/config.go @@ -14,7 +14,7 @@ type Config struct { RSAPrivateKey string GithubClientID string GithubClientSecret string - OAuthFetchEmailIfMissing bool + OAuthFetchEmailIfMissing bool OAuthStateExpirationDuration time.Duration AccessTokenExpirationDuration time.Duration RefreshTokenExpirationDuration time.Duration diff --git a/internal/application/identra/service.go b/internal/application/identra/service.go index 494dafe..3df4f6f 100644 --- a/internal/application/identra/service.go +++ b/internal/application/identra/service.go @@ -3,6 +3,7 @@ package identra import ( "context" "crypto/rand" + "crypto/sha256" "encoding/hex" "errors" "fmt" @@ -24,7 +25,9 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo/options" "golang.org/x/oauth2" "golang.org/x/oauth2/github" + "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -127,7 +130,44 @@ func (s *Service) Close(ctx context.Context) error { } func (s *Service) GetJWKS(ctx context.Context, _ *identra_v1_pb.GetJWKSRequest) (*identra_v1_pb.GetJWKSResponse, error) { - return s.keyManager.GetJWKS(), nil + response := s.keyManager.GetJWKS() + + // Generate ETag based on hash of key IDs in the response + // This allows clients to efficiently check if keys have changed + etag := generateJWKSETag(response) + + // Set HTTP cache headers via gRPC metadata + // Cache-Control: public, max-age=3600 (1 hour) + // This allows clients to cache the JWKS and reduces load on the server + md := metadata.Pairs( + "Cache-Control", "public, max-age=3600", + "ETag", etag, + ) + if err := grpc.SetHeader(ctx, md); err != nil { + // Log error but don't fail the request + slog.Warn("failed to set JWKS cache headers", "error", err) + } + + return response, nil +} + +// generateJWKSETag creates an ETag based on the key IDs in the JWKS response. +// This allows clients to efficiently check if the key set has changed. +func generateJWKSETag(jwks *identra_v1_pb.GetJWKSResponse) string { + if jwks == nil || len(jwks.Keys) == 0 { + return `"empty"` + } + + // Join all key IDs with a delimiter to avoid ambiguous concatenations, then hash them + keyIDs := make([]string, 0, len(jwks.Keys)) + for _, key := range jwks.Keys { + keyIDs = append(keyIDs, key.Kid) + } + + hash := sha256.Sum256([]byte(strings.Join(keyIDs, ","))) + // Use the full 32 bytes (256 bits) of the SHA-256 hash to minimize collision risk + // Quoted per HTTP ETag specification (RFC 7232) + return fmt.Sprintf(`"%x"`, hash[:]) } func (s *Service) GetOAuthAuthorizationURL( diff --git a/internal/examples/key_rotation.go b/internal/examples/key_rotation.go new file mode 100644 index 0000000..1464f80 --- /dev/null +++ b/internal/examples/key_rotation.go @@ -0,0 +1,130 @@ +package main + +import ( + "fmt" + "log" + "time" + + "github.com/poly-workshop/identra/internal/infrastructure/security" +) + +// This example demonstrates the key rotation workflow for JWKS. +// +// NOTE: This example is located in the internal/ tree because it demonstrates +// internal API usage. It uses packages under internal/infrastructure/security +// which are not accessible to external packages per Go's internal package rules. +// +// This example is intended for: +// - Internal operators managing Identra deployments +// - Understanding the key rotation workflow +// - Testing rotation procedures in development +// +// In production, these steps would be automated or executed via CLI/API. +func main() { + fmt.Println("=== JWKS Key Rotation Example ===") + fmt.Println() + + // Step 1: Initialize KeyManager with first key + km := security.GetKeyManager() + if err := km.GenerateKeyPair(); err != nil { + log.Fatalf("Failed to generate initial key: %v", err) + } + + initialKeyID := km.GetKeyID() + fmt.Printf("Step 1: Generated initial ACTIVE key: %s\n", initialKeyID) + printKeyStatus(km) + + // Step 2: Add new key in PASSIVE state + fmt.Println("\nStep 2: Adding new key in PASSIVE state...") + newKeyID, err := km.AddKeyPassive() + if err != nil { + log.Fatalf("Failed to add passive key: %v", err) + } + fmt.Printf("Added new PASSIVE key: %s\n", newKeyID) + printKeyStatus(km) + + // Step 3: Wait for JWKS cache propagation (simulated) + fmt.Println("\nStep 3: Waiting for JWKS cache propagation (1 hour in production)...") + fmt.Println("(Skipping wait in this example)") + + // Step 4: Promote new key to ACTIVE + fmt.Println("\nStep 4: Promoting new key to ACTIVE...") + if err := km.PromoteKey(newKeyID); err != nil { + log.Fatalf("Failed to promote key: %v", err) + } + fmt.Printf("Promoted key %s to ACTIVE\n", newKeyID) + fmt.Printf("Previous key %s automatically demoted to PASSIVE\n", initialKeyID) + printKeyStatus(km) + + // Step 5: Wait for old tokens to expire (simulated) + fmt.Println("\nStep 5: Waiting for old tokens to expire (30 minutes in production)...") + fmt.Println("(Skipping wait in this example)") + + // Step 6: Retire old key + fmt.Println("\nStep 6: Retiring old key...") + if err := km.RetireKey(initialKeyID); err != nil { + log.Fatalf("Failed to retire key: %v", err) + } + fmt.Printf("Retired key %s\n", initialKeyID) + printKeyStatus(km) + + fmt.Println("\n=== Key Rotation Complete ===") + + // Demonstrate token signing and verification + fmt.Println("\n=== Token Operations ===") + demonstrateTokenOperations(km) +} + +func printKeyStatus(km *security.KeyManager) { + keys := km.ListKeys() + fmt.Printf("\nCurrent key status (%d keys):\n", len(keys)) + for _, key := range keys { + status := "" + if km.GetKeyID() == key.KeyID { + status = " (current signing key)" + } + fmt.Printf(" - %s: %s%s\n", key.KeyID, key.State, status) + } + + // Show JWKS content + jwks := km.GetJWKS() + fmt.Printf("\nKeys published in JWKS: %d\n", len(jwks.Keys)) + for _, jwk := range jwks.Keys { + fmt.Printf(" - %s (%s, %s)\n", jwk.Kid, jwk.Kty, jwk.Alg) + } +} + +func demonstrateTokenOperations(km *security.KeyManager) { + // Create token configuration + config := security.TokenConfig{ + PrivateKey: km.GetPrivateKey(), + PublicKey: km.GetPublicKey(), + KeyID: km.GetKeyID(), + Issuer: "identra-example", + AccessTokenExpiration: 15 * time.Minute, + RefreshTokenExpiration: 7 * 24 * time.Hour, + } + + // Generate token pair + userID := "user-12345" + tokenPair, err := security.NewTokenPair(userID, config) + if err != nil { + log.Fatalf("Failed to create token pair: %v", err) + } + + fmt.Printf("\nGenerated token pair for user: %s\n", userID) + fmt.Printf("Access token expires at: %v\n", time.Unix(tokenPair.AccessToken.ExpiresAt, 0)) + fmt.Printf("Refresh token expires at: %v\n", time.Unix(tokenPair.RefreshToken.ExpiresAt, 0)) + + // Validate access token + claims, err := security.ValidateAccessToken(tokenPair.AccessToken.Token, config.PublicKey) + if err != nil { + log.Fatalf("Failed to validate access token: %v", err) + } + + fmt.Printf("\nValidated access token:\n") + fmt.Printf(" User ID: %s\n", claims.UserID) + fmt.Printf(" Token Type: %s\n", claims.TokenType) + fmt.Printf(" Token ID: %s\n", claims.TokenID) + fmt.Printf(" Issuer: %s\n", claims.Issuer) +} diff --git a/internal/infrastructure/configs/keys.go b/internal/infrastructure/configs/keys.go index 0832636..c231d2b 100644 --- a/internal/infrastructure/configs/keys.go +++ b/internal/infrastructure/configs/keys.go @@ -7,14 +7,14 @@ const ( HTTPPortKey = "http_port" // Auth configuration keys - AuthRSAPrivateKeyKey = "auth.rsa_private_key" - AuthOAuthStateExpirationKey = "auth.oauth_state_expiration" - AuthAccessTokenExpirationKey = "auth.access_token_expiration" - AuthRefreshTokenExpirationKey = "auth.refresh_token_expiration" - AuthTokenIssuerKey = "auth.token_issuer" + AuthRSAPrivateKeyKey = "auth.rsa_private_key" + AuthOAuthStateExpirationKey = "auth.oauth_state_expiration" + AuthAccessTokenExpirationKey = "auth.access_token_expiration" + AuthRefreshTokenExpirationKey = "auth.refresh_token_expiration" + AuthTokenIssuerKey = "auth.token_issuer" AuthOAuthFetchEmailIfMissingKey = "auth.oauth.fetch_email_if_missing" - AuthGithubClientIDKey = "auth.github.client_id" - AuthGithubClientSecretKey = "auth.github.client_secret" + AuthGithubClientIDKey = "auth.github.client_id" + AuthGithubClientSecretKey = "auth.github.client_secret" // Persistence configuration keys PersistenceTypeKey = "persistence.type" diff --git a/internal/infrastructure/security/jkws.go b/internal/infrastructure/security/jkws.go deleted file mode 100644 index 3f0ff71..0000000 --- a/internal/infrastructure/security/jkws.go +++ /dev/null @@ -1,210 +0,0 @@ -package security - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "crypto/x509" - "encoding/base64" - "encoding/pem" - "fmt" - "math/big" - "sync" - - identra_v1_pb "github.com/poly-workshop/identra/gen/go/identra/v1" -) - -const ( - // RSAKeySize is the size of RSA keys in bits - RSAKeySize = 2048 - // KeyAlgorithm is the algorithm used for signing - KeyAlgorithm = "RS256" - // KeyUsage indicates the key is used for signing - KeyUsage = "sig" -) - -// KeyManager manages RSA key pairs for JWT signing and verification -// and can expose them as a JWKS document. -type KeyManager struct { - privateKey *rsa.PrivateKey - publicKey *rsa.PublicKey - keyID string - mu sync.RWMutex -} - -var ( - globalKeyManager *KeyManager - keyManagerOnce sync.Once -) - -// GetKeyManager returns the global KeyManager instance. -func GetKeyManager() *KeyManager { - keyManagerOnce.Do(func() { - globalKeyManager = &KeyManager{} - }) - return globalKeyManager -} - -// InitializeFromPEM initializes the key manager from a PEM-encoded private key. -func (km *KeyManager) InitializeFromPEM(privateKeyPEM string) error { - km.mu.Lock() - defer km.mu.Unlock() - - block, _ := pem.Decode([]byte(privateKeyPEM)) - if block == nil { - return fmt.Errorf("failed to decode PEM block") - } - - var privateKey *rsa.PrivateKey - var err error - - switch block.Type { - case "RSA PRIVATE KEY": - privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) - case "PRIVATE KEY": - key, parseErr := x509.ParsePKCS8PrivateKey(block.Bytes) - if parseErr != nil { - return fmt.Errorf("failed to parse PKCS8 private key: %w", parseErr) - } - var ok bool - privateKey, ok = key.(*rsa.PrivateKey) - if !ok { - return fmt.Errorf("private key is not RSA") - } - default: - return fmt.Errorf("unsupported PEM block type: %s", block.Type) - } - - if err != nil { - return fmt.Errorf("failed to parse private key: %w", err) - } - - km.privateKey = privateKey - km.publicKey = &privateKey.PublicKey - km.keyID = km.generateKeyID() - - return nil -} - -// GenerateKeyPair generates a new RSA key pair. -func (km *KeyManager) GenerateKeyPair() error { - km.mu.Lock() - defer km.mu.Unlock() - - privateKey, err := rsa.GenerateKey(rand.Reader, RSAKeySize) - if err != nil { - return fmt.Errorf("failed to generate RSA key pair: %w", err) - } - - km.privateKey = privateKey - km.publicKey = &privateKey.PublicKey - km.keyID = km.generateKeyID() - - return nil -} - -// generateKeyID creates a unique key ID based on the public key. -func (km *KeyManager) generateKeyID() string { - if km.publicKey == nil { - return "" - } - - hash := sha256.Sum256(km.publicKey.N.Bytes()) - return base64.RawURLEncoding.EncodeToString(hash[:8]) -} - -// GetPrivateKey returns the RSA private key for signing. -func (km *KeyManager) GetPrivateKey() *rsa.PrivateKey { - km.mu.RLock() - defer km.mu.RUnlock() - return km.privateKey -} - -// GetPublicKey returns the RSA public key for verification. -func (km *KeyManager) GetPublicKey() *rsa.PublicKey { - km.mu.RLock() - defer km.mu.RUnlock() - return km.publicKey -} - -// GetKeyID returns the key ID. -func (km *KeyManager) GetKeyID() string { - km.mu.RLock() - defer km.mu.RUnlock() - return km.keyID -} - -// IsInitialized checks if the key manager has been initialized. -func (km *KeyManager) IsInitialized() bool { - km.mu.RLock() - defer km.mu.RUnlock() - return km.privateKey != nil -} - -// GetJWKS returns the JSON Web Key Set containing the public key. -func (km *KeyManager) GetJWKS() *identra_v1_pb.GetJWKSResponse { - km.mu.RLock() - defer km.mu.RUnlock() - - if km.publicKey == nil { - return &identra_v1_pb.GetJWKSResponse{ - Keys: []*identra_v1_pb.JSONWebKey{}, - } - } - - n := base64.RawURLEncoding.EncodeToString(km.publicKey.N.Bytes()) - e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(km.publicKey.E)).Bytes()) - - return &identra_v1_pb.GetJWKSResponse{ - Keys: []*identra_v1_pb.JSONWebKey{ - { - Kty: "RSA", - Alg: KeyAlgorithm, - Use: KeyUsage, - Kid: km.keyID, - N: &n, - E: &e, - }, - }, - } -} - -// ExportPrivateKeyPEM exports the private key in PEM format. -func (km *KeyManager) ExportPrivateKeyPEM() (string, error) { - km.mu.RLock() - defer km.mu.RUnlock() - - if km.privateKey == nil { - return "", fmt.Errorf("no private key available") - } - - privateKeyBytes := x509.MarshalPKCS1PrivateKey(km.privateKey) - privateKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: privateKeyBytes, - }) - - return string(privateKeyPEM), nil -} - -// ExportPublicKeyPEM exports the public key in PEM format. -func (km *KeyManager) ExportPublicKeyPEM() (string, error) { - km.mu.RLock() - defer km.mu.RUnlock() - - if km.publicKey == nil { - return "", fmt.Errorf("no public key available") - } - - publicKeyBytes, err := x509.MarshalPKIXPublicKey(km.publicKey) - if err != nil { - return "", fmt.Errorf("failed to marshal public key: %w", err) - } - - publicKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "PUBLIC KEY", - Bytes: publicKeyBytes, - }) - - return string(publicKeyPEM), nil -} diff --git a/internal/infrastructure/security/jwks.go b/internal/infrastructure/security/jwks.go new file mode 100644 index 0000000..ebc4420 --- /dev/null +++ b/internal/infrastructure/security/jwks.go @@ -0,0 +1,435 @@ +package security + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "math/big" + "sort" + "sync" + + identra_v1_pb "github.com/poly-workshop/identra/gen/go/identra/v1" +) + +const ( + // RSAKeySize is the size of RSA keys in bits + RSAKeySize = 2048 + // KeyAlgorithm is the algorithm used for signing + KeyAlgorithm = "RS256" + // KeyUsage indicates the key is used for signing + KeyUsage = "sig" +) + +// KeyState represents the lifecycle state of a signing key +type KeyState string + +const ( + // KeyStateActive indicates the key is currently used for signing new tokens + KeyStateActive KeyState = "ACTIVE" + // KeyStatePassive indicates the key is published in JWKS for verification but not used for signing + KeyStatePassive KeyState = "PASSIVE" + // KeyStateRetired indicates the key is no longer published and should be removed + KeyStateRetired KeyState = "RETIRED" +) + +// KeyEntry represents a single key in the key ring with its lifecycle state +type KeyEntry struct { + privateKey *rsa.PrivateKey + publicKey *rsa.PublicKey + keyID string + state KeyState +} + +// KeyManager manages RSA key pairs for JWT signing and verification +// with support for key rotation. It maintains a key ring where: +// - Exactly one key is ACTIVE for signing new tokens +// - Zero or more keys are PASSIVE for verification only +// - RETIRED keys are removed from the ring +type KeyManager struct { + // Legacy single-key fields for backwards compatibility + privateKey *rsa.PrivateKey + publicKey *rsa.PublicKey + keyID string + + // Key ring for rotation support + keys map[string]*KeyEntry // keyed by keyID + mu sync.RWMutex +} + +var ( + globalKeyManager *KeyManager + keyManagerOnce sync.Once +) + +// GetKeyManager returns the global KeyManager instance. +func GetKeyManager() *KeyManager { + keyManagerOnce.Do(func() { + globalKeyManager = &KeyManager{ + keys: make(map[string]*KeyEntry), + } + }) + return globalKeyManager +} + +// InitializeFromPEM initializes the key manager from a PEM-encoded private key. +// The key is added to the key ring in ACTIVE state. +// If an ACTIVE key already exists, it is demoted to PASSIVE. +func (km *KeyManager) InitializeFromPEM(privateKeyPEM string) error { + km.mu.Lock() + defer km.mu.Unlock() + + // Initialize keys map if not already done + if km.keys == nil { + km.keys = make(map[string]*KeyEntry) + } + + block, _ := pem.Decode([]byte(privateKeyPEM)) + if block == nil { + return fmt.Errorf("failed to decode PEM block") + } + + var privateKey *rsa.PrivateKey + var err error + + switch block.Type { + case "RSA PRIVATE KEY": + privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + case "PRIVATE KEY": + key, parseErr := x509.ParsePKCS8PrivateKey(block.Bytes) + if parseErr != nil { + return fmt.Errorf("failed to parse PKCS8 private key: %w", parseErr) + } + var ok bool + privateKey, ok = key.(*rsa.PrivateKey) + if !ok { + return fmt.Errorf("private key is not RSA") + } + default: + return fmt.Errorf("unsupported PEM block type: %s", block.Type) + } + + if err != nil { + return fmt.Errorf("failed to parse private key: %w", err) + } + + publicKey := &privateKey.PublicKey + keyID := generateKeyID(publicKey) + + // Demote any existing ACTIVE key to PASSIVE + for _, entry := range km.keys { + if entry.state == KeyStateActive { + entry.state = KeyStatePassive + } + } + + // Add to key ring as ACTIVE + km.keys[keyID] = &KeyEntry{ + privateKey: privateKey, + publicKey: publicKey, + keyID: keyID, + state: KeyStateActive, + } + + // Maintain backwards compatibility + km.privateKey = privateKey + km.publicKey = publicKey + km.keyID = keyID + + return nil +} + +// GenerateKeyPair generates a new RSA key pair and adds it to the key ring in ACTIVE state. +// If an ACTIVE key already exists, it is demoted to PASSIVE. +func (km *KeyManager) GenerateKeyPair() error { + km.mu.Lock() + defer km.mu.Unlock() + + // Initialize keys map if not already done + if km.keys == nil { + km.keys = make(map[string]*KeyEntry) + } + + privateKey, err := rsa.GenerateKey(rand.Reader, RSAKeySize) + if err != nil { + return fmt.Errorf("failed to generate RSA key pair: %w", err) + } + + publicKey := &privateKey.PublicKey + keyID := generateKeyID(publicKey) + + // Demote any existing ACTIVE key to PASSIVE + for _, entry := range km.keys { + if entry.state == KeyStateActive { + entry.state = KeyStatePassive + } + } + + // Add to key ring as ACTIVE + km.keys[keyID] = &KeyEntry{ + privateKey: privateKey, + publicKey: publicKey, + keyID: keyID, + state: KeyStateActive, + } + + // Maintain backwards compatibility + km.privateKey = privateKey + km.publicKey = publicKey + km.keyID = keyID + + return nil +} + +// generateKeyID creates a unique key ID based on the public key. +func generateKeyID(publicKey *rsa.PublicKey) string { + if publicKey == nil { + return "" + } + + hash := sha256.Sum256(publicKey.N.Bytes()) + return base64.RawURLEncoding.EncodeToString(hash[:8]) +} + +// GetPrivateKey returns the RSA private key for signing. +func (km *KeyManager) GetPrivateKey() *rsa.PrivateKey { + km.mu.RLock() + defer km.mu.RUnlock() + return km.privateKey +} + +// GetPublicKey returns the RSA public key for verification. +func (km *KeyManager) GetPublicKey() *rsa.PublicKey { + km.mu.RLock() + defer km.mu.RUnlock() + return km.publicKey +} + +// GetKeyID returns the key ID. +func (km *KeyManager) GetKeyID() string { + km.mu.RLock() + defer km.mu.RUnlock() + return km.keyID +} + +// IsInitialized checks if the key manager has been initialized. +func (km *KeyManager) IsInitialized() bool { + km.mu.RLock() + defer km.mu.RUnlock() + return km.privateKey != nil +} + +// GetJWKS returns the JSON Web Key Set containing all ACTIVE and PASSIVE public keys. +// This enables smooth key rotation as both old and new keys are published during the transition. +// Keys are sorted by KeyID to ensure deterministic output and stable ETags. +func (km *KeyManager) GetJWKS() *identra_v1_pb.GetJWKSResponse { + km.mu.RLock() + defer km.mu.RUnlock() + + var keys []*identra_v1_pb.JSONWebKey + + // Collect key IDs first to enable sorting for deterministic output + var keyIDs []string + for keyID, entry := range km.keys { + if entry.state == KeyStateActive || entry.state == KeyStatePassive { + keyIDs = append(keyIDs, keyID) + } + } + + // Sort key IDs to ensure deterministic order + sort.Strings(keyIDs) + + // Build the JWKS response in sorted order + for _, keyID := range keyIDs { + entry := km.keys[keyID] + n := base64.RawURLEncoding.EncodeToString(entry.publicKey.N.Bytes()) + e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(entry.publicKey.E)).Bytes()) + + keys = append(keys, &identra_v1_pb.JSONWebKey{ + Kty: "RSA", + Alg: KeyAlgorithm, + Use: KeyUsage, + Kid: entry.keyID, + N: &n, + E: &e, + }) + } + + return &identra_v1_pb.GetJWKSResponse{ + Keys: keys, + } +} + +// AddKeyPassive adds a new key to the key ring in PASSIVE state. +// This allows the key to be published in JWKS for verification before promoting it to ACTIVE. +func (km *KeyManager) AddKeyPassive() (string, error) { + km.mu.Lock() + defer km.mu.Unlock() + + // Initialize keys map if not already done + if km.keys == nil { + km.keys = make(map[string]*KeyEntry) + } + + privateKey, err := rsa.GenerateKey(rand.Reader, RSAKeySize) + if err != nil { + return "", fmt.Errorf("failed to generate RSA key pair: %w", err) + } + + publicKey := &privateKey.PublicKey + keyID := generateKeyID(publicKey) + + km.keys[keyID] = &KeyEntry{ + privateKey: privateKey, + publicKey: publicKey, + keyID: keyID, + state: KeyStatePassive, + } + + return keyID, nil +} + +// PromoteKey promotes a PASSIVE key to ACTIVE state and demotes the current ACTIVE key to PASSIVE. +// This is the core operation for key rotation. +func (km *KeyManager) PromoteKey(keyID string) error { + km.mu.Lock() + defer km.mu.Unlock() + + entry, exists := km.keys[keyID] + if !exists { + return fmt.Errorf("key not found: %s", keyID) + } + + if entry.state != KeyStatePassive { + return fmt.Errorf("key %s is not in PASSIVE state (current: %s)", keyID, entry.state) + } + + // Demote current ACTIVE key to PASSIVE + for _, e := range km.keys { + if e.state == KeyStateActive { + e.state = KeyStatePassive + } + } + + // Promote the specified key to ACTIVE + entry.state = KeyStateActive + + // Update backwards compatibility fields + km.privateKey = entry.privateKey + km.publicKey = entry.publicKey + km.keyID = entry.keyID + + return nil +} + +// DemoteKey demotes an ACTIVE key to PASSIVE state. +// Use this if you need to temporarily stop signing with a key while keeping it for verification. +func (km *KeyManager) DemoteKey(keyID string) error { + km.mu.Lock() + defer km.mu.Unlock() + + entry, exists := km.keys[keyID] + if !exists { + return fmt.Errorf("key not found: %s", keyID) + } + + if entry.state != KeyStateActive { + return fmt.Errorf("key %s is not in ACTIVE state (current: %s)", keyID, entry.state) + } + + entry.state = KeyStatePassive + + // Clear backwards compatibility fields if this was the active key + if km.keyID == keyID { + km.privateKey = nil + km.publicKey = nil + km.keyID = "" + } + + return nil +} + +// RetireKey removes a key from the key ring. +// Only PASSIVE keys can be retired. ACTIVE keys must be demoted first. +func (km *KeyManager) RetireKey(keyID string) error { + km.mu.Lock() + defer km.mu.Unlock() + + entry, exists := km.keys[keyID] + if !exists { + return fmt.Errorf("key not found: %s", keyID) + } + + if entry.state == KeyStateActive { + return fmt.Errorf("cannot retire ACTIVE key %s; demote it first", keyID) + } + + delete(km.keys, keyID) + return nil +} + +// KeyInfo contains information about a key in the key ring. +type KeyInfo struct { + KeyID string + State KeyState +} + +// ListKeys returns information about all keys in the key ring. +func (km *KeyManager) ListKeys() []KeyInfo { + km.mu.RLock() + defer km.mu.RUnlock() + + result := make([]KeyInfo, 0, len(km.keys)) + + for _, entry := range km.keys { + result = append(result, KeyInfo{ + KeyID: entry.keyID, + State: entry.state, + }) + } + + return result +} + +// ExportPrivateKeyPEM exports the private key in PEM format. +func (km *KeyManager) ExportPrivateKeyPEM() (string, error) { + km.mu.RLock() + defer km.mu.RUnlock() + + if km.privateKey == nil { + return "", fmt.Errorf("no private key available") + } + + privateKeyBytes := x509.MarshalPKCS1PrivateKey(km.privateKey) + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: privateKeyBytes, + }) + + return string(privateKeyPEM), nil +} + +// ExportPublicKeyPEM exports the public key in PEM format. +func (km *KeyManager) ExportPublicKeyPEM() (string, error) { + km.mu.RLock() + defer km.mu.RUnlock() + + if km.publicKey == nil { + return "", fmt.Errorf("no public key available") + } + + publicKeyBytes, err := x509.MarshalPKIXPublicKey(km.publicKey) + if err != nil { + return "", fmt.Errorf("failed to marshal public key: %w", err) + } + + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyBytes, + }) + + return string(publicKeyPEM), nil +} diff --git a/internal/infrastructure/security/tokens_test.go b/internal/infrastructure/security/tokens_test.go index b8ea6cf..85d1eba 100644 --- a/internal/infrastructure/security/tokens_test.go +++ b/internal/infrastructure/security/tokens_test.go @@ -300,3 +300,416 @@ func TestKeyManager(t *testing.T) { t.Error("Expected new key manager to be initialized from PEM") } } + +func TestKeyRotation(t *testing.T) { + km := &KeyManager{} + + // Initialize with first key + if err := km.GenerateKeyPair(); err != nil { + t.Fatalf("Failed to generate initial key pair: %v", err) + } + + firstKeyID := km.GetKeyID() + if firstKeyID == "" { + t.Error("Expected initial key ID to be set") + } + + // Verify only one key in JWKS + jwks := km.GetJWKS() + if len(jwks.Keys) != 1 { + t.Errorf("Expected 1 key in JWKS, got %d", len(jwks.Keys)) + } + if jwks.Keys[0].Kid != firstKeyID { + t.Errorf("Expected key ID %s, got %s", firstKeyID, jwks.Keys[0].Kid) + } + + // Add a new key in PASSIVE state + secondKeyID, err := km.AddKeyPassive() + if err != nil { + t.Fatalf("Failed to add passive key: %v", err) + } + if secondKeyID == "" { + t.Error("Expected second key ID to be set") + } + if secondKeyID == firstKeyID { + t.Error("Expected second key ID to be different from first") + } + + // Verify both keys are in JWKS + jwks = km.GetJWKS() + if len(jwks.Keys) != 2 { + t.Errorf("Expected 2 keys in JWKS after adding passive key, got %d", len(jwks.Keys)) + } + + // Verify active key is still the first one + if km.GetKeyID() != firstKeyID { + t.Errorf("Expected active key to still be %s, got %s", firstKeyID, km.GetKeyID()) + } + + // Promote the second key to ACTIVE + if err := km.PromoteKey(secondKeyID); err != nil { + t.Fatalf("Failed to promote key: %v", err) + } + + // Verify active key is now the second one + if km.GetKeyID() != secondKeyID { + t.Errorf("Expected active key to be %s after promotion, got %s", secondKeyID, km.GetKeyID()) + } + + // Verify both keys are still in JWKS (old key should be PASSIVE now) + jwks = km.GetJWKS() + if len(jwks.Keys) != 2 { + t.Errorf("Expected 2 keys in JWKS after promotion, got %d", len(jwks.Keys)) + } + + // Retire the first key + if err := km.RetireKey(firstKeyID); err != nil { + t.Fatalf("Failed to retire key: %v", err) + } + + // Verify only the second key is in JWKS + jwks = km.GetJWKS() + if len(jwks.Keys) != 1 { + t.Errorf("Expected 1 key in JWKS after retiring first key, got %d", len(jwks.Keys)) + } + if jwks.Keys[0].Kid != secondKeyID { + t.Errorf("Expected remaining key to be %s, got %s", secondKeyID, jwks.Keys[0].Kid) + } +} + +func TestKeyRotationWithTokenValidation(t *testing.T) { + km := &KeyManager{} + + // Initialize with first key + if err := km.GenerateKeyPair(); err != nil { + t.Fatalf("Failed to generate initial key pair: %v", err) + } + + // Create a token with the first key + userID := uuid.New().String() + config := TokenConfig{ + PrivateKey: km.GetPrivateKey(), + PublicKey: km.GetPublicKey(), + KeyID: km.GetKeyID(), + Issuer: "test-issuer", + AccessTokenExpiration: 15 * time.Minute, + RefreshTokenExpiration: 7 * 24 * time.Hour, + } + + tokenPair, err := NewTokenPair(userID, config) + if err != nil { + t.Fatalf("Failed to create token pair: %v", err) + } + + // Verify token can be validated with first key + claims, err := ValidateAccessToken(tokenPair.AccessToken.Token, km.GetPublicKey()) + if err != nil { + t.Fatalf("Failed to validate token with first key: %v", err) + } + if claims.UserID != userID { + t.Errorf("Expected user ID %s, got %s", userID, claims.UserID) + } + + // Add a new key in PASSIVE state + secondKeyID, err := km.AddKeyPassive() + if err != nil { + t.Fatalf("Failed to add passive key: %v", err) + } + + // Promote the second key to ACTIVE + if err := km.PromoteKey(secondKeyID); err != nil { + t.Fatalf("Failed to promote key: %v", err) + } + + // Token signed with first key should still be valid (key is now PASSIVE) + // In a real scenario, we would extract kid from JWT header and use it to look up the key + // For this test, we just verify that both keys are still in JWKS + jwks := km.GetJWKS() + if len(jwks.Keys) != 2 { + t.Errorf("Expected 2 keys in JWKS after promotion, got %d", len(jwks.Keys)) + } + + // Create a new token with the second (now active) key + config.PrivateKey = km.GetPrivateKey() + config.PublicKey = km.GetPublicKey() + config.KeyID = km.GetKeyID() + + newTokenPair, err := NewTokenPair(userID, config) + if err != nil { + t.Fatalf("Failed to create token pair with second key: %v", err) + } + + // New token should be valid with second key + newClaims, err := ValidateAccessToken(newTokenPair.AccessToken.Token, km.GetPublicKey()) + if err != nil { + t.Fatalf("Failed to validate token with second key: %v", err) + } + if newClaims.UserID != userID { + t.Errorf("Expected user ID %s in new token, got %s", userID, newClaims.UserID) + } +} + +func TestKeyLifecycleErrors(t *testing.T) { + km := &KeyManager{} + + // Try to promote a non-existent key + if err := km.PromoteKey("nonexistent"); err == nil { + t.Error("Expected error when promoting non-existent key") + } + + // Try to retire a non-existent key + if err := km.RetireKey("nonexistent"); err == nil { + t.Error("Expected error when retiring non-existent key") + } + + // Generate initial key + if err := km.GenerateKeyPair(); err != nil { + t.Fatalf("Failed to generate key pair: %v", err) + } + activeKeyID := km.GetKeyID() + + // Try to promote an already ACTIVE key + if err := km.PromoteKey(activeKeyID); err == nil { + t.Error("Expected error when promoting already ACTIVE key") + } + + // Try to retire an ACTIVE key + if err := km.RetireKey(activeKeyID); err == nil { + t.Error("Expected error when retiring ACTIVE key") + } + + // Add passive key + passiveKeyID, err := km.AddKeyPassive() + if err != nil { + t.Fatalf("Failed to add passive key: %v", err) + } + + // Retire passive key should succeed + if err := km.RetireKey(passiveKeyID); err != nil { + t.Errorf("Failed to retire passive key: %v", err) + } +} + +func TestListKeys(t *testing.T) { + km := &KeyManager{} + + // Initially empty + keys := km.ListKeys() + if len(keys) != 0 { + t.Errorf("Expected 0 keys initially, got %d", len(keys)) + } + + // Add first key + if err := km.GenerateKeyPair(); err != nil { + t.Fatalf("Failed to generate key pair: %v", err) + } + + keys = km.ListKeys() + if len(keys) != 1 { + t.Errorf("Expected 1 key after generation, got %d", len(keys)) + } + if keys[0].State != KeyStateActive { + t.Errorf("Expected first key to be ACTIVE, got %s", keys[0].State) + } + + // Add passive key + passiveKeyID, err := km.AddKeyPassive() + if err != nil { + t.Fatalf("Failed to add passive key: %v", err) + } + + keys = km.ListKeys() + if len(keys) != 2 { + t.Errorf("Expected 2 keys after adding passive key, got %d", len(keys)) + } + + // Verify states + activeCount := 0 + passiveCount := 0 + for _, k := range keys { + switch k.State { + case KeyStateActive: + activeCount++ + case KeyStatePassive: + passiveCount++ + } + } + if activeCount != 1 { + t.Errorf("Expected 1 ACTIVE key, got %d", activeCount) + } + if passiveCount != 1 { + t.Errorf("Expected 1 PASSIVE key, got %d", passiveCount) + } + + // Promote passive key + if err := km.PromoteKey(passiveKeyID); err != nil { + t.Fatalf("Failed to promote key: %v", err) + } + + keys = km.ListKeys() + activeCount = 0 + passiveCount = 0 + for _, k := range keys { + switch k.State { + case KeyStateActive: + activeCount++ + case KeyStatePassive: + passiveCount++ + } + } + if activeCount != 1 { + t.Errorf("Expected 1 ACTIVE key after promotion, got %d", activeCount) + } + if passiveCount != 1 { + t.Errorf("Expected 1 PASSIVE key after promotion, got %d", passiveCount) + } +} + +func TestMultipleActiveKeyPrevention(t *testing.T) { + km := &KeyManager{} + + // Generate first key + if err := km.GenerateKeyPair(); err != nil { + t.Fatalf("Failed to generate first key: %v", err) + } + firstKeyID := km.GetKeyID() + + // Generate second key - should demote first + if err := km.GenerateKeyPair(); err != nil { + t.Fatalf("Failed to generate second key: %v", err) + } + secondKeyID := km.GetKeyID() + + // Verify only one ACTIVE key + keys := km.ListKeys() + activeCount := 0 + for _, k := range keys { + if k.State == KeyStateActive { + activeCount++ + if k.KeyID != secondKeyID { + t.Errorf("Expected ACTIVE key to be %s, got %s", secondKeyID, k.KeyID) + } + } + } + if activeCount != 1 { + t.Errorf("Expected exactly 1 ACTIVE key, got %d", activeCount) + } + + // Verify first key was demoted to PASSIVE + found := false + for _, k := range keys { + if k.KeyID == firstKeyID { + found = true + if k.State != KeyStatePassive { + t.Errorf("Expected first key to be PASSIVE, got %s", k.State) + } + } + } + if !found { + t.Error("First key not found in key ring") + } +} + +func TestInitializeFromPEMWithExistingKey(t *testing.T) { + km := &KeyManager{} + + // Generate initial key + if err := km.GenerateKeyPair(); err != nil { + t.Fatalf("Failed to generate initial key: %v", err) + } + firstKeyID := km.GetKeyID() + + // Export and re-import (simulating loading from config) + pem, err := km.ExportPrivateKeyPEM() + if err != nil { + t.Fatalf("Failed to export PEM: %v", err) + } + + // Create a new key and then initialize from PEM + newKm := &KeyManager{} + if err := newKm.GenerateKeyPair(); err != nil { + t.Fatalf("Failed to generate key in new manager: %v", err) + } + tempKeyID := newKm.GetKeyID() + + // Initialize from PEM - should demote the temp key + if err := newKm.InitializeFromPEM(pem); err != nil { + t.Fatalf("Failed to initialize from PEM: %v", err) + } + + // Verify the PEM key is now ACTIVE + if newKm.GetKeyID() != firstKeyID { + t.Errorf("Expected ACTIVE key to be from PEM (%s), got %s", firstKeyID, newKm.GetKeyID()) + } + + // Verify temp key was demoted + keys := newKm.ListKeys() + activeCount := 0 + for _, k := range keys { + if k.State == KeyStateActive { + activeCount++ + } + if k.KeyID == tempKeyID && k.State != KeyStatePassive { + t.Errorf("Expected temp key to be PASSIVE, got %s", k.State) + } + } + if activeCount != 1 { + t.Errorf("Expected exactly 1 ACTIVE key after PEM init, got %d", activeCount) + } +} + +func TestJWKSDeterministicOrder(t *testing.T) { + km := &KeyManager{} + + // Add multiple keys in random order + if err := km.GenerateKeyPair(); err != nil { + t.Fatalf("Failed to generate first key: %v", err) + } + + keyID2, err := km.AddKeyPassive() + if err != nil { + t.Fatalf("Failed to add second passive key: %v", err) + } + + keyID3, err := km.AddKeyPassive() + if err != nil { + t.Fatalf("Failed to add third passive key: %v", err) + } + + // Get JWKS multiple times and verify order is consistent + jwks1 := km.GetJWKS() + jwks2 := km.GetJWKS() + jwks3 := km.GetJWKS() + + // Verify all responses have the same number of keys + if len(jwks1.Keys) != len(jwks2.Keys) || len(jwks1.Keys) != len(jwks3.Keys) { + t.Errorf("Inconsistent number of keys across JWKS responses") + } + + // Verify the order is the same in all responses + for i := range jwks1.Keys { + if jwks1.Keys[i].Kid != jwks2.Keys[i].Kid { + t.Errorf("Key order differs between responses 1 and 2 at index %d: %s vs %s", + i, jwks1.Keys[i].Kid, jwks2.Keys[i].Kid) + } + if jwks1.Keys[i].Kid != jwks3.Keys[i].Kid { + t.Errorf("Key order differs between responses 1 and 3 at index %d: %s vs %s", + i, jwks1.Keys[i].Kid, jwks3.Keys[i].Kid) + } + } + + // Verify keys are sorted (alphanumeric order) + for i := 1; i < len(jwks1.Keys); i++ { + if jwks1.Keys[i-1].Kid > jwks1.Keys[i].Kid { + t.Errorf("Keys are not sorted: %s > %s", jwks1.Keys[i-1].Kid, jwks1.Keys[i].Kid) + } + } + + // Clean up - retire the passive keys + if err := km.RetireKey(keyID2); err != nil { + t.Fatalf("Failed to retire key: %v", err) + } + if err := km.RetireKey(keyID3); err != nil { + t.Fatalf("Failed to retire key: %v", err) + } +}