forked from Gitlawb/openclaude
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsmart_router.py
More file actions
361 lines (315 loc) · 13.7 KB
/
smart_router.py
File metadata and controls
361 lines (315 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
"""
smart_router.py
---------------
Intelligent auto-router for openclaude.
Instead of always using one fixed provider, the smart router:
- Pings all configured providers on startup
- Scores them by latency, cost, and health
- Routes each request to the optimal provider
- Falls back automatically if a provider fails
- Learns from real request timings over time
Usage in server.py:
from smart_router import SmartRouter
router = SmartRouter()
await router.initialize()
result = await router.route(messages, model, stream)
.env config:
ROUTER_MODE=smart # or: fixed (default behaviour)
ROUTER_STRATEGY=latency # or: cost, balanced
ROUTER_FALLBACK=true # auto-retry on failure
Contribution to: https://github.com/Gitlawb/openclaude
"""
import asyncio
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Optional
import httpx
logger = logging.getLogger(__name__)
# ── Provider definitions ──────────────────────────────────────────────────────
@dataclass
class Provider:
name: str # e.g. "openai", "gemini", "ollama"
ping_url: str # URL used to check health
api_key_env: str # env var name for API key
cost_per_1k_tokens: float # estimated cost USD per 1k tokens
big_model: str # model for sonnet/large requests
small_model: str # model for haiku/small requests
latency_ms: float = 9999.0 # updated by benchmark
healthy: bool = True # updated by health checks
request_count: int = 0 # total requests routed here
error_count: int = 0 # total errors from this provider
avg_latency_ms: float = 9999.0 # rolling average from real requests
@property
def api_key(self) -> Optional[str]:
return os.getenv(self.api_key_env)
@property
def is_configured(self) -> bool:
"""True if the provider has an API key set."""
if self.name == "ollama":
return True # Ollama needs no API key
return bool(self.api_key)
@property
def error_rate(self) -> float:
if self.request_count == 0:
return 0.0
return self.error_count / self.request_count
def score(self, strategy: str = "balanced") -> float:
"""
Lower score = better provider.
strategy: 'latency' | 'cost' | 'balanced'
"""
if not self.healthy or not self.is_configured:
return float("inf")
latency_score = self.avg_latency_ms / 1000.0 # normalize to seconds
cost_score = self.cost_per_1k_tokens * 100 # normalize to similar scale
error_penalty = self.error_rate * 500 # heavy penalty for errors
if strategy == "latency":
return latency_score + error_penalty
elif strategy == "cost":
return cost_score + error_penalty
else: # balanced
return (latency_score * 0.5) + (cost_score * 0.5) + error_penalty
# ── Default provider catalogue ────────────────────────────────────────────────
def build_default_providers() -> list[Provider]:
big = os.getenv("BIG_MODEL", "gpt-4.1")
small = os.getenv("SMALL_MODEL", "gpt-4.1-mini")
ollama_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
return [
Provider(
name="openai",
ping_url="https://api.openai.com/v1/models",
api_key_env="OPENAI_API_KEY",
cost_per_1k_tokens=0.002,
big_model=big if "gpt" in big else "gpt-4.1",
small_model=small if "gpt" in small else "gpt-4.1-mini",
),
Provider(
name="gemini",
ping_url="https://generativelanguage.googleapis.com/v1/models",
api_key_env="GEMINI_API_KEY",
cost_per_1k_tokens=0.0005,
big_model=big if "gemini" in big else "gemini-2.5-pro",
small_model=small if "gemini" in small else "gemini-2.0-flash",
),
Provider(
name="ollama",
ping_url=f"{ollama_url}/api/tags",
api_key_env="",
cost_per_1k_tokens=0.0, # free — local
big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b",
small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b",
),
]
# ── Smart Router ──────────────────────────────────────────────────────────────
class SmartRouter:
"""
Intelligently routes Claude Code API requests to the best
available LLM provider based on latency, cost, and health.
"""
def __init__(
self,
providers: Optional[list[Provider]] = None,
strategy: Optional[str] = None,
fallback_enabled: Optional[bool] = None,
):
self.providers = providers or build_default_providers()
self.strategy = strategy or os.getenv("ROUTER_STRATEGY", "balanced")
self.fallback_enabled = (
fallback_enabled
if fallback_enabled is not None
else os.getenv("ROUTER_FALLBACK", "true").lower() == "true"
)
self._initialized = False
# ── Initialization ────────────────────────────────────────────────────────
async def initialize(self) -> None:
"""Ping all providers and build initial latency scores."""
logger.info("SmartRouter: benchmarking providers...")
await asyncio.gather(
*[self._ping_provider(p) for p in self.providers],
return_exceptions=True,
)
available = [p for p in self.providers if p.healthy and p.is_configured]
logger.info(
f"SmartRouter ready. Available providers: "
f"{[p.name for p in available]}"
)
if not available:
logger.warning(
"SmartRouter: no providers available! "
"Check your API keys in .env"
)
self._initialized = True
async def _ping_provider(self, provider: Provider) -> None:
"""Measure latency to a provider's health endpoint."""
if not provider.is_configured:
provider.healthy = False
logger.debug(f"SmartRouter: {provider.name} skipped — no API key")
return
headers = {}
if provider.api_key:
headers["Authorization"] = f"Bearer {provider.api_key}"
start = time.monotonic()
try:
async with httpx.AsyncClient(timeout=5.0) as client:
resp = await client.get(provider.ping_url, headers=headers)
elapsed_ms = (time.monotonic() - start) * 1000
if resp.status_code in (200, 400, 401, 403):
# 400/401/403 means reachable, just possibly bad key
# We still mark healthy for routing purposes
provider.healthy = True
provider.latency_ms = elapsed_ms
provider.avg_latency_ms = elapsed_ms
logger.info(
f"SmartRouter: {provider.name} OK "
f"({elapsed_ms:.0f}ms, status={resp.status_code})"
)
else:
provider.healthy = False
logger.warning(
f"SmartRouter: {provider.name} unhealthy "
f"(status={resp.status_code})"
)
except Exception as e:
provider.healthy = False
logger.warning(f"SmartRouter: {provider.name} unreachable — {e}")
# ── Routing logic ─────────────────────────────────────────────────────────
def select_provider(self, is_large_request: bool = False) -> Optional[Provider]:
"""
Pick the best available provider for this request.
Returns None if no providers are available.
"""
available = [
p for p in self.providers
if p.healthy and p.is_configured
]
if not available:
return None
return min(available, key=lambda p: p.score(self.strategy))
def get_model_for_provider(
self, provider: Provider, claude_model: str
) -> str:
"""Map a Claude model name to the provider's actual model."""
is_large = any(
keyword in claude_model.lower()
for keyword in ["opus", "sonnet", "large", "big"]
)
return provider.big_model if is_large else provider.small_model
def is_large_request(self, messages: list[dict]) -> bool:
"""Estimate if this is a large request based on message length."""
total_chars = sum(
len(str(m.get("content", ""))) for m in messages
)
return total_chars > 2000 # >2000 chars = treat as large
def _update_latency(self, provider: Provider, duration_ms: float) -> None:
"""Exponential moving average update for latency tracking."""
alpha = 0.3 # weight for new observation
provider.avg_latency_ms = (
alpha * duration_ms + (1 - alpha) * provider.avg_latency_ms
)
# ── Main routing entry point ──────────────────────────────────────────────
async def route(
self,
messages: list[dict],
claude_model: str = "claude-sonnet",
attempt: int = 0,
exclude_providers: Optional[list[str]] = None,
) -> dict:
"""
Route a request to the best provider.
Returns a dict with routing decision info:
{
"provider": provider name,
"model": actual model to use,
"api_key": API key for the provider,
"base_url": base URL for the provider,
}
Raises RuntimeError if no providers available.
"""
if not self._initialized:
await self.initialize()
exclude = set(exclude_providers or [])
large = self.is_large_request(messages)
available = [
p for p in self.providers
if p.healthy and p.is_configured and p.name not in exclude
]
if not available:
raise RuntimeError(
"SmartRouter: no providers available. "
"Check your API keys and provider health."
)
provider = min(available, key=lambda p: p.score(self.strategy))
model = self.get_model_for_provider(provider, claude_model)
logger.debug(
f"SmartRouter: routing to {provider.name}/{model} "
f"(strategy={self.strategy}, large={large}, attempt={attempt})"
)
return {
"provider": provider.name,
"model": model,
"api_key": provider.api_key or "none",
"provider_object": provider,
}
async def record_result(
self,
provider_name: str,
success: bool,
duration_ms: float,
) -> None:
"""
Record the outcome of a request.
Called after each proxied request to update provider scores.
"""
provider = next(
(p for p in self.providers if p.name == provider_name), None
)
if not provider:
return
provider.request_count += 1
if success:
self._update_latency(provider, duration_ms)
else:
provider.error_count += 1
# After 3 consecutive failures, mark unhealthy temporarily
recent_errors = provider.error_count
recent_total = provider.request_count
if recent_total >= 3 and (recent_errors / recent_total) > 0.7:
logger.warning(
f"SmartRouter: {provider_name} error rate high "
f"({provider.error_rate:.0%}), marking unhealthy"
)
provider.healthy = False
# Schedule re-check after 60s
asyncio.create_task(self._recheck_provider(provider, delay=60))
async def _recheck_provider(
self, provider: Provider, delay: float = 60
) -> None:
"""Re-ping a provider after a delay and restore if healthy."""
await asyncio.sleep(delay)
await self._ping_provider(provider)
if provider.healthy:
logger.info(
f"SmartRouter: {provider.name} recovered, "
f"re-adding to pool"
)
# ── Status report ─────────────────────────────────────────────────────────
def status(self) -> list[dict]:
"""Return current provider status for monitoring."""
return [
{
"provider": p.name,
"healthy": p.healthy,
"configured": p.is_configured,
"latency_ms": round(p.avg_latency_ms, 1),
"cost_per_1k": p.cost_per_1k_tokens,
"requests": p.request_count,
"errors": p.error_count,
"error_rate": f"{p.error_rate:.1%}",
"score": round(p.score(self.strategy), 3)
if p.healthy and p.is_configured
else "N/A",
}
for p in self.providers
]