Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions internal/database/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,16 @@ func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error {
UPDATE tokens
SET request_count = request_count + 1, last_used_at = ?
WHERE token = ?
AND is_active = TRUE
AND (expires_at IS NULL OR expires_at > ?)
AND (
max_requests IS NULL
OR max_requests <= 0
OR request_count < max_requests
)
`

result, err := d.ExecContextRebound(ctx, query, now, tokenID)
result, err := d.ExecContextRebound(ctx, query, now, tokenID, now)
if err != nil {
return fmt.Errorf("failed to increment token usage: %w", err)
}
Expand All @@ -283,20 +285,36 @@ func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error {
}

if rowsAffected == 0 {
// No rows updated means either the token doesn't exist, or it exists but is already
// at quota. We do a follow-up SELECT to return the correct, semantically meaningful
// error (404 vs 429). If this becomes a hot path, we can explore dialect-specific
// optimizations like UPDATE ... RETURNING.
var requestCount int
var maxRequests sql.NullInt32
checkQuery := `SELECT request_count, max_requests FROM tokens WHERE token = ?`
err := d.QueryRowContextRebound(ctx, checkQuery, tokenID).Scan(&requestCount, &maxRequests)
// No rows updated means either:
// - token doesn't exist
// - token is inactive
// - token is expired
// - token is already at quota
//
// We do a follow-up SELECT to return a semantically meaningful error.
var (
isActive bool
expiresAt sql.NullTime
requestCount int
maxRequests sql.NullInt32
)
checkQuery := `SELECT is_active, expires_at, request_count, max_requests FROM tokens WHERE token = ?`
err := d.QueryRowContextRebound(ctx, checkQuery, tokenID).Scan(&isActive, &expiresAt, &requestCount, &maxRequests)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrTokenNotFound
}
return fmt.Errorf("failed to check token usage for %s: %w", obfuscate.ObfuscateTokenGeneric(tokenID), err)
}
if !isActive {
return token.ErrTokenInactive
}
if expiresAt.Valid {
exp := expiresAt.Time.UTC()
if now.After(exp) {
return token.ErrTokenExpired
}
}
if maxRequests.Valid && maxRequests.Int32 > 0 {
if requestCount >= int(maxRequests.Int32) {
return token.ErrTokenRateLimit
Expand Down
26 changes: 26 additions & 0 deletions internal/database/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,32 @@ func TestTokenCRUD(t *testing.T) {
t.Fatalf("Expected ErrTokenNotFound, got %v", err)
}

// Test IncrementTokenUsage with inactive token
inactive := token1
inactive.IsActive = false
err = db.UpdateToken(ctx, inactive)
if err != nil {
t.Fatalf("Failed to update token to inactive: %v", err)
}
err = db.IncrementTokenUsage(ctx, token1.Token)
if err != token.ErrTokenInactive {
t.Fatalf("Expected ErrTokenInactive, got %v", err)
}

// Test IncrementTokenUsage with expired token
expired := token1
expired.IsActive = true
past := time.Now().Add(-1 * time.Hour)
expired.ExpiresAt = &past
err = db.UpdateToken(ctx, expired)
if err != nil {
t.Fatalf("Failed to update token to expired: %v", err)
}
err = db.IncrementTokenUsage(ctx, token1.Token)
if err != token.ErrTokenExpired {
t.Fatalf("Expected ErrTokenExpired, got %v", err)
}

// Test UpdateToken
updatedToken1.IsActive = false
updatedToken1.RequestCount = 10
Expand Down
59 changes: 47 additions & 12 deletions internal/token/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ func (cv *CachedValidator) ValidateToken(ctx context.Context, tokenID string) (s
// ValidateTokenWithTracking validates a token and tracks usage.
//
// For cached *unlimited* tokens, we keep validation cheap by returning from the cache and enqueueing
// async tracking (when an aggregator is configured). For cached *limited* tokens we defer to the
// underlying validator and invalidate the cache afterwards, to avoid serving stale rate-limit state.
// async tracking (when an aggregator is configured).
//
// For cached *limited* tokens, we still want to avoid an extra DB read on every request. We do this
// by using the cached token metadata (active/expiry/project) and performing a synchronous usage
// increment (which is where max_requests is enforced).
func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenID string) (string, error) {
shouldInvalidateAfterTracking := false

cv.cacheMutex.RLock()
entry, found := cv.cache[tokenID]
cv.cacheMutex.RUnlock()
Expand All @@ -163,7 +164,7 @@ func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenI
cv.misses++
cv.evictions++
cv.statsMutex.Unlock()
} else if entry.Data.IsValid() {
} else if isCacheableTokenValid(entry.Data) {
cv.statsMutex.Lock()
cv.hits++
cv.statsMutex.Unlock()
Expand All @@ -176,11 +177,20 @@ func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenI
return entry.Data.ProjectID, nil
}
}
} else {
shouldInvalidateAfterTracking = true
} else if sv, ok := cv.validator.(*StandardValidator); ok && sv != nil && sv.store != nil {
// Limited token: enforce max_requests via a synchronous increment, but avoid a DB read.
if err := sv.store.IncrementTokenUsage(ctx, tokenID); err != nil {
// If the token is no longer usable (inactive/expired/quota), invalidate the cache entry
// so we avoid repeatedly hitting the cache and failing the same increment.
if err == ErrTokenRateLimit || err == ErrTokenInactive || err == ErrTokenExpired {
cv.invalidateCache(tokenID)
}
return "", err
}
return entry.Data.ProjectID, nil
}
} else {
// Should be unreachable (invalid tokens should never be cached), but be defensive.
// Token became invalid (inactive/expired). Be defensive and drop the cache entry.
cv.invalidateCache(tokenID)

cv.statsMutex.Lock()
Expand All @@ -195,9 +205,9 @@ func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenI
return "", err
}

if shouldInvalidateAfterTracking {
cv.invalidateCache(tokenID)
}
// Populate the cache after successful tracking so subsequent requests can avoid extra DB reads.
// (Only works for StandardValidator; others are safely ignored.)
cv.cacheToken(ctx, tokenID)

return projectID, nil
}
Expand Down Expand Up @@ -231,6 +241,16 @@ func (cv *CachedValidator) checkCache(tokenID string) (string, bool) {
return "", false
}

// In cache but token is no longer valid (inactive/expired)
if !isCacheableTokenValid(entry.Data) {
cv.invalidateCache(tokenID)
cv.statsMutex.Lock()
cv.misses++
cv.evictions++
cv.statsMutex.Unlock()
return "", false
}

// In cache and valid
cv.statsMutex.Lock()
cv.hits++
Expand All @@ -239,6 +259,21 @@ func (cv *CachedValidator) checkCache(tokenID string) (string, bool) {
return entry.Data.ProjectID, true
}

// isCacheableTokenValid determines whether a cached token can be used for authentication.
//
// Important: We intentionally do NOT use RequestCount/MaxRequests here.
// Cache hits are not supposed to count against token quotas (see cache-hit fast path),
// and cached RequestCount is inherently stale under concurrency.
func isCacheableTokenValid(td TokenData) bool {
if !td.IsActive {
return false
}
if IsExpired(td.ExpiresAt) {
return false
}
return true
}

// cacheToken retrieves and caches a token
func (cv *CachedValidator) cacheToken(ctx context.Context, tokenID string) {
standardValidator, ok := cv.validator.(*StandardValidator)
Expand All @@ -252,7 +287,7 @@ func (cv *CachedValidator) cacheToken(ctx context.Context, tokenID string) {
if err != nil {
return
}
if !tokenData.IsValid() {
if !isCacheableTokenValid(tokenData) {
return
}

Expand Down
161 changes: 161 additions & 0 deletions internal/token/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package token
import (
"context"
"errors"
"sync"
"testing"
"time"
)
Expand Down Expand Up @@ -220,6 +221,166 @@ func TestCachedValidator_ValidateTokenWithTracking(t *testing.T) {
}
}

type countingStore struct {
mu sync.Mutex

tokens map[string]TokenData

getByTokenCalls int
incCalls int
}

func newCountingStore() *countingStore {
return &countingStore{tokens: make(map[string]TokenData)}
}

func (s *countingStore) GetTokenByID(ctx context.Context, id string) (TokenData, error) {
s.mu.Lock()
defer s.mu.Unlock()
td, ok := s.tokens[id]
if !ok {
return TokenData{}, ErrTokenNotFound
}
return td, nil
}

func (s *countingStore) GetTokenByToken(ctx context.Context, tokenString string) (TokenData, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.getByTokenCalls++
td, ok := s.tokens[tokenString]
if !ok {
return TokenData{}, ErrTokenNotFound
}
return td, nil
}

func (s *countingStore) IncrementTokenUsage(ctx context.Context, tokenString string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.incCalls++

td, ok := s.tokens[tokenString]
if !ok {
return ErrTokenNotFound
}
if !td.IsActive {
return ErrTokenInactive
}
if IsExpired(td.ExpiresAt) {
return ErrTokenExpired
}
if td.MaxRequests != nil && *td.MaxRequests > 0 && td.RequestCount >= *td.MaxRequests {
return ErrTokenRateLimit
}
td.RequestCount++
now := time.Now()
td.LastUsedAt = &now
s.tokens[tokenString] = td
return nil
}

func (s *countingStore) CreateToken(ctx context.Context, token TokenData) error { return nil }
func (s *countingStore) UpdateToken(ctx context.Context, token TokenData) error { return nil }
func (s *countingStore) ListTokens(ctx context.Context) ([]TokenData, error) { return nil, nil }
func (s *countingStore) GetTokensByProjectID(ctx context.Context, projectID string) ([]TokenData, error) {
return nil, nil
}

func TestCachedValidator_ValidateTokenWithTracking_LimitedToken_UsesCacheAndAvoidsExtraReads(t *testing.T) {
ctx := context.Background()
store := newCountingStore()
validator := &StandardValidator{store: store}
cv := NewCachedValidator(validator, CacheOptions{TTL: time.Minute, MaxSize: 10, EnableCleanup: false})

now := time.Now()
future := now.Add(1 * time.Hour)
maxReq := 100
tok, _ := GenerateToken()

store.tokens[tok] = TokenData{
Token: tok,
ProjectID: "p1",
ExpiresAt: &future,
IsActive: true,
RequestCount: 0,
MaxRequests: &maxReq,
CreatedAt: now,
}

// First call: expect underlying validation+tracking and then cache population (may re-read once).
got, err := cv.ValidateTokenWithTracking(ctx, tok)
if err != nil {
t.Fatalf("ValidateTokenWithTracking() error = %v", err)
}
if got != "p1" {
t.Fatalf("ValidateTokenWithTracking() = %q, want %q", got, "p1")
}

// Second call: should hit cache and only do a usage increment (no GetTokenByToken read).
got, err = cv.ValidateTokenWithTracking(ctx, tok)
if err != nil {
t.Fatalf("ValidateTokenWithTracking() second error = %v", err)
}
if got != "p1" {
t.Fatalf("ValidateTokenWithTracking() second = %q, want %q", got, "p1")
}

store.mu.Lock()
defer store.mu.Unlock()

if store.incCalls != 2 {
t.Fatalf("IncrementTokenUsage calls = %d, want 2", store.incCalls)
}
// First call: 2x reads (1x via validateTokenData + 1x via cacheToken population) and 1x write.
// Second call (cache hit): 0x reads and 1x write (only IncrementTokenUsage).
if store.getByTokenCalls != 2 {
t.Fatalf("GetTokenByToken calls = %d, want 2", store.getByTokenCalls)
}
}

func TestCachedValidator_ValidateTokenWithTracking_LimitedToken_InvalidatesOnRateLimit(t *testing.T) {
ctx := context.Background()
store := newCountingStore()
validator := &StandardValidator{store: store}
cv := NewCachedValidator(validator, CacheOptions{TTL: time.Minute, MaxSize: 10, EnableCleanup: false})

now := time.Now()
future := now.Add(1 * time.Hour)
maxReq := 1
tok, _ := GenerateToken()

store.tokens[tok] = TokenData{
Token: tok,
ProjectID: "p1",
ExpiresAt: &future,
IsActive: true,
RequestCount: 0,
MaxRequests: &maxReq,
CreatedAt: now,
}

// First use should succeed and populate cache (initial validation read + cacheToken read = 2 GetTokenByToken calls).
_, err := cv.ValidateTokenWithTracking(ctx, tok)
if err != nil {
t.Fatalf("ValidateTokenWithTracking() first error = %v", err)
}

// Second use should hit cache, attempt increment, fail with rate limit, and invalidate cache entry.
_, err = cv.ValidateTokenWithTracking(ctx, tok)
if err == nil || !errors.Is(err, ErrTokenRateLimit) {
t.Fatalf("ValidateTokenWithTracking() second error = %v, want ErrTokenRateLimit", err)
}

// Cache should be invalidated after the failed increment.
cv.cacheMutex.RLock()
_, ok := cv.cache[tok]
cv.cacheMutex.RUnlock()
if ok {
t.Fatalf("expected cache entry to be invalidated after ErrTokenRateLimit")
}
}

func TestCachedValidator_CacheEviction(t *testing.T) {
ctx := context.Background()
mockValidator := NewMockValidator()
Expand Down
10 changes: 10 additions & 0 deletions internal/token/token_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ func (m *MockStore) IncrementTokenUsage(ctx context.Context, tokenID string) err
return ErrTokenNotFound
}

if !token.IsActive {
return ErrTokenInactive
}
if IsExpired(token.ExpiresAt) {
return ErrTokenExpired
}
if token.MaxRequests != nil && *token.MaxRequests > 0 && token.RequestCount >= *token.MaxRequests {
return ErrTokenRateLimit
}

token.RequestCount++
now := time.Now()
token.LastUsedAt = &now
Expand Down