Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions cmd/aegisflow/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,28 @@ func healthHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, `{"status":"ok","requests":%d}`, n)
}

// buildKeyRotator constructs a KeyRotator from a provider config.
// It supports the new api_keys list (with key or key_env per entry) as well as
// the legacy single api_key_env field for backward compatibility.
func buildKeyRotator(pc config.ProviderConfig) *provider.KeyRotator {
var keys []string
for _, entry := range pc.APIKeys {
switch {
case entry.Key != "":
keys = append(keys, entry.Key)
case entry.KeyEnv != "":
if v := os.Getenv(entry.KeyEnv); v != "" {
keys = append(keys, v)
}
}
}
// Fall back to the legacy single-key field if no api_keys were resolved.
if len(keys) == 0 && pc.APIKeyEnv != "" {
keys = append(keys, os.Getenv(pc.APIKeyEnv))
}
return provider.NewKeyRotator(keys, pc.KeySelection, 0)
}

func initProviders(cfg *config.Config, registry *provider.Registry) {
for _, pc := range cfg.Providers {
if !pc.Enabled {
Expand All @@ -779,10 +801,11 @@ func initProviders(cfg *config.Config, registry *provider.Registry) {
registry.Register(provider.NewMockProvider(pc.Name, latency))
log.Printf("registered provider: %s (type: mock, latency: %s)", pc.Name, latency)
case "openai":
p := provider.NewOpenAIProvider(pc.Name, pc.BaseURL, pc.APIKeyEnv, pc.Models, pc.Timeout, pc.MaxRetries)
kr := buildKeyRotator(pc)
p := provider.NewOpenAIProviderWithKeys(pc.Name, pc.BaseURL, kr, pc.Models, pc.Timeout, pc.MaxRetries)
p.ConfigureRetry(pc.Retry)
registry.Register(p)
log.Printf("registered provider: %s (type: openai, base_url: %s)", pc.Name, pc.BaseURL)
log.Printf("registered provider: %s (type: openai, base_url: %s, keys: %d)", pc.Name, pc.BaseURL, kr.Len())
case "anthropic":
registry.Register(provider.NewAnthropicProvider(pc.Name, pc.BaseURL, pc.APIKeyEnv, pc.Models, pc.Timeout))
log.Printf("registered provider: %s (type: anthropic, base_url: %s)", pc.Name, pc.BaseURL)
Expand Down
35 changes: 22 additions & 13 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -481,20 +481,29 @@ type CORSConfig struct {
MaxAge int `yaml:"max_age"`
}

// ProviderAPIKey represents one entry in the api_keys list for a provider.
// Use Key for a literal value (dev/testing) or KeyEnv for a secret from an env var (production).
type ProviderAPIKey struct {
Key string `yaml:"key"`
KeyEnv string `yaml:"key_env"`
}

type ProviderConfig struct {
Name string `yaml:"name"`
Type string `yaml:"type"`
Enabled bool `yaml:"enabled"`
Default bool `yaml:"default"`
BaseURL string `yaml:"base_url"`
APIKeyEnv string `yaml:"api_key_env"`
Models []string `yaml:"models"`
Timeout time.Duration `yaml:"timeout"`
MaxRetries int `yaml:"max_retries"`
Retry RetryConfig `yaml:"retry"`
APIVersion string `yaml:"api_version"`
Config map[string]string `yaml:"config"`
Region string `yaml:"region"`
Name string `yaml:"name"`
Type string `yaml:"type"`
Enabled bool `yaml:"enabled"`
Default bool `yaml:"default"`
BaseURL string `yaml:"base_url"`
APIKeyEnv string `yaml:"api_key_env"` // backward compat: single key from env
APIKeys []ProviderAPIKey `yaml:"api_keys"` // multi-key rotation
KeySelection string `yaml:"key_selection"` // "round-robin" (default; only supported strategy)
Models []string `yaml:"models"`
Timeout time.Duration `yaml:"timeout"`
MaxRetries int `yaml:"max_retries"`
Retry RetryConfig `yaml:"retry"`
APIVersion string `yaml:"api_version"`
Config map[string]string `yaml:"config"`
Region string `yaml:"region"`
}

type RetryConfig struct {
Expand Down
123 changes: 123 additions & 0 deletions internal/provider/keyrotator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package provider

import (
"sync"
"time"
)

type keyState int

const (
keyStateActive keyState = iota
keyStateRateLimited
keyStateFailed
)

type managedKey struct {
value string
state keyState
cooldown time.Time
}

// KeyRotator selects API keys using round-robin and automatically excludes
// keys that are rate-limited or permanently failed.
type KeyRotator struct {
mu sync.Mutex
keys []*managedKey
strategy string
counter uint64 // always accessed under mu; plain uint64 avoids drift when active-set length changes
cooldown time.Duration
}

// NewKeyRotator creates a rotator from the given key values.
// strategy must be "round-robin" (the only supported strategy; defaults to it if empty).
// rateLimitCooldown controls how long a 429-hit key is excluded before being re-admitted.
// TODO: make rateLimitCooldown configurable per-provider via YAML (currently defaults to 60s).
func NewKeyRotator(keys []string, strategy string, rateLimitCooldown time.Duration) *KeyRotator {
if strategy == "" {
strategy = "round-robin"
}
if rateLimitCooldown == 0 {
rateLimitCooldown = 60 * time.Second
}
managed := make([]*managedKey, 0, len(keys))
for _, k := range keys {
if k != "" {
managed = append(managed, &managedKey{value: k, state: keyStateActive})
}
}
return &KeyRotator{keys: managed, strategy: strategy, cooldown: rateLimitCooldown}
}

// Pick returns the next available API key. Returns ("", false) if no keys are usable.
func (r *KeyRotator) Pick() (string, bool) {
r.mu.Lock()
defer r.mu.Unlock()
active := r.activeUnlocked()
if len(active) == 0 {
return "", false
}
idx := r.counter % uint64(len(active))
r.counter++
return active[idx].value, true
}

// MarkFailed permanently excludes a key from rotation (call on HTTP 401 Unauthorized).
func (r *KeyRotator) MarkFailed(key string) {
r.mu.Lock()
defer r.mu.Unlock()
for _, k := range r.keys {
if k.value == key {
k.state = keyStateFailed
return
}
}
}

// MarkRateLimited temporarily excludes a key (call on HTTP 429 Too Many Requests).
// The key is re-admitted automatically after the configured cooldown duration.
func (r *KeyRotator) MarkRateLimited(key string) {
r.mu.Lock()
defer r.mu.Unlock()
for _, k := range r.keys {
if k.value == key {
k.state = keyStateRateLimited
k.cooldown = time.Now().Add(r.cooldown)
return
}
}
}

// Available reports whether at least one key is usable right now.
func (r *KeyRotator) Available() bool {
r.mu.Lock()
defer r.mu.Unlock()
return len(r.activeUnlocked()) > 0
}

// Len returns the total number of configured keys regardless of state.
func (r *KeyRotator) Len() int {
r.mu.Lock()
defer r.mu.Unlock()
return len(r.keys)
}

// activeUnlocked returns currently usable keys. Caller must hold mu.
// Rate-limited keys whose cooldown has expired are re-admitted automatically.
func (r *KeyRotator) activeUnlocked() []*managedKey {
now := time.Now()
var active []*managedKey
for _, k := range r.keys {
switch k.state {
case keyStateActive:
active = append(active, k)
case keyStateRateLimited:
if now.After(k.cooldown) {
k.state = keyStateActive
active = append(active, k)
}
// keyStateFailed: never re-admitted
}
}
return active
}
159 changes: 159 additions & 0 deletions internal/provider/keyrotator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package provider

import (
"testing"
"time"
)

func TestKeyRotatorRoundRobin(t *testing.T) {
kr := NewKeyRotator([]string{"key-1", "key-2", "key-3"}, "round-robin", 0)

seen := map[string]int{}
for i := 0; i < 9; i++ {
k, ok := kr.Pick()
if !ok {
t.Fatal("expected key, got none")
}
seen[k]++
}

// All three keys should have been picked an equal number of times.
for _, key := range []string{"key-1", "key-2", "key-3"} {
if seen[key] != 3 {
t.Errorf("expected key %s to be picked 3 times, got %d", key, seen[key])
}
}
}

func TestKeyRotatorMarkFailed(t *testing.T) {
kr := NewKeyRotator([]string{"key-1", "key-2"}, "round-robin", 0)

kr.MarkFailed("key-1")

for i := 0; i < 5; i++ {
k, ok := kr.Pick()
if !ok {
t.Fatal("expected key, got none")
}
if k == "key-1" {
t.Error("permanently failed key should never be picked")
}
}
}

func TestKeyRotatorMarkRateLimited(t *testing.T) {
// Use a very short cooldown so the test doesn't need to sleep long.
cooldown := 50 * time.Millisecond
kr := NewKeyRotator([]string{"key-1", "key-2"}, "round-robin", cooldown)

kr.MarkRateLimited("key-1")

// key-1 should not be picked during cooldown.
for i := 0; i < 5; i++ {
k, ok := kr.Pick()
if !ok {
t.Fatal("expected key, got none")
}
if k == "key-1" {
t.Error("rate-limited key should not be picked during cooldown")
}
}

// After the cooldown, key-1 should be re-admitted.
time.Sleep(cooldown + 10*time.Millisecond)

seen := map[string]bool{}
for i := 0; i < 10; i++ {
k, ok := kr.Pick()
if !ok {
t.Fatal("expected key after cooldown, got none")
}
seen[k] = true
}
if !seen["key-1"] {
t.Error("key-1 should be picked again after cooldown expires")
}
}

func TestKeyRotatorAllKeysFailed(t *testing.T) {
kr := NewKeyRotator([]string{"key-1", "key-2"}, "round-robin", 0)

kr.MarkFailed("key-1")
kr.MarkFailed("key-2")

_, ok := kr.Pick()
if ok {
t.Error("Pick should return false when all keys are permanently failed")
}
if kr.Available() {
t.Error("Available should return false when all keys are failed")
}
}

func TestKeyRotatorAllKeysRateLimited(t *testing.T) {
kr := NewKeyRotator([]string{"key-1", "key-2"}, "round-robin", time.Hour)

kr.MarkRateLimited("key-1")
kr.MarkRateLimited("key-2")

_, ok := kr.Pick()
if ok {
t.Error("Pick should return false when all keys are rate-limited")
}
if kr.Available() {
t.Error("Available should return false when all keys are rate-limited")
}
}

func TestKeyRotatorSingleKey(t *testing.T) {
kr := NewKeyRotator([]string{"only-key"}, "round-robin", 0)

for i := 0; i < 3; i++ {
k, ok := kr.Pick()
if !ok {
t.Fatal("expected key, got none")
}
if k != "only-key" {
t.Errorf("expected only-key, got %s", k)
}
}
}

func TestKeyRotatorEmptyKeys(t *testing.T) {
kr := NewKeyRotator([]string{}, "round-robin", 0)

_, ok := kr.Pick()
if ok {
t.Error("Pick should return false for empty rotator")
}
if kr.Available() {
t.Error("Available should return false for empty rotator")
}
if kr.Len() != 0 {
t.Errorf("Len should be 0, got %d", kr.Len())
}
}

func TestKeyRotatorEmptyStringKeysIgnored(t *testing.T) {
kr := NewKeyRotator([]string{"", "real-key", ""}, "round-robin", 0)

if kr.Len() != 1 {
t.Errorf("empty string keys should be ignored, expected Len 1 got %d", kr.Len())
}
k, ok := kr.Pick()
if !ok || k != "real-key" {
t.Errorf("expected real-key, got %q ok=%v", k, ok)
}
}

func TestKeyRotatorLen(t *testing.T) {
kr := NewKeyRotator([]string{"key-1", "key-2", "key-3"}, "round-robin", 0)
if kr.Len() != 3 {
t.Errorf("expected Len 3, got %d", kr.Len())
}
kr.MarkFailed("key-1")
// Len counts all keys regardless of state.
if kr.Len() != 3 {
t.Errorf("Len should still be 3 after marking one failed, got %d", kr.Len())
}
}
Loading
Loading