diff --git a/internal/database/token.go b/internal/database/token.go index e13085ed..352ff83a 100644 --- a/internal/database/token.go +++ b/internal/database/token.go @@ -265,6 +265,8 @@ func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error { UPDATE tokens SET request_count = request_count + 1, last_used_at = ? WHERE token = ? + AND is_active = TRUE + AND (expires_at IS NULL OR expires_at > ?) AND ( max_requests IS NULL OR max_requests <= 0 @@ -272,7 +274,7 @@ func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error { ) ` - result, err := d.ExecContextRebound(ctx, query, now, tokenID) + result, err := d.ExecContextRebound(ctx, query, now, tokenID, now) if err != nil { return fmt.Errorf("failed to increment token usage: %w", err) } @@ -283,20 +285,36 @@ func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error { } if rowsAffected == 0 { - // No rows updated means either the token doesn't exist, or it exists but is already - // at quota. We do a follow-up SELECT to return the correct, semantically meaningful - // error (404 vs 429). If this becomes a hot path, we can explore dialect-specific - // optimizations like UPDATE ... RETURNING. - var requestCount int - var maxRequests sql.NullInt32 - checkQuery := `SELECT request_count, max_requests FROM tokens WHERE token = ?` - err := d.QueryRowContextRebound(ctx, checkQuery, tokenID).Scan(&requestCount, &maxRequests) + // No rows updated means either: + // - token doesn't exist + // - token is inactive + // - token is expired + // - token is already at quota + // + // We do a follow-up SELECT to return a semantically meaningful error. + var ( + isActive bool + expiresAt sql.NullTime + requestCount int + maxRequests sql.NullInt32 + ) + checkQuery := `SELECT is_active, expires_at, request_count, max_requests FROM tokens WHERE token = ?` + err := d.QueryRowContextRebound(ctx, checkQuery, tokenID).Scan(&isActive, &expiresAt, &requestCount, &maxRequests) if err != nil { if errors.Is(err, sql.ErrNoRows) { return ErrTokenNotFound } return fmt.Errorf("failed to check token usage for %s: %w", obfuscate.ObfuscateTokenGeneric(tokenID), err) } + if !isActive { + return token.ErrTokenInactive + } + if expiresAt.Valid { + exp := expiresAt.Time.UTC() + if now.After(exp) { + return token.ErrTokenExpired + } + } if maxRequests.Valid && maxRequests.Int32 > 0 { if requestCount >= int(maxRequests.Int32) { return token.ErrTokenRateLimit diff --git a/internal/database/token_test.go b/internal/database/token_test.go index dfd0e3ec..e6df7cfc 100644 --- a/internal/database/token_test.go +++ b/internal/database/token_test.go @@ -130,6 +130,32 @@ func TestTokenCRUD(t *testing.T) { t.Fatalf("Expected ErrTokenNotFound, got %v", err) } + // Test IncrementTokenUsage with inactive token + inactive := token1 + inactive.IsActive = false + err = db.UpdateToken(ctx, inactive) + if err != nil { + t.Fatalf("Failed to update token to inactive: %v", err) + } + err = db.IncrementTokenUsage(ctx, token1.Token) + if err != token.ErrTokenInactive { + t.Fatalf("Expected ErrTokenInactive, got %v", err) + } + + // Test IncrementTokenUsage with expired token + expired := token1 + expired.IsActive = true + past := time.Now().Add(-1 * time.Hour) + expired.ExpiresAt = &past + err = db.UpdateToken(ctx, expired) + if err != nil { + t.Fatalf("Failed to update token to expired: %v", err) + } + err = db.IncrementTokenUsage(ctx, token1.Token) + if err != token.ErrTokenExpired { + t.Fatalf("Expected ErrTokenExpired, got %v", err) + } + // Test UpdateToken updatedToken1.IsActive = false updatedToken1.RequestCount = 10 diff --git a/internal/token/cache.go b/internal/token/cache.go index 06ebed97..c10619b6 100644 --- a/internal/token/cache.go +++ b/internal/token/cache.go @@ -141,11 +141,12 @@ func (cv *CachedValidator) ValidateToken(ctx context.Context, tokenID string) (s // ValidateTokenWithTracking validates a token and tracks usage. // // For cached *unlimited* tokens, we keep validation cheap by returning from the cache and enqueueing -// async tracking (when an aggregator is configured). For cached *limited* tokens we defer to the -// underlying validator and invalidate the cache afterwards, to avoid serving stale rate-limit state. +// async tracking (when an aggregator is configured). +// +// For cached *limited* tokens, we still want to avoid an extra DB read on every request. We do this +// by using the cached token metadata (active/expiry/project) and performing a synchronous usage +// increment (which is where max_requests is enforced). func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenID string) (string, error) { - shouldInvalidateAfterTracking := false - cv.cacheMutex.RLock() entry, found := cv.cache[tokenID] cv.cacheMutex.RUnlock() @@ -163,7 +164,7 @@ func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenI cv.misses++ cv.evictions++ cv.statsMutex.Unlock() - } else if entry.Data.IsValid() { + } else if isCacheableTokenValid(entry.Data) { cv.statsMutex.Lock() cv.hits++ cv.statsMutex.Unlock() @@ -176,11 +177,20 @@ func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenI return entry.Data.ProjectID, nil } } - } else { - shouldInvalidateAfterTracking = true + } else if sv, ok := cv.validator.(*StandardValidator); ok && sv != nil && sv.store != nil { + // Limited token: enforce max_requests via a synchronous increment, but avoid a DB read. + if err := sv.store.IncrementTokenUsage(ctx, tokenID); err != nil { + // If the token is no longer usable (inactive/expired/quota), invalidate the cache entry + // so we avoid repeatedly hitting the cache and failing the same increment. + if err == ErrTokenRateLimit || err == ErrTokenInactive || err == ErrTokenExpired { + cv.invalidateCache(tokenID) + } + return "", err + } + return entry.Data.ProjectID, nil } } else { - // Should be unreachable (invalid tokens should never be cached), but be defensive. + // Token became invalid (inactive/expired). Be defensive and drop the cache entry. cv.invalidateCache(tokenID) cv.statsMutex.Lock() @@ -195,9 +205,9 @@ func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenI return "", err } - if shouldInvalidateAfterTracking { - cv.invalidateCache(tokenID) - } + // Populate the cache after successful tracking so subsequent requests can avoid extra DB reads. + // (Only works for StandardValidator; others are safely ignored.) + cv.cacheToken(ctx, tokenID) return projectID, nil } @@ -231,6 +241,16 @@ func (cv *CachedValidator) checkCache(tokenID string) (string, bool) { return "", false } + // In cache but token is no longer valid (inactive/expired) + if !isCacheableTokenValid(entry.Data) { + cv.invalidateCache(tokenID) + cv.statsMutex.Lock() + cv.misses++ + cv.evictions++ + cv.statsMutex.Unlock() + return "", false + } + // In cache and valid cv.statsMutex.Lock() cv.hits++ @@ -239,6 +259,21 @@ func (cv *CachedValidator) checkCache(tokenID string) (string, bool) { return entry.Data.ProjectID, true } +// isCacheableTokenValid determines whether a cached token can be used for authentication. +// +// Important: We intentionally do NOT use RequestCount/MaxRequests here. +// Cache hits are not supposed to count against token quotas (see cache-hit fast path), +// and cached RequestCount is inherently stale under concurrency. +func isCacheableTokenValid(td TokenData) bool { + if !td.IsActive { + return false + } + if IsExpired(td.ExpiresAt) { + return false + } + return true +} + // cacheToken retrieves and caches a token func (cv *CachedValidator) cacheToken(ctx context.Context, tokenID string) { standardValidator, ok := cv.validator.(*StandardValidator) @@ -252,7 +287,7 @@ func (cv *CachedValidator) cacheToken(ctx context.Context, tokenID string) { if err != nil { return } - if !tokenData.IsValid() { + if !isCacheableTokenValid(tokenData) { return } diff --git a/internal/token/cache_test.go b/internal/token/cache_test.go index 270b1a18..a347b55a 100644 --- a/internal/token/cache_test.go +++ b/internal/token/cache_test.go @@ -3,6 +3,7 @@ package token import ( "context" "errors" + "sync" "testing" "time" ) @@ -220,6 +221,166 @@ func TestCachedValidator_ValidateTokenWithTracking(t *testing.T) { } } +type countingStore struct { + mu sync.Mutex + + tokens map[string]TokenData + + getByTokenCalls int + incCalls int +} + +func newCountingStore() *countingStore { + return &countingStore{tokens: make(map[string]TokenData)} +} + +func (s *countingStore) GetTokenByID(ctx context.Context, id string) (TokenData, error) { + s.mu.Lock() + defer s.mu.Unlock() + td, ok := s.tokens[id] + if !ok { + return TokenData{}, ErrTokenNotFound + } + return td, nil +} + +func (s *countingStore) GetTokenByToken(ctx context.Context, tokenString string) (TokenData, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.getByTokenCalls++ + td, ok := s.tokens[tokenString] + if !ok { + return TokenData{}, ErrTokenNotFound + } + return td, nil +} + +func (s *countingStore) IncrementTokenUsage(ctx context.Context, tokenString string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.incCalls++ + + td, ok := s.tokens[tokenString] + if !ok { + return ErrTokenNotFound + } + if !td.IsActive { + return ErrTokenInactive + } + if IsExpired(td.ExpiresAt) { + return ErrTokenExpired + } + if td.MaxRequests != nil && *td.MaxRequests > 0 && td.RequestCount >= *td.MaxRequests { + return ErrTokenRateLimit + } + td.RequestCount++ + now := time.Now() + td.LastUsedAt = &now + s.tokens[tokenString] = td + return nil +} + +func (s *countingStore) CreateToken(ctx context.Context, token TokenData) error { return nil } +func (s *countingStore) UpdateToken(ctx context.Context, token TokenData) error { return nil } +func (s *countingStore) ListTokens(ctx context.Context) ([]TokenData, error) { return nil, nil } +func (s *countingStore) GetTokensByProjectID(ctx context.Context, projectID string) ([]TokenData, error) { + return nil, nil +} + +func TestCachedValidator_ValidateTokenWithTracking_LimitedToken_UsesCacheAndAvoidsExtraReads(t *testing.T) { + ctx := context.Background() + store := newCountingStore() + validator := &StandardValidator{store: store} + cv := NewCachedValidator(validator, CacheOptions{TTL: time.Minute, MaxSize: 10, EnableCleanup: false}) + + now := time.Now() + future := now.Add(1 * time.Hour) + maxReq := 100 + tok, _ := GenerateToken() + + store.tokens[tok] = TokenData{ + Token: tok, + ProjectID: "p1", + ExpiresAt: &future, + IsActive: true, + RequestCount: 0, + MaxRequests: &maxReq, + CreatedAt: now, + } + + // First call: expect underlying validation+tracking and then cache population (may re-read once). + got, err := cv.ValidateTokenWithTracking(ctx, tok) + if err != nil { + t.Fatalf("ValidateTokenWithTracking() error = %v", err) + } + if got != "p1" { + t.Fatalf("ValidateTokenWithTracking() = %q, want %q", got, "p1") + } + + // Second call: should hit cache and only do a usage increment (no GetTokenByToken read). + got, err = cv.ValidateTokenWithTracking(ctx, tok) + if err != nil { + t.Fatalf("ValidateTokenWithTracking() second error = %v", err) + } + if got != "p1" { + t.Fatalf("ValidateTokenWithTracking() second = %q, want %q", got, "p1") + } + + store.mu.Lock() + defer store.mu.Unlock() + + if store.incCalls != 2 { + t.Fatalf("IncrementTokenUsage calls = %d, want 2", store.incCalls) + } + // First call: 2x reads (1x via validateTokenData + 1x via cacheToken population) and 1x write. + // Second call (cache hit): 0x reads and 1x write (only IncrementTokenUsage). + if store.getByTokenCalls != 2 { + t.Fatalf("GetTokenByToken calls = %d, want 2", store.getByTokenCalls) + } +} + +func TestCachedValidator_ValidateTokenWithTracking_LimitedToken_InvalidatesOnRateLimit(t *testing.T) { + ctx := context.Background() + store := newCountingStore() + validator := &StandardValidator{store: store} + cv := NewCachedValidator(validator, CacheOptions{TTL: time.Minute, MaxSize: 10, EnableCleanup: false}) + + now := time.Now() + future := now.Add(1 * time.Hour) + maxReq := 1 + tok, _ := GenerateToken() + + store.tokens[tok] = TokenData{ + Token: tok, + ProjectID: "p1", + ExpiresAt: &future, + IsActive: true, + RequestCount: 0, + MaxRequests: &maxReq, + CreatedAt: now, + } + + // First use should succeed and populate cache (initial validation read + cacheToken read = 2 GetTokenByToken calls). + _, err := cv.ValidateTokenWithTracking(ctx, tok) + if err != nil { + t.Fatalf("ValidateTokenWithTracking() first error = %v", err) + } + + // Second use should hit cache, attempt increment, fail with rate limit, and invalidate cache entry. + _, err = cv.ValidateTokenWithTracking(ctx, tok) + if err == nil || !errors.Is(err, ErrTokenRateLimit) { + t.Fatalf("ValidateTokenWithTracking() second error = %v, want ErrTokenRateLimit", err) + } + + // Cache should be invalidated after the failed increment. + cv.cacheMutex.RLock() + _, ok := cv.cache[tok] + cv.cacheMutex.RUnlock() + if ok { + t.Fatalf("expected cache entry to be invalidated after ErrTokenRateLimit") + } +} + func TestCachedValidator_CacheEviction(t *testing.T) { ctx := context.Background() mockValidator := NewMockValidator() diff --git a/internal/token/token_integration_test.go b/internal/token/token_integration_test.go index d82cba81..e97ac6e3 100644 --- a/internal/token/token_integration_test.go +++ b/internal/token/token_integration_test.go @@ -80,6 +80,16 @@ func (m *MockStore) IncrementTokenUsage(ctx context.Context, tokenID string) err return ErrTokenNotFound } + if !token.IsActive { + return ErrTokenInactive + } + if IsExpired(token.ExpiresAt) { + return ErrTokenExpired + } + if token.MaxRequests != nil && *token.MaxRequests > 0 && token.RequestCount >= *token.MaxRequests { + return ErrTokenRateLimit + } + token.RequestCount++ now := time.Now() token.LastUsedAt = &now