From 417b21feafbb8415bd8341d5583e978fec62a339 Mon Sep 17 00:00:00 2001 From: Manuel Fittko Date: Tue, 6 Jan 2026 21:29:34 +0100 Subject: [PATCH 1/4] [token] Reduce DB reads on tracked validations - Populate token cache after ValidateTokenWithTracking - For cached limited tokens: increment usage without extra DB read - Tighten DB IncrementTokenUsage gating (inactive/expired/quota) and return semantic errors Testing: make test, make lint --- internal/database/token.go | 36 +++++-- internal/database/token_test.go | 26 +++++ internal/token/cache.go | 54 +++++++--- internal/token/cache_test.go | 119 +++++++++++++++++++++++ internal/token/token_integration_test.go | 10 ++ 5 files changed, 224 insertions(+), 21 deletions(-) diff --git a/internal/database/token.go b/internal/database/token.go index e13085ed..7871523e 100644 --- a/internal/database/token.go +++ b/internal/database/token.go @@ -265,6 +265,8 @@ func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error { UPDATE tokens SET request_count = request_count + 1, last_used_at = ? WHERE token = ? + AND is_active = 1 + AND (expires_at IS NULL OR expires_at > ?) AND ( max_requests IS NULL OR max_requests <= 0 @@ -272,7 +274,7 @@ func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error { ) ` - result, err := d.ExecContextRebound(ctx, query, now, tokenID) + result, err := d.ExecContextRebound(ctx, query, now, tokenID, now) if err != nil { return fmt.Errorf("failed to increment token usage: %w", err) } @@ -283,20 +285,36 @@ func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error { } if rowsAffected == 0 { - // No rows updated means either the token doesn't exist, or it exists but is already - // at quota. We do a follow-up SELECT to return the correct, semantically meaningful - // error (404 vs 429). If this becomes a hot path, we can explore dialect-specific - // optimizations like UPDATE ... RETURNING. - var requestCount int - var maxRequests sql.NullInt32 - checkQuery := `SELECT request_count, max_requests FROM tokens WHERE token = ?` - err := d.QueryRowContextRebound(ctx, checkQuery, tokenID).Scan(&requestCount, &maxRequests) + // No rows updated means either: + // - token doesn't exist + // - token is inactive + // - token is expired + // - token is already at quota + // + // We do a follow-up SELECT to return a semantically meaningful error. + var ( + isActive bool + expiresAt sql.NullTime + requestCount int + maxRequests sql.NullInt32 + ) + checkQuery := `SELECT is_active, expires_at, request_count, max_requests FROM tokens WHERE token = ?` + err := d.QueryRowContextRebound(ctx, checkQuery, tokenID).Scan(&isActive, &expiresAt, &requestCount, &maxRequests) if err != nil { if errors.Is(err, sql.ErrNoRows) { return ErrTokenNotFound } return fmt.Errorf("failed to check token usage for %s: %w", obfuscate.ObfuscateTokenGeneric(tokenID), err) } + if !isActive { + return token.ErrTokenInactive + } + if expiresAt.Valid { + exp := expiresAt.Time.UTC() + if !exp.IsZero() && now.After(exp) { + return token.ErrTokenExpired + } + } if maxRequests.Valid && maxRequests.Int32 > 0 { if requestCount >= int(maxRequests.Int32) { return token.ErrTokenRateLimit diff --git a/internal/database/token_test.go b/internal/database/token_test.go index dfd0e3ec..e6df7cfc 100644 --- a/internal/database/token_test.go +++ b/internal/database/token_test.go @@ -130,6 +130,32 @@ func TestTokenCRUD(t *testing.T) { t.Fatalf("Expected ErrTokenNotFound, got %v", err) } + // Test IncrementTokenUsage with inactive token + inactive := token1 + inactive.IsActive = false + err = db.UpdateToken(ctx, inactive) + if err != nil { + t.Fatalf("Failed to update token to inactive: %v", err) + } + err = db.IncrementTokenUsage(ctx, token1.Token) + if err != token.ErrTokenInactive { + t.Fatalf("Expected ErrTokenInactive, got %v", err) + } + + // Test IncrementTokenUsage with expired token + expired := token1 + expired.IsActive = true + past := time.Now().Add(-1 * time.Hour) + expired.ExpiresAt = &past + err = db.UpdateToken(ctx, expired) + if err != nil { + t.Fatalf("Failed to update token to expired: %v", err) + } + err = db.IncrementTokenUsage(ctx, token1.Token) + if err != token.ErrTokenExpired { + t.Fatalf("Expected ErrTokenExpired, got %v", err) + } + // Test UpdateToken updatedToken1.IsActive = false updatedToken1.RequestCount = 10 diff --git a/internal/token/cache.go b/internal/token/cache.go index 06ebed97..f893f08a 100644 --- a/internal/token/cache.go +++ b/internal/token/cache.go @@ -141,11 +141,12 @@ func (cv *CachedValidator) ValidateToken(ctx context.Context, tokenID string) (s // ValidateTokenWithTracking validates a token and tracks usage. // // For cached *unlimited* tokens, we keep validation cheap by returning from the cache and enqueueing -// async tracking (when an aggregator is configured). For cached *limited* tokens we defer to the -// underlying validator and invalidate the cache afterwards, to avoid serving stale rate-limit state. +// async tracking (when an aggregator is configured). +// +// For cached *limited* tokens, we still want to avoid an extra DB read on every request. We do this +// by using the cached token metadata (active/expiry/project) and performing a synchronous usage +// increment (which is where max_requests is enforced). func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenID string) (string, error) { - shouldInvalidateAfterTracking := false - cv.cacheMutex.RLock() entry, found := cv.cache[tokenID] cv.cacheMutex.RUnlock() @@ -163,7 +164,7 @@ func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenI cv.misses++ cv.evictions++ cv.statsMutex.Unlock() - } else if entry.Data.IsValid() { + } else if isCacheableTokenValid(entry.Data) { cv.statsMutex.Lock() cv.hits++ cv.statsMutex.Unlock() @@ -176,11 +177,15 @@ func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenI return entry.Data.ProjectID, nil } } - } else { - shouldInvalidateAfterTracking = true + } else if sv, ok := cv.validator.(*StandardValidator); ok && sv != nil && sv.store != nil { + // Limited token: enforce max_requests via a synchronous increment, but avoid a DB read. + if err := sv.store.IncrementTokenUsage(ctx, tokenID); err != nil { + return "", err + } + return entry.Data.ProjectID, nil } } else { - // Should be unreachable (invalid tokens should never be cached), but be defensive. + // Token became invalid (inactive/expired). Be defensive and drop the cache entry. cv.invalidateCache(tokenID) cv.statsMutex.Lock() @@ -195,9 +200,9 @@ func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenI return "", err } - if shouldInvalidateAfterTracking { - cv.invalidateCache(tokenID) - } + // Populate the cache after successful tracking so subsequent requests can avoid extra DB reads. + // (Only works for StandardValidator; others are safely ignored.) + cv.cacheToken(ctx, tokenID) return projectID, nil } @@ -231,6 +236,16 @@ func (cv *CachedValidator) checkCache(tokenID string) (string, bool) { return "", false } + // In cache but token is no longer valid (inactive/expired) + if !isCacheableTokenValid(entry.Data) { + cv.invalidateCache(tokenID) + cv.statsMutex.Lock() + cv.misses++ + cv.evictions++ + cv.statsMutex.Unlock() + return "", false + } + // In cache and valid cv.statsMutex.Lock() cv.hits++ @@ -239,6 +254,21 @@ func (cv *CachedValidator) checkCache(tokenID string) (string, bool) { return entry.Data.ProjectID, true } +// isCacheableTokenValid determines whether a cached token can be used for authentication. +// +// Important: We intentionally do NOT use RequestCount/MaxRequests here. +// Cache hits are not supposed to count against token quotas (see cache-hit fast path), +// and cached RequestCount is inherently stale under concurrency. +func isCacheableTokenValid(td TokenData) bool { + if !td.IsActive { + return false + } + if IsExpired(td.ExpiresAt) { + return false + } + return true +} + // cacheToken retrieves and caches a token func (cv *CachedValidator) cacheToken(ctx context.Context, tokenID string) { standardValidator, ok := cv.validator.(*StandardValidator) @@ -252,7 +282,7 @@ func (cv *CachedValidator) cacheToken(ctx context.Context, tokenID string) { if err != nil { return } - if !tokenData.IsValid() { + if !isCacheableTokenValid(tokenData) { return } diff --git a/internal/token/cache_test.go b/internal/token/cache_test.go index 270b1a18..47237ca2 100644 --- a/internal/token/cache_test.go +++ b/internal/token/cache_test.go @@ -3,6 +3,7 @@ package token import ( "context" "errors" + "sync" "testing" "time" ) @@ -220,6 +221,124 @@ func TestCachedValidator_ValidateTokenWithTracking(t *testing.T) { } } +type countingStore struct { + mu sync.Mutex + + tokens map[string]TokenData + + getByTokenCalls int + incCalls int +} + +func newCountingStore() *countingStore { + return &countingStore{tokens: make(map[string]TokenData)} +} + +func (s *countingStore) GetTokenByID(ctx context.Context, id string) (TokenData, error) { + s.mu.Lock() + defer s.mu.Unlock() + td, ok := s.tokens[id] + if !ok { + return TokenData{}, ErrTokenNotFound + } + return td, nil +} + +func (s *countingStore) GetTokenByToken(ctx context.Context, tokenString string) (TokenData, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.getByTokenCalls++ + td, ok := s.tokens[tokenString] + if !ok { + return TokenData{}, ErrTokenNotFound + } + return td, nil +} + +func (s *countingStore) IncrementTokenUsage(ctx context.Context, tokenString string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.incCalls++ + + td, ok := s.tokens[tokenString] + if !ok { + return ErrTokenNotFound + } + if !td.IsActive { + return ErrTokenInactive + } + if IsExpired(td.ExpiresAt) { + return ErrTokenExpired + } + if td.MaxRequests != nil && *td.MaxRequests > 0 && td.RequestCount >= *td.MaxRequests { + return ErrTokenRateLimit + } + td.RequestCount++ + now := time.Now() + td.LastUsedAt = &now + s.tokens[tokenString] = td + return nil +} + +func (s *countingStore) CreateToken(ctx context.Context, token TokenData) error { return nil } +func (s *countingStore) UpdateToken(ctx context.Context, token TokenData) error { return nil } +func (s *countingStore) ListTokens(ctx context.Context) ([]TokenData, error) { return nil, nil } +func (s *countingStore) GetTokensByProjectID(ctx context.Context, projectID string) ([]TokenData, error) { + return nil, nil +} + +func TestCachedValidator_ValidateTokenWithTracking_LimitedToken_UsesCacheAndAvoidsExtraReads(t *testing.T) { + ctx := context.Background() + store := newCountingStore() + validator := &StandardValidator{store: store} + cv := NewCachedValidator(validator, CacheOptions{TTL: time.Minute, MaxSize: 10, EnableCleanup: false}) + + now := time.Now() + future := now.Add(1 * time.Hour) + maxReq := 100 + tok, _ := GenerateToken() + + store.tokens[tok] = TokenData{ + Token: tok, + ProjectID: "p1", + ExpiresAt: &future, + IsActive: true, + RequestCount: 0, + MaxRequests: &maxReq, + CreatedAt: now, + } + + // First call: expect underlying validation+tracking and then cache population (may re-read once). + got, err := cv.ValidateTokenWithTracking(ctx, tok) + if err != nil { + t.Fatalf("ValidateTokenWithTracking() error = %v", err) + } + if got != "p1" { + t.Fatalf("ValidateTokenWithTracking() = %q, want %q", got, "p1") + } + + // Second call: should hit cache and only do a usage increment (no GetTokenByToken read). + got, err = cv.ValidateTokenWithTracking(ctx, tok) + if err != nil { + t.Fatalf("ValidateTokenWithTracking() second error = %v", err) + } + if got != "p1" { + t.Fatalf("ValidateTokenWithTracking() second = %q, want %q", got, "p1") + } + + store.mu.Lock() + defer store.mu.Unlock() + + if store.incCalls != 2 { + t.Fatalf("IncrementTokenUsage calls = %d, want 2", store.incCalls) + } + // Cache population after first tracking currently performs one extra GetTokenByToken read. + // The key property we assert: it does not keep reading on every subsequent request. + if store.getByTokenCalls > 2 { + t.Fatalf("GetTokenByToken calls = %d, want <= 2", store.getByTokenCalls) + } +} + func TestCachedValidator_CacheEviction(t *testing.T) { ctx := context.Background() mockValidator := NewMockValidator() diff --git a/internal/token/token_integration_test.go b/internal/token/token_integration_test.go index d82cba81..e97ac6e3 100644 --- a/internal/token/token_integration_test.go +++ b/internal/token/token_integration_test.go @@ -80,6 +80,16 @@ func (m *MockStore) IncrementTokenUsage(ctx context.Context, tokenID string) err return ErrTokenNotFound } + if !token.IsActive { + return ErrTokenInactive + } + if IsExpired(token.ExpiresAt) { + return ErrTokenExpired + } + if token.MaxRequests != nil && *token.MaxRequests > 0 && token.RequestCount >= *token.MaxRequests { + return ErrTokenRateLimit + } + token.RequestCount++ now := time.Now() token.LastUsedAt = &now From 19f2c4583bfa67e109f214c579b378940bc150ce Mon Sep 17 00:00:00 2001 From: Manuel Fittko Date: Tue, 6 Jan 2026 22:17:44 +0100 Subject: [PATCH 2/4] [db] Fix postgres IncrementTokenUsage active check - Use SQL boolean literal TRUE for is_active check (works across sqlite/mysql/postgres) - Tighten cache test assertion per review --- internal/database/token.go | 2 +- internal/token/cache_test.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/database/token.go b/internal/database/token.go index 7871523e..1328c5ed 100644 --- a/internal/database/token.go +++ b/internal/database/token.go @@ -265,7 +265,7 @@ func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error { UPDATE tokens SET request_count = request_count + 1, last_used_at = ? WHERE token = ? - AND is_active = 1 + AND is_active = TRUE AND (expires_at IS NULL OR expires_at > ?) AND ( max_requests IS NULL diff --git a/internal/token/cache_test.go b/internal/token/cache_test.go index 47237ca2..13b57e29 100644 --- a/internal/token/cache_test.go +++ b/internal/token/cache_test.go @@ -332,10 +332,10 @@ func TestCachedValidator_ValidateTokenWithTracking_LimitedToken_UsesCacheAndAvoi if store.incCalls != 2 { t.Fatalf("IncrementTokenUsage calls = %d, want 2", store.incCalls) } - // Cache population after first tracking currently performs one extra GetTokenByToken read. - // The key property we assert: it does not keep reading on every subsequent request. - if store.getByTokenCalls > 2 { - t.Fatalf("GetTokenByToken calls = %d, want <= 2", store.getByTokenCalls) + // First call: 1x read via validateTokenData + 1x read via cacheToken population. + // Second call (cache hit): 0x reads. + if store.getByTokenCalls != 2 { + t.Fatalf("GetTokenByToken calls = %d, want 2", store.getByTokenCalls) } } From f32634e6ca9a164150091189920ca65c75bbfb9f Mon Sep 17 00:00:00 2001 From: Manuel Fittko Date: Tue, 6 Jan 2026 22:29:59 +0100 Subject: [PATCH 3/4] [token] Invalidate cache on terminal tracking errors - Invalidate cached limited-token entries when IncrementTokenUsage returns rate-limit/inactive/expired - Treat any non-NULL expires_at as authoritative (remove IsZero guard) Testing: go test ./internal/token ./internal/database --- internal/database/token.go | 2 +- internal/token/cache.go | 5 +++++ internal/token/cache_test.go | 42 ++++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/internal/database/token.go b/internal/database/token.go index 1328c5ed..352ff83a 100644 --- a/internal/database/token.go +++ b/internal/database/token.go @@ -311,7 +311,7 @@ func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error { } if expiresAt.Valid { exp := expiresAt.Time.UTC() - if !exp.IsZero() && now.After(exp) { + if now.After(exp) { return token.ErrTokenExpired } } diff --git a/internal/token/cache.go b/internal/token/cache.go index f893f08a..c10619b6 100644 --- a/internal/token/cache.go +++ b/internal/token/cache.go @@ -180,6 +180,11 @@ func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenI } else if sv, ok := cv.validator.(*StandardValidator); ok && sv != nil && sv.store != nil { // Limited token: enforce max_requests via a synchronous increment, but avoid a DB read. if err := sv.store.IncrementTokenUsage(ctx, tokenID); err != nil { + // If the token is no longer usable (inactive/expired/quota), invalidate the cache entry + // so we avoid repeatedly hitting the cache and failing the same increment. + if err == ErrTokenRateLimit || err == ErrTokenInactive || err == ErrTokenExpired { + cv.invalidateCache(tokenID) + } return "", err } return entry.Data.ProjectID, nil diff --git a/internal/token/cache_test.go b/internal/token/cache_test.go index 13b57e29..56e47dd0 100644 --- a/internal/token/cache_test.go +++ b/internal/token/cache_test.go @@ -339,6 +339,48 @@ func TestCachedValidator_ValidateTokenWithTracking_LimitedToken_UsesCacheAndAvoi } } +func TestCachedValidator_ValidateTokenWithTracking_LimitedToken_InvalidatesOnRateLimit(t *testing.T) { + ctx := context.Background() + store := newCountingStore() + validator := &StandardValidator{store: store} + cv := NewCachedValidator(validator, CacheOptions{TTL: time.Minute, MaxSize: 10, EnableCleanup: false}) + + now := time.Now() + future := now.Add(1 * time.Hour) + maxReq := 1 + tok, _ := GenerateToken() + + store.tokens[tok] = TokenData{ + Token: tok, + ProjectID: "p1", + ExpiresAt: &future, + IsActive: true, + RequestCount: 0, + MaxRequests: &maxReq, + CreatedAt: now, + } + + // First use should succeed and populate cache (plus cacheToken read). + _, err := cv.ValidateTokenWithTracking(ctx, tok) + if err != nil { + t.Fatalf("ValidateTokenWithTracking() first error = %v", err) + } + + // Second use should hit cache, attempt increment, fail with rate limit, and invalidate cache entry. + _, err = cv.ValidateTokenWithTracking(ctx, tok) + if err == nil || !errors.Is(err, ErrTokenRateLimit) { + t.Fatalf("ValidateTokenWithTracking() second error = %v, want ErrTokenRateLimit", err) + } + + // Cache should be invalidated after the failed increment. + cv.cacheMutex.RLock() + _, ok := cv.cache[tok] + cv.cacheMutex.RUnlock() + if ok { + t.Fatalf("expected cache entry to be invalidated after ErrTokenRateLimit") + } +} + func TestCachedValidator_CacheEviction(t *testing.T) { ctx := context.Background() mockValidator := NewMockValidator() From 22e1f260bb533aea6d6e98c0f01bb1eff5a12ba9 Mon Sep 17 00:00:00 2001 From: Manuel Fittko Date: Tue, 6 Jan 2026 22:36:57 +0100 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/token/cache_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/token/cache_test.go b/internal/token/cache_test.go index 56e47dd0..a347b55a 100644 --- a/internal/token/cache_test.go +++ b/internal/token/cache_test.go @@ -332,8 +332,8 @@ func TestCachedValidator_ValidateTokenWithTracking_LimitedToken_UsesCacheAndAvoi if store.incCalls != 2 { t.Fatalf("IncrementTokenUsage calls = %d, want 2", store.incCalls) } - // First call: 1x read via validateTokenData + 1x read via cacheToken population. - // Second call (cache hit): 0x reads. + // First call: 2x reads (1x via validateTokenData + 1x via cacheToken population) and 1x write. + // Second call (cache hit): 0x reads and 1x write (only IncrementTokenUsage). if store.getByTokenCalls != 2 { t.Fatalf("GetTokenByToken calls = %d, want 2", store.getByTokenCalls) } @@ -360,7 +360,7 @@ func TestCachedValidator_ValidateTokenWithTracking_LimitedToken_InvalidatesOnRat CreatedAt: now, } - // First use should succeed and populate cache (plus cacheToken read). + // First use should succeed and populate cache (initial validation read + cacheToken read = 2 GetTokenByToken calls). _, err := cv.ValidateTokenWithTracking(ctx, tok) if err != nil { t.Fatalf("ValidateTokenWithTracking() first error = %v", err)