From 7e5be4ea842fcb15be9126d3601bae5d029e73b4 Mon Sep 17 00:00:00 2001 From: charles Date: Tue, 8 Jul 2025 22:01:04 -0700 Subject: [PATCH] Fix the model mapping when adding provider keys --- app/api/schemas/provider_key.py | 17 ++++++++++++++++- forge-cli.py | 11 +++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/app/api/schemas/provider_key.py b/app/api/schemas/provider_key.py index 94f284f..aaaac2f 100644 --- a/app/api/schemas/provider_key.py +++ b/app/api/schemas/provider_key.py @@ -1,6 +1,7 @@ +import json from datetime import datetime -from pydantic import BaseModel, ConfigDict, computed_field, Field +from pydantic import BaseModel, ConfigDict, computed_field, Field, field_validator from app.core.logger import get_logger from app.core.security import decrypt_api_key @@ -78,6 +79,20 @@ class ProviderKeyInDBBase(BaseModel): encrypted_api_key: str model_config = ConfigDict(from_attributes=True) + @field_validator("model_mapping", mode="before") + @classmethod + def parse_model_mapping(cls, v): + """Parse JSON string to dictionary for model_mapping field.""" + if v is None: + return None + if isinstance(v, str): + try: + return json.loads(v) + except json.JSONDecodeError: + logger.warning(f"Failed to parse model_mapping JSON: {v}") + return {} + return v + class ProviderKey(ProviderKeyInDBBase): @computed_field diff --git a/forge-cli.py b/forge-cli.py index 41fc5c5..1947449 100755 --- a/forge-cli.py +++ b/forge-cli.py @@ -152,8 +152,8 @@ def add_provider_key( provider_name: str, api_key: str, base_url: str | None = None, - model_mapping: dict[str, str] | None = None, - config: dict[str, str] | None = None, + model_mapping: str | None = None, + config: str | None = None, ) -> bool: """Add a provider key""" if not self.token: @@ -170,7 +170,7 @@ def add_provider_key( "provider_name": provider_name, "api_key": api_key, "base_url": base_url, - "model_mapping": model_mapping, + "model_mapping": json.loads(model_mapping) if model_mapping else None, "config": json.loads(config) if config else None, } @@ -209,6 +209,8 @@ def list_provider_keys(self) -> list[dict[str, Any]]: print(f" Base URL: {key['base_url']}") if key.get("config"): print(f" Config: {key['config']}") + if key.get("model_mapping"): + print(f" Model Mapping: {key['model_mapping']}") return keys else: print(f"❌ Error listing provider keys: {response.status_code}") @@ -612,7 +614,8 @@ def main(): key = getpass("Enter provider API key: ") base_url = input("Enter provider base URL (optional, press Enter to skip): ") config = input("Enter provider config in json string format (optional, press Enter to skip): ") - forge.add_provider_key(provider, key, base_url=base_url, config=config) + model_mapping = input("Enter model ampping config in json string format (optional, press Enter to skip): ") + forge.add_provider_key(provider, key, base_url=base_url, config=config, model_mapping=model_mapping) elif choice == "9": if not forge.token: