From 1032fa24c7aec8529ff6cdd892c1d2ef4642abf0 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Mon, 24 Nov 2025 21:19:33 +0000 Subject: [PATCH] Release 2025-11-24-21-19 --- .github/scripts/sync-public.sh | 2 +- README.md | 2 +- neurons/miner/agents/example.py | 410 -------------------- neurons/miner/agents/hello_world.py | 6 - neurons/miner/agents/invalid_agent.py | 15 - neurons/miner/gateway/app.py | 79 ++-- neurons/validator/models/chutes.py | 145 +++++++ neurons/validator/models/desearch.py | 47 +++ neurons/validator/models/numinous_client.py | 30 +- 9 files changed, 272 insertions(+), 464 deletions(-) delete mode 100644 neurons/miner/agents/example.py delete mode 100644 neurons/miner/agents/hello_world.py delete mode 100644 neurons/miner/agents/invalid_agent.py diff --git a/.github/scripts/sync-public.sh b/.github/scripts/sync-public.sh index 0294657..e22f996 100755 --- a/.github/scripts/sync-public.sh +++ b/.github/scripts/sync-public.sh @@ -4,7 +4,7 @@ set -o nounset # Treat unset variables as an error and exit immediately set -o pipefail # Settings -PUBLIC_REPO="https://github.com/amedeo-gigaver/infinite_games.git" +PUBLIC_REPO="https://github.com/numinouslabs/numinous.git" PRIVATE_REPO="https://github.com/infinite-mech/infinite_games.git" RELEASE_DATE=$(date +"%Y-%m-%d-%H-%M") NEW_BRANCH="sync-main-$RELEASE_DATE" diff --git a/README.md b/README.md index 40fe819..ebe3935 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ -[Leaderboard](https://app.hex.tech/1644b22a-abe5-4113-9d5f-3ad05e4a8de7/app/Numinous-031erYRYSssIrH3W3KcyHg/latest) • [Website](https://numinouslabs.io/) • [X](https://x.com/numinous_ai) • +[Discord](https://discord.gg/qKPeYPc3) • [Dashboard](https://app.hex.tech/1644b22a-abe5-4113-9d5f-3ad05e4a8de7/app/Numinous-031erYRYSssIrH3W3KcyHg/latest) • [Website](https://numinouslabs.io/) • [Twitter](https://x.com/numinous_ai) • [Network](https://taostats.io/subnets/6/chart) --- diff --git a/neurons/miner/agents/example.py b/neurons/miner/agents/example.py deleted file mode 100644 index 2eb1129..0000000 --- a/neurons/miner/agents/example.py +++ /dev/null @@ -1,410 +0,0 @@ -import asyncio -import os -import time -from datetime import datetime - -import httpx -from pydantic import BaseModel - -# ============================================================================= -# ENVIRONMENT SETUP -# ============================================================================= - -# Fetch run ID from environment -RUN_ID = os.getenv("RUN_ID") -if not RUN_ID: - raise ValueError("RUN_ID environment variable is required but not set") - -# Fetch proxy URL from environment -PROXY_URL = os.getenv("SANDBOX_PROXY_URL", "http://sandbox_proxy") -CHUTES_URL = f"{PROXY_URL}/api/gateway/chutes" -DESEARCH_URL = f"{PROXY_URL}/api/gateway/desearch" - - -# ============================================================================= -# CONSTANTS -# ============================================================================= - -MIN_INSTANCES = 5 -LLMS = [ - "tngtech/DeepSeek-TNG-R1T2-Chimera", # 671B Tri-Mind (V3+R1+R1-0528 hybrid), fixed token - "deepseek-ai/DeepSeek-V3.1", # 685B MoE, general-purpose powerhouse - "zai-org/GLM-4.5", # Faster tool-calling specialist - "openai/gpt-oss-120b", # 120B open-source fallback -] - -# Retry configuration -MAX_RETRIES = 3 -BASE_BACKOFF = 1.5 # seconds - - -# ============================================================================= -# MODELS -# ============================================================================= - - -class AgentData(BaseModel): - event_id: str - title: str - description: str - cutoff: datetime - metadata: dict - - -class ChuteModelStatus(BaseModel): - chute_id: str - name: str - active_instance_count: int - - -# ============================================================================= -# TOOL PROMPTS -# ============================================================================= - - -def build_research_prompt(event: AgentData) -> str: - """Build targeted research prompt for Desearch.""" - return f"""Search for recent information to help forecast this event: -"{event.title}" - -Focus on: -- Latest news, announcements, or developments related to this topic -- Historical patterns or precedents -- Expert opinions or market sentiment -- Any relevant data, statistics, or indicators - -Event description: {event.description} -Forecast deadline: {event.cutoff.strftime('%Y-%m-%d')}""" - - -def build_forecast_messages(event: AgentData, context: str) -> list[dict]: - """Build LLM messages for forecasting.""" - cutoff_date = event.cutoff.strftime("%Y-%m-%d %H:%M UTC") - - system_prompt = """You are an expert forecaster specializing in probabilistic predictions. -Your task is to estimate the likelihood of binary events (YES/NO outcomes). - -Key principles: -- Consider base rates and historical precedents -- Weigh evidence quality and recency -- Account for uncertainty and missing information -- Avoid extreme predictions (0 or 1) unless evidence is overwhelming -- Use the full probability range: 0.0 (impossible) to 1.0 (certain)""" - - user_prompt = f"""**Event to Forecast:** -{event.title} - -**Full Description:** -{event.description} - -**Forecast Deadline:** {cutoff_date} - -**Research Context:** -{context if context else "No additional research context available. Base your forecast on the event description and general knowledge."} - -**Your Task:** -Estimate the probability (0.0 to 1.0) that this event will occur or resolve as YES by the deadline. - -Consider: -1. What is the base rate for similar events? -2. What specific evidence supports or contradicts this outcome? -3. What uncertainties or unknowns remain? -4. How confident are you in available information? - -**Required Output Format:** -PREDICTION: [number between 0.0 and 1.0] -REASONING: [2-4 sentences explaining your probability estimate, key factors considered, and main uncertainties]""" - - return [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - -# ============================================================================= -# HELPER FUNCTIONS -# ============================================================================= - - -async def fetch_chutes_active_models(min_instances: int = 5) -> list[ChuteModelStatus]: - try: - async with httpx.AsyncClient(timeout=10.0) as client: - response = await client.get(f"{CHUTES_URL}/status") - response.raise_for_status() - - chutes_statuses = [ChuteModelStatus.model_validate(item) for item in response.json()] - filtered_chutes = [ - chute for chute in chutes_statuses if chute.active_instance_count >= min_instances - ] - return filtered_chutes - - except Exception as e: - print(f"[WARNING] Failed to fetch chutes status: {e}") - return [] - - -def get_available_models(all_models: list[str], active_chutes: list[ChuteModelStatus]) -> list[str]: - active_names = {chute.name for chute in active_chutes} - available = [model for model in all_models if model in active_names] - - if available: - print(f"[INFO] Available models: {available}") - else: - print("[WARNING] No preferred models available. Will try all from list.") - available = all_models # Fallback: try anyway - - return available - - -async def retry_with_backoff(func, max_retries: int = MAX_RETRIES): - for attempt in range(max_retries): - try: - return await func() - except httpx.TimeoutException as e: - if attempt < max_retries - 1: - delay = BASE_BACKOFF ** (attempt + 1) - print(f"[RETRY] Timeout, retrying in {delay}s...") - await asyncio.sleep(delay) - else: - raise Exception(f"Max retries exceeded: {e}") - except httpx.HTTPStatusError as e: - try: - error_detail = e.response.json().get("detail", str(e)) - except Exception: - error_detail = e.response.text if hasattr(e.response, "text") else str(e) - - if e.response.status_code == 429: # Rate limit - if attempt < max_retries - 1: - delay = BASE_BACKOFF ** (attempt + 1) - print(f"[RETRY] Rate limited (429), retrying in {delay}s...") - await asyncio.sleep(delay) - else: - raise Exception( - f"Rate limit exceeded after {max_retries} retries: {error_detail}" - ) - else: - # Don't retry other HTTP errors - re-raise with detail - raise Exception(f"HTTP {e.response.status_code}: {error_detail}") - except Exception: - # Unknown error - don't retry - raise - - -def clip_probability(prediction: float) -> float: - return max(0.0, min(1.0, prediction)) - - -# ============================================================================= -# PHASE 1: RESEARCH WITH DESEARCH -# ============================================================================= - - -async def research_event(event: AgentData) -> str: - print("[PHASE 1] Researching event via Desearch...") - - try: - - async def desearch_call(): - async with httpx.AsyncClient(timeout=30.0) as client: - payload = { - "prompt": build_research_prompt(event), - "model": "NOVA", - "tools": ["web", "reddit", "wikipedia"], - "count": 5, - "run_id": RUN_ID, - } - response = await client.post(f"{DESEARCH_URL}/ai/search", json=payload) - response.raise_for_status() - return response.json() - - result = await retry_with_backoff(desearch_call) - - # Extract context from response - context = result.get("completion", "") - if context: - context = context[:5000] # Limit size - # Show preview of research findings - preview = context[:300].replace("\n", " ") - print(f"[PHASE 1] Research complete. Context length: {len(context)}") - print(f"[PHASE 1] Preview: {preview}...") - else: - print("[PHASE 1] No context in response") - - return context - - except Exception as e: - print(f"[PHASE 1] Research failed: {e}. Continuing without context.") - return "" - - -# ============================================================================= -# PHASE 2: FORECAST WITH LLM (CHUTES) -# ============================================================================= -# -# Error Handling Strategy for 503 (Service Unavailable): -# -# 503 can mean two things: -# 1. No instances available (cold model) - retrying won't help immediately -# 2. Model overloaded/restarting - short retry might work -# -# Solution: Try SHORT retry (2 attempts, 1.5s backoff) before swapping models -# This handles overloaded scenarios without wasting too much time on cold models - - -async def call_llm(model: str, messages: list[dict]) -> str: - async with httpx.AsyncClient(timeout=45.0) as client: - payload = { - "model": model, - "messages": messages, - "run_id": RUN_ID, - } # Must ALWAYS include RUN_ID on body - response = await client.post( - f"{CHUTES_URL}/chat/completions", - json=payload, - ) - response.raise_for_status() - - data = response.json() - return data["choices"][0]["message"]["content"] - - -def parse_llm_response(response_text: str) -> tuple[float, str]: - try: - lines = response_text.strip().split("\n") - prediction = 0.5 - reasoning = "No reasoning provided." - - for line in lines: - if line.startswith("PREDICTION:"): - # Clip probability to [0.0, 1.0] - prediction = clip_probability(float(line.replace("PREDICTION:", "").strip())) - elif line.startswith("REASONING:"): - reasoning = line.replace("REASONING:", "").strip() - - return prediction, reasoning - - except Exception as e: - print(f"[WARNING] Failed to parse LLM response: {e}") - return 0.5, "Failed to parse LLM response." - - -async def forecast_with_llm(event: AgentData, context: str, available_models: list[str]) -> dict: - print("[PHASE 2] Generating forecast with LLM...") - - messages = build_forecast_messages(event, context) - - # Try each available model in order - for i, model in enumerate(available_models): - print(f"[PHASE 2] Trying model {i+1}/{len(available_models)}: {model}") - - max_503_retries = 2 - backoff_503 = 1.5 - - for attempt in range(max_503_retries): - try: - - async def llm_call(): - return await call_llm(model, messages) - - response_text = await retry_with_backoff(llm_call) - prediction, reasoning = parse_llm_response(response_text) - - print(f"[PHASE 2] Success with {model}: prediction={prediction}") - return { - "event_id": event.event_id, - "prediction": prediction, - "reasoning": reasoning, - } - - except httpx.HTTPStatusError as e: - try: - error_detail = e.response.json().get("detail", "") - except Exception: - error_detail = "" - - if e.response.status_code == 503: - if attempt < max_503_retries - 1: - delay = backoff_503 ** (attempt + 1) - print(f"[PHASE 2] Model {model} unavailable (503). Retrying in {delay}s...") - await asyncio.sleep(delay) - continue # Retry same model - else: - # After retries, swap to next model - detail_msg = f": {error_detail}" if error_detail else "" - print( - f"[PHASE 2] Model {model} still unavailable after {max_503_retries} retries{detail_msg}. Trying next model..." - ) - break - else: - detail_msg = f": {error_detail}" if error_detail else "" - print( - f"[PHASE 2] HTTP error {e.response.status_code} with {model}{detail_msg}. Trying next model..." - ) - break - - except Exception as e: - print(f"[PHASE 2] Error with {model}: {e}. Trying next model...") - break - - print("[PHASE 2] All models failed. Returning fallback prediction.") - return { - "event_id": event.event_id, - "prediction": 0.5, - "reasoning": "Unable to generate forecast due to model availability issues. Returning neutral prediction.", - } - - -# ============================================================================= -# MAIN AGENT -# ============================================================================= - - -async def run_agent(event: AgentData) -> dict: - """ - Two-phase forecasting agent: - 1. Research: Gather context using Desearch - 2. Forecast: Generate prediction using LLM (Chutes) - - Demonstrates: - - Model availability checking - - Retry with exponential backoff - - Model swapping on 503 errors - - Graceful fallback - """ - start_time = time.time() - - # Check which models are available - active_chutes = await fetch_chutes_active_models(min_instances=MIN_INSTANCES) - available_models = get_available_models(LLMS, active_chutes) - - if not available_models: - print("[WARNING] No models available. Will attempt with preferred list anyway.") - available_models = LLMS - - # Phase 1: Research - context = await research_event(event) - - # Phase 2: Forecast - result = await forecast_with_llm(event, context, available_models) - - elapsed = time.time() - start_time - print(f"[AGENT] Complete in {elapsed:.2f}s") - - return result - - -def agent_main(event_data: dict) -> dict: - """ - Entry point for the forecasting agent. - - Args: - event_data: Event information dict - - Returns: - dict with keys: event_id, prediction, reasoning - """ - event = AgentData.model_validate(event_data) - print(f"\n[AGENT] Running forecast for event: {event.event_id}") - print(f"[AGENT] Title: {event.title}") - - return asyncio.run(run_agent(event)) diff --git a/neurons/miner/agents/hello_world.py b/neurons/miner/agents/hello_world.py deleted file mode 100644 index f4eeac4..0000000 --- a/neurons/miner/agents/hello_world.py +++ /dev/null @@ -1,6 +0,0 @@ -def agent_main(event_data: dict) -> dict: - return { - "event_id": event_data["event_id"], - "probability": 0.5, - "reasoning": "I think the probability of the event is 0.5", - } diff --git a/neurons/miner/agents/invalid_agent.py b/neurons/miner/agents/invalid_agent.py deleted file mode 100644 index 6f48ef9..0000000 --- a/neurons/miner/agents/invalid_agent.py +++ /dev/null @@ -1,15 +0,0 @@ -import random - -import bittensor as bt # Non-supported library - - -def agent(event_data: dict) -> dict: # Wrong entrypoint name -> "agent_main" - # Dummy usage - keypair = bt.Keypair.create_from_uri("//Alice") - forecast = random.random() - reasoning = f"I used random.random() to predict {forecast}. Keypair: {keypair.ss58_address}" - return { - "event_id": event_data["event_id"], - "probability": forecast, - "reasoning": reasoning, - } diff --git a/neurons/miner/gateway/app.py b/neurons/miner/gateway/app.py index 45e6ef7..ae93a7d 100644 --- a/neurons/miner/gateway/app.py +++ b/neurons/miner/gateway/app.py @@ -9,20 +9,11 @@ from neurons.miner.gateway.error_handler import handle_provider_errors from neurons.miner.gateway.providers.chutes import ChutesClient from neurons.miner.gateway.providers.desearch import DesearchClient -from neurons.validator.models.chutes import ChutesCompletion, ChuteStatus -from neurons.validator.models.desearch import ( - AISearchResponse, - WebCrawlResponse, - WebLinksResponse, - WebSearchResponse, -) -from neurons.validator.models.numinous_client import ( - ChutesInferenceRequest, - DesearchAISearchRequest, - DesearchWebCrawlRequest, - DesearchWebLinksRequest, - DesearchWebSearchRequest, -) +from neurons.validator.models import numinous_client as models +from neurons.validator.models.chutes import ChuteStatus +from neurons.validator.models.chutes import calculate_cost as calculate_chutes_cost +from neurons.validator.models.desearch import DesearchEndpoint +from neurons.validator.models.desearch import calculate_cost as calculate_desearch_cost logger = logging.getLogger(__name__) @@ -45,10 +36,10 @@ async def health_check(): return {"status": "healthy", "service": "API Gateway"} -@gateway_router.post("/chutes/chat/completions", response_model=ChutesCompletion) +@gateway_router.post("/chutes/chat/completions", response_model=models.GatewayChutesCompletion) @cached_gateway_call @handle_provider_errors("Chutes") -async def chutes_chat_completion(request: ChutesInferenceRequest) -> ChutesCompletion: +async def chutes_chat_completion(request: models.ChutesInferenceRequest) -> models.ChutesCompletion: api_key = os.getenv("CHUTES_API_KEY") if not api_key: raise HTTPException( @@ -58,7 +49,7 @@ async def chutes_chat_completion(request: ChutesInferenceRequest) -> ChutesCompl client = ChutesClient(api_key=api_key) messages = [msg.model_dump() for msg in request.messages] - return await client.chat_completion( + result = await client.chat_completion( model=request.model, messages=messages, temperature=request.temperature, @@ -68,6 +59,10 @@ async def chutes_chat_completion(request: ChutesInferenceRequest) -> ChutesCompl **(request.model_extra or {}), ) + return models.GatewayChutesCompletion( + **result.model_dump(), cost=calculate_chutes_cost(request.model, result) + ) + @gateway_router.get("/chutes/status", response_model=list[ChuteStatus]) @handle_provider_errors("Chutes") @@ -83,10 +78,12 @@ async def get_chutes_status() -> list[ChuteStatus]: return await client.get_chutes_status() -@gateway_router.post("/desearch/ai/search", response_model=AISearchResponse) +@gateway_router.post("/desearch/ai/search", response_model=models.GatewayDesearchAISearchResponse) @cached_gateway_call @handle_provider_errors("Desearch") -async def desearch_ai_search(request: DesearchAISearchRequest) -> AISearchResponse: +async def desearch_ai_search( + request: models.DesearchAISearchRequest, +) -> models.GatewayDesearchAISearchResponse: api_key = os.getenv("DESEARCH_API_KEY") if not api_key: raise HTTPException( @@ -95,7 +92,7 @@ async def desearch_ai_search(request: DesearchAISearchRequest) -> AISearchRespon ) client = DesearchClient(api_key=api_key) - return await client.ai_search( + result = await client.ai_search( prompt=request.prompt, model=request.model, tools=request.tools, @@ -105,13 +102,18 @@ async def desearch_ai_search(request: DesearchAISearchRequest) -> AISearchRespon count=request.count, ) + return models.GatewayDesearchAISearchResponse( + **result.model_dump(), + cost=calculate_desearch_cost(DesearchEndpoint.AI_SEARCH, request.model), + ) + -@gateway_router.post("/desearch/ai/links", response_model=WebLinksResponse) +@gateway_router.post("/desearch/ai/links", response_model=models.GatewayDesearchWebLinksResponse) @cached_gateway_call @handle_provider_errors("Desearch") async def desearch_web_links_search( - request: DesearchWebLinksRequest, -) -> WebLinksResponse: + request: models.DesearchWebLinksRequest, +) -> models.GatewayDesearchWebLinksResponse: api_key = os.getenv("DESEARCH_API_KEY") if not api_key: raise HTTPException( @@ -120,15 +122,21 @@ async def desearch_web_links_search( ) client = DesearchClient(api_key=api_key) - return await client.web_links_search( + result = await client.web_links_search( prompt=request.prompt, model=request.model, tools=request.tools, count=request.count ) + return models.GatewayDesearchWebLinksResponse( + **result.model_dump(), + cost=calculate_desearch_cost(DesearchEndpoint.AI_WEB_SEARCH, request.model), + ) -@gateway_router.post("/desearch/web/search", response_model=WebSearchResponse) +@gateway_router.post("/desearch/web/search", response_model=models.GatewayDesearchWebSearchResponse) @cached_gateway_call @handle_provider_errors("Desearch") -async def desearch_web_search(request: DesearchWebSearchRequest) -> WebSearchResponse: +async def desearch_web_search( + request: models.DesearchWebSearchRequest, +) -> models.GatewayDesearchWebSearchResponse: api_key = os.getenv("DESEARCH_API_KEY") if not api_key: raise HTTPException( @@ -137,15 +145,21 @@ async def desearch_web_search(request: DesearchWebSearchRequest) -> WebSearchRes ) client = DesearchClient(api_key=api_key) - return await client.web_search( + result = await client.web_search( query=request.query, num_results=request.num, start=request.start ) + return models.GatewayDesearchWebSearchResponse( + **result.model_dump(), + cost=calculate_desearch_cost(DesearchEndpoint.WEB_SEARCH), + ) -@gateway_router.post("/desearch/web/crawl", response_model=WebCrawlResponse) +@gateway_router.post("/desearch/web/crawl", response_model=models.GatewayDesearchWebCrawlResponse) @cached_gateway_call @handle_provider_errors("Desearch") -async def desearch_web_crawl(request: DesearchWebCrawlRequest) -> WebCrawlResponse: +async def desearch_web_crawl( + request: models.DesearchWebCrawlRequest, +) -> models.GatewayDesearchWebCrawlResponse: api_key = os.getenv("DESEARCH_API_KEY") if not api_key: raise HTTPException( @@ -154,7 +168,12 @@ async def desearch_web_crawl(request: DesearchWebCrawlRequest) -> WebCrawlRespon ) client = DesearchClient(api_key=api_key) - return await client.web_crawl(url=request.url) + result = await client.web_crawl(url=request.url) + + return models.GatewayDesearchWebCrawlResponse( + **result.model_dump(), + cost=calculate_desearch_cost(DesearchEndpoint.WEB_CRAWL), + ) app.include_router(gateway_router) diff --git a/neurons/validator/models/chutes.py b/neurons/validator/models/chutes.py index 4d3e547..2e43c38 100644 --- a/neurons/validator/models/chutes.py +++ b/neurons/validator/models/chutes.py @@ -117,3 +117,148 @@ class ChuteStatus(BaseModel): avg_busy_ratio: float total_invocations: float total_rate_limit_errors: float + + +class Chute(BaseModel): + name: str + input_cost: float + output_cost: float + + def calculate_cost(self, completion: ChutesCompletion) -> float: + return (self.input_cost / 1_000_000) * completion.usage.prompt_tokens + ( + self.output_cost / 1_000_000 + ) * completion.usage.completion_tokens + + +CHUTES_REGISTRY: dict[ChuteModel, Chute] = { + ChuteModel.DEEPSEEK_R1_SGTEST: Chute( + name=ChuteModel.DEEPSEEK_R1_SGTEST, + input_cost=0.3, + output_cost=1.2, + ), + ChuteModel.DEEPSEEK_R1_0528: Chute( + name=ChuteModel.DEEPSEEK_R1_0528, + input_cost=0.4, + output_cost=1.75, + ), + ChuteModel.DEEPSEEK_R1: Chute( + name=ChuteModel.DEEPSEEK_R1, + input_cost=0.3, + output_cost=1.2, + ), + ChuteModel.DEEPSEEK_V3_0324: Chute( + name=ChuteModel.DEEPSEEK_V3_0324, + input_cost=0.24, + output_cost=0.84, + ), + ChuteModel.DEEPSEEK_V3_1_TERMINUS: Chute( + name=ChuteModel.DEEPSEEK_V3_1_TERMINUS, + input_cost=0.23, + output_cost=0.9, + ), + ChuteModel.DEEPSEEK_V3_1: Chute( + name=ChuteModel.DEEPSEEK_V3_1, + input_cost=0.20, + output_cost=0.8, + ), + ChuteModel.DEEPSEEK_TNG_R1T2_CHIMERA: Chute( + name=ChuteModel.DEEPSEEK_TNG_R1T2_CHIMERA, + input_cost=0.3, + output_cost=1.2, + ), + ChuteModel.DEEPSEEK_V3_2_EXP: Chute( + name=ChuteModel.DEEPSEEK_V3_2_EXP, + input_cost=0.25, + output_cost=0.35, + ), + ChuteModel.GLM_4_6: Chute( + name=ChuteModel.GLM_4_6, + input_cost=0.4, + output_cost=1.75, + ), + ChuteModel.GLM_4_5: Chute( + name=ChuteModel.GLM_4_5, + input_cost=0.35, + output_cost=1.55, + ), + ChuteModel.GLM_4_5_AIR: Chute( + name=ChuteModel.GLM_4_5_AIR, + input_cost=0, + output_cost=0, + ), + ChuteModel.GEMMA_3_4B_IT: Chute( + name=ChuteModel.GEMMA_3_4B_IT, + input_cost=0, + output_cost=0, + ), + ChuteModel.GEMMA_3_27B_IT: Chute( + name=ChuteModel.GEMMA_3_27B_IT, + input_cost=0.13, + output_cost=0.52, + ), + ChuteModel.GEMMA_3_12B_IT: Chute( + name=ChuteModel.GEMMA_3_12B_IT, + input_cost=0.03, + output_cost=0.1, + ), + ChuteModel.QWEN3_32B: Chute( + name=ChuteModel.QWEN3_32B, + input_cost=0.05, + output_cost=0.2, + ), + ChuteModel.QWEN3_235B_A22B: Chute( + name=ChuteModel.QWEN3_235B_A22B, + input_cost=0.3, + output_cost=1.2, + ), + ChuteModel.QWEN2_5_VL_32B_INSTRUCT: Chute( + name=ChuteModel.QWEN2_5_VL_32B_INSTRUCT, + input_cost=0.05, + output_cost=0.22, + ), + ChuteModel.QWEN3_235B_A22B_INSTRUCT_2507: Chute( + name=ChuteModel.QWEN3_235B_A22B_INSTRUCT_2507, + input_cost=0.08, + output_cost=0.55, + ), + ChuteModel.QWEN3_VL_235B_A22B_THINKING: Chute( + name=ChuteModel.QWEN3_VL_235B_A22B_THINKING, + input_cost=0.3, + output_cost=1.2, + ), + ChuteModel.MISTRAL_SMALL_24B_INSTRUCT_2501: Chute( + name=ChuteModel.MISTRAL_SMALL_24B_INSTRUCT_2501, + input_cost=0.05, + output_cost=0.22, + ), + ChuteModel.GPT_OSS_20B: Chute( + name=ChuteModel.GPT_OSS_20B, + input_cost=0, + output_cost=0, + ), + ChuteModel.GPT_OSS_120B: Chute( + name=ChuteModel.GPT_OSS_120B, + input_cost=0.04, + output_cost=0.4, + ), +} + + +def get_chute(model: typing.Union[ChuteModel, str]) -> Chute: + if isinstance(model, str): + try: + model = ChuteModel(model) + except ValueError: + available = ", ".join(m.value for m in ChuteModel) + raise ValueError(f"Model '{model}' is not available. Available models: {available}") + + return CHUTES_REGISTRY[model] + + +def list_available_models() -> list[str]: + return [model.value for model in ChuteModel] + + +def calculate_cost(model: typing.Union[ChuteModel, str], completion: ChutesCompletion) -> float: + chute = get_chute(model) + return chute.calculate_cost(completion) diff --git a/neurons/validator/models/desearch.py b/neurons/validator/models/desearch.py index f2180f1..7631fc7 100644 --- a/neurons/validator/models/desearch.py +++ b/neurons/validator/models/desearch.py @@ -89,3 +89,50 @@ class WebCrawlResponse(BaseModel): content: str model_config = ConfigDict(extra="allow") + + +class DesearchEndpoint(StrEnum): + AI_SEARCH = "ai_search" + AI_WEB_SEARCH = "ai_web_search" + WEB_SEARCH = "web_search" + WEB_CRAWL = "web_crawl" + + +# Cost per 100 searches +DESEARCH_PRICING: typing.Dict[DesearchEndpoint, typing.Any] = { + DesearchEndpoint.AI_SEARCH: { + ModelEnum.NOVA: 0.6, + ModelEnum.ORBIT: 2.2, + ModelEnum.HORIZON: 2.6, + }, + DesearchEndpoint.AI_WEB_SEARCH: { + ModelEnum.NOVA: 0.6, + ModelEnum.ORBIT: 1.7, + ModelEnum.HORIZON: 2.1, + }, + DesearchEndpoint.WEB_SEARCH: 0.25, + DesearchEndpoint.WEB_CRAWL: 0.05, +} + + +def calculate_cost( + endpoint: DesearchEndpoint, + model: typing.Optional[ModelEnum] = None, +) -> float: + pricing = DESEARCH_PRICING.get(endpoint) + if pricing is None: + raise ValueError(f"No pricing found for endpoint: {endpoint}") + + if isinstance(pricing, dict): + if model is None: + raise ValueError(f"Model is required for {endpoint}") + cost_per_100 = pricing.get(model) + if cost_per_100 is None: + available = ", ".join(m.value for m in pricing.keys()) + raise ValueError( + f"Model '{model}' not available for {endpoint}. " f"Available models: {available}" + ) + else: + cost_per_100 = pricing + + return cost_per_100 / 100 diff --git a/neurons/validator/models/numinous_client.py b/neurons/validator/models/numinous_client.py index 30544ce..8e7d974 100644 --- a/neurons/validator/models/numinous_client.py +++ b/neurons/validator/models/numinous_client.py @@ -4,12 +4,16 @@ from pydantic import BaseModel, ConfigDict, Field -from neurons.validator.models.chutes import ChuteModel, Message +from neurons.validator.models.chutes import ChuteModel, ChutesCompletion, Message from neurons.validator.models.desearch import ( + AISearchResponse, DateFilterEnum, ModelEnum, ResultTypeEnum, ToolEnum, + WebCrawlResponse, + WebLinksResponse, + WebSearchResponse, WebToolEnum, ) @@ -195,3 +199,27 @@ class DesearchWebSearchRequest(GatewayCall): class DesearchWebCrawlRequest(GatewayCall): url: str = Field(..., description="The URL to crawl") + + +class GatewayCallResponse(BaseModel): + cost: float + + +class GatewayChutesCompletion(ChutesCompletion, GatewayCallResponse): + pass + + +class GatewayDesearchAISearchResponse(AISearchResponse, GatewayCallResponse): + pass + + +class GatewayDesearchWebLinksResponse(WebLinksResponse, GatewayCallResponse): + pass + + +class GatewayDesearchWebSearchResponse(WebSearchResponse, GatewayCallResponse): + pass + + +class GatewayDesearchWebCrawlResponse(WebCrawlResponse, GatewayCallResponse): + pass