From 11ad86d2deabd3c58916c0816da592469c697724 Mon Sep 17 00:00:00 2001 From: Patricia Bedard Date: Thu, 17 Oct 2024 12:44:43 -0400 Subject: [PATCH 1/8] add LLM model & manager, get-all, get and completion API endpoints --- apis/paios/openapi.yaml | 123 +++++++++++++ backend/api/LlmsView.py | 43 +++++ backend/api/__init__.py | 1 + backend/managers/LlmsManager.py | 167 ++++++++++++++++++ backend/models.py | 8 + backend/requirements.txt | 1 + backend/schemas.py | 8 + .../versions/73d50424c826_added_llm_table.py | 35 ++++ 8 files changed, 386 insertions(+) create mode 100644 backend/api/LlmsView.py create mode 100644 backend/managers/LlmsManager.py create mode 100644 migrations/versions/73d50424c826_added_llm_table.py diff --git a/apis/paios/openapi.yaml b/apis/paios/openapi.yaml index 68a4dd9e..c1259dd6 100644 --- a/apis/paios/openapi.yaml +++ b/apis/paios/openapi.yaml @@ -700,6 +700,8 @@ paths: - $ref: '#/components/parameters/id' /shares: get: + security: + - jwt: [] tags: - Share Management summary: Retrieve all share links @@ -721,6 +723,8 @@ paths: X-Total-Count: $ref: '#/components/headers/X-Total-Count' post: + security: + - jwt: [] summary: Create new share link tags: - Share Management @@ -737,6 +741,8 @@ paths: $ref: '#/components/schemas/ShareCreate' '/shares/{id}': get: + security: + - jwt: [] tags: - Share Management summary: Retrieve share link by id @@ -751,6 +757,8 @@ paths: schema: $ref: '#/components/schemas/Share' put: + security: + - jwt: [] tags: - Share Management summary: Update share link by id @@ -766,6 +774,8 @@ paths: '200': description: OK delete: + security: + - jwt: [] tags: - Share Management summary: Delete share link by id @@ -777,6 +787,74 @@ paths: description: No Content '404': description: Not Found + /llms: + get: + tags: + - LLM Management + summary: Retrieve all available LLMs + description: Get all installed / available LLMs. + parameters: + - $ref: '#/components/parameters/sort' + - $ref: '#/components/parameters/range' + - $ref: '#/components/parameters/filter' + responses: + '200': + description: OK + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Llm' + headers: + X-Total-Count: + $ref: '#/components/headers/X-Total-Count' + '/llms/{id}': + get: + tags: + - LLM Management + summary: Retrieve LLM by id + description: Retrieve the LLM with the specified id. + parameters: + - $ref: '#/components/parameters/kebab-dot_id' + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/Llm' + '404': + description: LLM not found + '/llms/{id}/completion': + post: + tags: + - LLM Management + summary: Invoke Completion on LLM + description: Invoke Completion on the LLM with the specified id. + operationId: backend.api.LlmsView.completion + parameters: + - $ref: '#/components/parameters/kebab-dot_id' + requestBody: + description: Messages to input to Completion function. + content: + application/json: + schema: + type: object + properties: + messages: + $ref: '#/components/schemas/messagesList' + responses: + '200': + description: Completion succeeded + content: + application/json: + schema: + type: object + '400': + description: Completion failed + '404': + description: LLM not found /auth/webauthn/register-options: post: summary: Generate WebAuthn registration options @@ -936,6 +1014,13 @@ components: required: true schema: $ref: '#/components/schemas/kebab-snake_id' + kebab-dot_id: + name: id + in: path + description: id of the object + required: true + schema: + $ref: '#/components/schemas/kebab-dot_id' key: name: key in: path @@ -1003,6 +1088,12 @@ components: maxLength: 100 example: langchain-core pattern: ^[a-z0-9]+([_-][a-z0-9]+)*$ + kebab-dot_id: + type: string + minLength: 2 + maxLength: 100 + example: ollama-llama3.2 + pattern: ^[a-z0-9]+([.-][a-z0-9]+)*$ semVer: type: string example: '1.1.0' @@ -1068,6 +1159,16 @@ components: readOnly: true example: abcd-efgh-ijkl pattern: ^[a-z]{4}-[a-z]{4}-[a-z]{4}$ + messagesList: + type: array + example: [{"role": "user", "content": "What is Kwaai.ai?"}] + items: + type: object + properties: + role: + type: string + content: + type: string download: type: object properties: @@ -1390,6 +1491,28 @@ components: example: false required: - resource_id + Llm: + type: object + title: Llm + properties: + id: + type: string + name: + type: string + llm_name: + type: string + provider: + type: string + api_base: + type: string + nullable: true + is_active: + type: boolean + required: + - id + - name + - provider + - is_active RegistrationOptions: type: object properties: diff --git a/backend/api/LlmsView.py b/backend/api/LlmsView.py new file mode 100644 index 00000000..e0cdca33 --- /dev/null +++ b/backend/api/LlmsView.py @@ -0,0 +1,43 @@ +from starlette.responses import JSONResponse +from backend.managers.LlmsManager import LlmsManager +from backend.pagination import parse_pagination_params + +class LlmsView: + def __init__(self): + self.llmm = LlmsManager() + + async def get(self, id: str): + llm = await self.llmm.get_llm(id) + if llm is None: + return JSONResponse(headers={"error": "LLM not found"}, status_code=404) + return JSONResponse(llm.model_dump(), status_code=200) + + async def search(self, filter: str = None, range: str = None, sort: str = None): + result = parse_pagination_params(filter, range, sort) + if isinstance(result, JSONResponse): + return result + + offset, limit, sort_by, sort_order, filters = result + + llms, total_count = await self.llmm.retrieve_llms(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters) + headers = { + 'X-Total-Count': str(total_count), + 'Content-Range': f'shares {offset}-{offset + len(llms) - 1}/{total_count}' + } + return JSONResponse([llm.model_dump() for llm in llms], status_code=200, headers=headers) + + async def completion(self, id: str, body: dict): + print("completion. body: {}".format(body)) + messages = [] + if 'messages' in body and body['messages']: + messages = body['messages'] + llm = await self.llmm.get_llm(id) + if llm: + print("Selected LLM is {}".format(llm.llm_name)) + response = self.llmm.completion(llm, messages) + if response: + return JSONResponse(response.model_dump(), status_code=200) + else: + return JSONResponse(status_code=400, content={"message": "Completion failed"}) + else: + return JSONResponse(status_code=404, content={"message": "LLM not found"}) diff --git a/backend/api/__init__.py b/backend/api/__init__.py index 697d4c4d..78cf1bfe 100644 --- a/backend/api/__init__.py +++ b/backend/api/__init__.py @@ -9,3 +9,4 @@ from .PersonasView import PersonasView from .SharesView import SharesView from .AuthView import AuthView +from .LlmsView import LlmsView diff --git a/backend/managers/LlmsManager.py b/backend/managers/LlmsManager.py new file mode 100644 index 00000000..0c7872bb --- /dev/null +++ b/backend/managers/LlmsManager.py @@ -0,0 +1,167 @@ +import asyncio +import httpx +from threading import Lock +from sqlalchemy import select, insert, update, delete, func +from backend.models import Llm +from backend.db import db_session_context +from backend.schemas import LlmSchema +from backend.utils import get_env_key +from typing import List, Tuple, Optional, Dict, Any, Union +from litellm import Router +from litellm.utils import CustomStreamWrapper, ModelResponse + +import logging +logger = logging.getLogger(__name__) + +class LlmsManager: + _instance = None + _lock = Lock() + + def __new__(cls, *args, **kwargs): + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(LlmsManager, cls).__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self): + if not hasattr(self, '_initialized'): # Ensure initialization happens only once + with self._lock: + if not hasattr(self, '_initialized'): + self.router = None + router_init_task = asyncio.create_task(self._init_router()) + asyncio.gather(router_init_task, return_exceptions=True) + self._initialized = True + + async def _init_router(self): + try: + # load models + ollama_task = asyncio.create_task(self._load_ollama_models()) + await asyncio.gather(ollama_task, return_exceptions=True) + # collect the available LLMs + llms, total_llms = await self.retrieve_llms() + # configure router + model_list = [] + for llm in llms: + params = {} + params["model"] = llm.llm_name + if llm.provider == "ollama": + params["api_base"] = llm.api_base + model = { + "model_name": llm.llm_name, + "litellm_params": params, + } + model_list.append(model) + print(model_list) + self.router = Router(model_list=model_list) + except Exception as e: + logger.exception(e) + + async def _load_ollama_models(self): + try: + ollama_urlroot = get_env_key("OLLAMA_URLROOT") # eg: http://localhost:11434 + except ValueError: + return # no Ollama server specified + # retrieve list of installed models + async with httpx.AsyncClient() as client: + response = await client.get("{}/api/tags".format(ollama_urlroot)) + if response.status_code == 200: + data = response.json() + available_models = [model_data['model'] for model_data in data.get("models", [])] + print(available_models) + else: + pass # FIX + # create / update Ollama family Llm objects + provider = "ollama" + async with db_session_context() as session: + # mark existing models as inactive + stmt = update(Llm).where(Llm.provider == provider).values(is_active=False) + result = await session.execute(stmt) + if result.rowcount > 0: + await session.commit() + # insert / update models + for model in available_models: + name = model.removesuffix(":latest") + llm_name = "{}/{}".format(provider,name) # what LiteLLM expects + safe_name = llm_name.replace("/", "-").replace(":", "-") + result = await session.execute(select(Llm).filter(Llm.id == safe_name)) + llm = result.scalar_one_or_none() + if llm: + stmt = update(Llm).where(Llm.id == safe_name).values(name=name, + llm_name=llm_name, + provider=provider, + api_base=ollama_urlroot, + is_active=True) + result = await session.execute(stmt) + if result.rowcount > 0: + await session.commit() + else: + new_llm = Llm(id=safe_name, name=name, llm_name=llm_name, + provider=provider, api_base=ollama_urlroot, + is_active=True) + session.add(new_llm) + await session.commit() + + async def get_llm(self, id: str) -> Optional[LlmSchema]: + async with db_session_context() as session: + result = await session.execute(select(Llm).filter(Llm.id == id)) + llm = result.scalar_one_or_none() + if llm: + return LlmSchema(id=llm.id, name=llm.name, llm_name=llm.llm_name, + provider=llm.provider, api_base=llm.api_base, is_active=llm.is_active) + return None + + async def retrieve_llms(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None, + sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None) -> Tuple[List[LlmSchema], int]: + async with db_session_context() as session: + query = select(Llm).filter(Llm.is_active == True) + + if filters: + for key, value in filters.items(): + if isinstance(value, list): + query = query.filter(getattr(Llm, key).in_(value)) + else: + query = query.filter(getattr(Llm, key) == value) + + if sort_by and sort_by in ['id', 'name', 'llm_name', 'provider', 'api_base', 'is_active']: + order_column = getattr(Llm, sort_by) + query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) + + query = query.offset(offset).limit(limit) + + result = await session.execute(query) + llms = [LlmSchema(id=llm.id, name=llm.name, llm_name=llm.llm_name, + provider=llm.provider, api_base=llm.api_base, + is_active=llm.is_active) + for llm in result.scalars().all()] + + # Get total count + count_query = select(func.count()).select_from(Llm).filter(Llm.is_active == True) + if filters: + for key, value in filters.items(): + if isinstance(value, list): + count_query = count_query.filter(getattr(Llm, key).in_(value)) + else: + count_query = count_query.filter(getattr(Llm, key) == value) + + total_count = await session.execute(count_query) + total_count = total_count.scalar() + + return llms, total_count + + def completion(self, llm, messages, **optional_params) -> Union[ModelResponse, CustomStreamWrapper]: + response = self.router.completion(model=llm.llm_name, + messages=messages, + kwargs=optional_params) + print("completion response: {}".format(response)) + #message = response.choices[0].message.content + #return message + return response + + async def acompletion(self, llm, messages, **optional_params) -> Union[CustomStreamWrapper, ModelResponse]: + response = await self.router.acompletion(model=llm.llm_name, + messages=messages, + kwargs=optional_params) + print("acompletion response: {}".format(response)) + return response + \ No newline at end of file diff --git a/backend/models.py b/backend/models.py index 251cbe4f..28a1f17b 100644 --- a/backend/models.py +++ b/backend/models.py @@ -65,6 +65,14 @@ class Share(SQLModelBase, table=True): expiration_dt: datetime | None = Field(default=None) # the link expiration date/time (optional) is_revoked: bool = Field() +class Llm(SQLModelBase, table=True): + id: str = Field(primary_key=True) # the model's unique, URL-friendly name + name: str = Field() + llm_name: str = Field() # the model name known to LiteLLM + provider: str = Field() # model provider, eg "ollama" + api_base: str | None = Field(default=None) + is_active: bool = Field() # is the model installed / available? + # Resolve forward references User.model_rebuild() Cred.model_rebuild() diff --git a/backend/requirements.txt b/backend/requirements.txt index 0e5dd6b4..6cc65691 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -20,3 +20,4 @@ structlog webauthn greenlet pyjwt +litellm diff --git a/backend/schemas.py b/backend/schemas.py index fce66742..a42ba8c3 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -88,6 +88,14 @@ class ShareCreateSchema(ShareBaseSchema): class ShareSchema(ShareBaseSchema): id: str +class LlmSchema(BaseModel): + id: str + name: str + llm_name: str + provider: str + api_base: Optional[str] = None + is_active: bool + class RegistrationOptions(BaseModel): email: str diff --git a/migrations/versions/73d50424c826_added_llm_table.py b/migrations/versions/73d50424c826_added_llm_table.py new file mode 100644 index 00000000..32463cb8 --- /dev/null +++ b/migrations/versions/73d50424c826_added_llm_table.py @@ -0,0 +1,35 @@ +"""added llm table + +Revision ID: 73d50424c826 +Revises: 187855982332 +Create Date: 2024-10-10 17:31:43.949960 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = '73d50424c826' +down_revision: Union[str, None] = '187855982332' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table('llm', + sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('llm_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('api_base', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + + +def downgrade() -> None: + op.drop_table('llm') From 67f76dd2ab879c52ed91f2bfe254783c23b47cb4 Mon Sep 17 00:00:00 2001 From: Patricia Bedard Date: Tue, 29 Oct 2024 13:36:36 -0400 Subject: [PATCH 2/8] include OpenAI models in LLM router initialization --- apis/paios/openapi.yaml | 3 +- backend/api/LlmsView.py | 11 +++- backend/managers/LlmsManager.py | 98 ++++++++++++++++++++++++++------- backend/schemas.py | 2 +- 4 files changed, 89 insertions(+), 25 deletions(-) diff --git a/apis/paios/openapi.yaml b/apis/paios/openapi.yaml index c1259dd6..96e970fc 100644 --- a/apis/paios/openapi.yaml +++ b/apis/paios/openapi.yaml @@ -1499,7 +1499,7 @@ components: type: string name: type: string - llm_name: + full_name: type: string provider: type: string @@ -1511,6 +1511,7 @@ components: required: - id - name + - full_name - provider - is_active RegistrationOptions: diff --git a/backend/api/LlmsView.py b/backend/api/LlmsView.py index e0cdca33..65eb6cee 100644 --- a/backend/api/LlmsView.py +++ b/backend/api/LlmsView.py @@ -1,6 +1,7 @@ from starlette.responses import JSONResponse from backend.managers.LlmsManager import LlmsManager from backend.pagination import parse_pagination_params +from backend.schemas import LlmSchema class LlmsView: def __init__(self): @@ -10,7 +11,9 @@ async def get(self, id: str): llm = await self.llmm.get_llm(id) if llm is None: return JSONResponse(headers={"error": "LLM not found"}, status_code=404) - return JSONResponse(llm.model_dump(), status_code=200) + llm_schema = LlmSchema(id=llm.id, name=llm.name, full_name=f"{llm.provider}/{llm.name}", + provider=llm.provider, api_base=llm.api_base, is_active=llm.is_active) + return JSONResponse(llm_schema.model_dump(), status_code=200) async def search(self, filter: str = None, range: str = None, sort: str = None): result = parse_pagination_params(filter, range, sort) @@ -20,11 +23,15 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): offset, limit, sort_by, sort_order, filters = result llms, total_count = await self.llmm.retrieve_llms(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters) + results = [LlmSchema(id=llm.id, name=llm.name, full_name=f"{llm.provider}/{llm.name}", + provider=llm.provider, api_base=llm.api_base, + is_active=llm.is_active) + for llm in llms] headers = { 'X-Total-Count': str(total_count), 'Content-Range': f'shares {offset}-{offset + len(llms) - 1}/{total_count}' } - return JSONResponse([llm.model_dump() for llm in llms], status_code=200, headers=headers) + return JSONResponse([llm.model_dump() for llm in results], status_code=200, headers=headers) async def completion(self, id: str, body: dict): print("completion. body: {}".format(body)) diff --git a/backend/managers/LlmsManager.py b/backend/managers/LlmsManager.py index 0c7872bb..d813b040 100644 --- a/backend/managers/LlmsManager.py +++ b/backend/managers/LlmsManager.py @@ -4,7 +4,6 @@ from sqlalchemy import select, insert, update, delete, func from backend.models import Llm from backend.db import db_session_context -from backend.schemas import LlmSchema from backend.utils import get_env_key from typing import List, Tuple, Optional, Dict, Any, Union from litellm import Router @@ -37,22 +36,29 @@ async def _init_router(self): try: # load models ollama_task = asyncio.create_task(self._load_ollama_models()) - await asyncio.gather(ollama_task, return_exceptions=True) + openai_task = asyncio.create_task(self._load_openai_models()) + await asyncio.gather(ollama_task, + openai_task, + return_exceptions=True) # collect the available LLMs llms, total_llms = await self.retrieve_llms() # configure router model_list = [] for llm in llms: + model_name = f"{llm.provider}/{llm.name}" params = {} params["model"] = llm.llm_name if llm.provider == "ollama": params["api_base"] = llm.api_base + if llm.provider == "openai": + params["api_key"] = get_env_key("OPENAI_API_KEY") model = { - "model_name": llm.llm_name, + "model_name": model_name, "litellm_params": params, } model_list.append(model) - print(model_list) + #import pprint + #pprint.pprint(model_list) self.router = Router(model_list=model_list) except Exception as e: logger.exception(e) @@ -84,8 +90,7 @@ async def _load_ollama_models(self): name = model.removesuffix(":latest") llm_name = "{}/{}".format(provider,name) # what LiteLLM expects safe_name = llm_name.replace("/", "-").replace(":", "-") - result = await session.execute(select(Llm).filter(Llm.id == safe_name)) - llm = result.scalar_one_or_none() + llm = self.get_llm(safe_name) if llm: stmt = update(Llm).where(Llm.id == safe_name).values(name=name, llm_name=llm_name, @@ -102,17 +107,75 @@ async def _load_ollama_models(self): session.add(new_llm) await session.commit() - async def get_llm(self, id: str) -> Optional[LlmSchema]: + async def _load_openai_models(self): + try: + openai_key = get_env_key("OPENAI_API_KEY") + except ValueError: + print("No OpenAI API key specified. Skipping.") + return # no OpenAI API key specified + # retrieve list of installed models + async with httpx.AsyncClient() as client: + openai_urlroot = get_env_key("OPENAI_URLROOT", "https://api.openai.com") + headers = { + "Authorization": f"Bearer {openai_key}" + } + response = await client.get(f"{openai_urlroot}/v1/models", headers=headers) + print(vars(response)) + if response.status_code == 200: + data = response.json() + import json + pretty_json = json.dumps(data, indent=4) + print(pretty_json) + available_models = [model_data['id'] for model_data in data.get("data", [])] + print(available_models) + else: + print(f"Error: {response.status_code} - {response.text}") # FIX + # create / update OpenAI family Llm objects + provider = "openai" + async with db_session_context() as session: + # mark existing models as inactive + stmt = update(Llm).where(Llm.provider == provider).values(is_active=False) + result = await session.execute(stmt) + if result.rowcount > 0: + await session.commit() + # insert / update models + for model in available_models: + llm_provider = None + if any(substring in model for substring in {"gpt","o1","chatgpt"}): + llm_provider = "openai" + if any(substring in model for substring in {"ada","babbage","curie","davinci","instruct"}): + llm_provider = "text-completion-openai" + if llm_provider: + name = model + llm_name = "{}/{}".format(llm_provider,name) # what LiteLLM expects + safe_name = f"{provider}/{name}".replace("/", "-").replace(":", "-") + llm = self.get_llm(safe_name) + if llm: + stmt = update(Llm).where(Llm.id == safe_name).values(name=name, + llm_name=llm_name, + provider=provider, + api_base=openai_urlroot, + is_active=True) + result = await session.execute(stmt) + if result.rowcount > 0: + await session.commit() + else: + new_llm = Llm(id=safe_name, name=name, llm_name=llm_name, + provider=provider, api_base=openai_urlroot, + is_active=True) + session.add(new_llm) + await session.commit() + + async def get_llm(self, id: str) -> Optional[Llm]: async with db_session_context() as session: result = await session.execute(select(Llm).filter(Llm.id == id)) llm = result.scalar_one_or_none() if llm: - return LlmSchema(id=llm.id, name=llm.name, llm_name=llm.llm_name, - provider=llm.provider, api_base=llm.api_base, is_active=llm.is_active) + return llm return None async def retrieve_llms(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None, - sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None) -> Tuple[List[LlmSchema], int]: + sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None) -> Tuple[List[Llm], int]: async with db_session_context() as session: query = select(Llm).filter(Llm.is_active == True) @@ -123,17 +186,14 @@ async def retrieve_llms(self, offset: int = 0, limit: int = 100, sort_by: Option else: query = query.filter(getattr(Llm, key) == value) - if sort_by and sort_by in ['id', 'name', 'llm_name', 'provider', 'api_base', 'is_active']: + if sort_by and sort_by in ['id', 'name', 'provider', 'api_base', 'is_active']: order_column = getattr(Llm, sort_by) query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) query = query.offset(offset).limit(limit) result = await session.execute(query) - llms = [LlmSchema(id=llm.id, name=llm.name, llm_name=llm.llm_name, - provider=llm.provider, api_base=llm.api_base, - is_active=llm.is_active) - for llm in result.scalars().all()] + llms = result.scalars().all() # Get total count count_query = select(func.count()).select_from(Llm).filter(Llm.is_active == True) @@ -151,17 +211,13 @@ async def retrieve_llms(self, offset: int = 0, limit: int = 100, sort_by: Option def completion(self, llm, messages, **optional_params) -> Union[ModelResponse, CustomStreamWrapper]: response = self.router.completion(model=llm.llm_name, - messages=messages, - kwargs=optional_params) + messages=messages) print("completion response: {}".format(response)) - #message = response.choices[0].message.content - #return message return response async def acompletion(self, llm, messages, **optional_params) -> Union[CustomStreamWrapper, ModelResponse]: response = await self.router.acompletion(model=llm.llm_name, - messages=messages, - kwargs=optional_params) + messages=messages) print("acompletion response: {}".format(response)) return response \ No newline at end of file diff --git a/backend/schemas.py b/backend/schemas.py index a42ba8c3..b68772e5 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -91,7 +91,7 @@ class ShareSchema(ShareBaseSchema): class LlmSchema(BaseModel): id: str name: str - llm_name: str + full_name: str provider: str api_base: Optional[str] = None is_active: bool From 51ad1d15053df50ff2512c6120c7d371ee2e872e Mon Sep 17 00:00:00 2001 From: Patricia Bedard Date: Thu, 31 Oct 2024 13:30:54 -0400 Subject: [PATCH 3/8] add support for optional completion parameters --- apis/paios/openapi.yaml | 9 +++++++++ backend/api/LlmsView.py | 21 +++++++++++++-------- backend/managers/LlmsManager.py | 33 +++++++++++++++++++-------------- 3 files changed, 41 insertions(+), 22 deletions(-) diff --git a/apis/paios/openapi.yaml b/apis/paios/openapi.yaml index 96e970fc..6a3b150a 100644 --- a/apis/paios/openapi.yaml +++ b/apis/paios/openapi.yaml @@ -844,6 +844,10 @@ paths: properties: messages: $ref: '#/components/schemas/messagesList' + optional_params: + $ref: '#/components/schemas/completionParamList' + required: + - messages responses: '200': description: Completion succeeded @@ -979,6 +983,8 @@ tags: description: Management of personas - name: Share Management description: Management of share links + - name: LLM Management + description: Discovery and invocation of LLM functionality components: securitySchemes: jwt: @@ -1169,6 +1175,9 @@ components: type: string content: type: string + completionParamList: + type: object + example: {"max_tokens": 50, "temperature": 0.2} download: type: object properties: diff --git a/backend/api/LlmsView.py b/backend/api/LlmsView.py index 65eb6cee..307abb5c 100644 --- a/backend/api/LlmsView.py +++ b/backend/api/LlmsView.py @@ -2,6 +2,7 @@ from backend.managers.LlmsManager import LlmsManager from backend.pagination import parse_pagination_params from backend.schemas import LlmSchema +from litellm.exceptions import BadRequestError class LlmsView: def __init__(self): @@ -35,16 +36,20 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): async def completion(self, id: str, body: dict): print("completion. body: {}".format(body)) - messages = [] - if 'messages' in body and body['messages']: - messages = body['messages'] llm = await self.llmm.get_llm(id) if llm: - print("Selected LLM is {}".format(llm.llm_name)) - response = self.llmm.completion(llm, messages) - if response: + messages = [] + if 'messages' in body and body['messages']: + messages = body['messages'] + opt_params = {} + if 'optional_params' in body and body['optional_params']: + opt_params = body['optional_params'] + try: + response = self.llmm.completion(llm, messages, **opt_params) return JSONResponse(response.model_dump(), status_code=200) - else: - return JSONResponse(status_code=400, content={"message": "Completion failed"}) + except BadRequestError as e: + return JSONResponse(status_code=400, content={"message": e.message}) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "Completion failed."}) else: return JSONResponse(status_code=404, content={"message": "LLM not found"}) diff --git a/backend/managers/LlmsManager.py b/backend/managers/LlmsManager.py index d813b040..081ee94f 100644 --- a/backend/managers/LlmsManager.py +++ b/backend/managers/LlmsManager.py @@ -6,7 +6,7 @@ from backend.db import db_session_context from backend.utils import get_env_key from typing import List, Tuple, Optional, Dict, Any, Union -from litellm import Router +from litellm import Router, completion from litellm.utils import CustomStreamWrapper, ModelResponse import logging @@ -74,7 +74,7 @@ async def _load_ollama_models(self): if response.status_code == 200: data = response.json() available_models = [model_data['model'] for model_data in data.get("models", [])] - print(available_models) + #print(available_models) else: pass # FIX # create / update Ollama family Llm objects @@ -90,7 +90,7 @@ async def _load_ollama_models(self): name = model.removesuffix(":latest") llm_name = "{}/{}".format(provider,name) # what LiteLLM expects safe_name = llm_name.replace("/", "-").replace(":", "-") - llm = self.get_llm(safe_name) + llm = await self.get_llm(safe_name) if llm: stmt = update(Llm).where(Llm.id == safe_name).values(name=name, llm_name=llm_name, @@ -120,14 +120,10 @@ async def _load_openai_models(self): "Authorization": f"Bearer {openai_key}" } response = await client.get(f"{openai_urlroot}/v1/models", headers=headers) - print(vars(response)) if response.status_code == 200: data = response.json() - import json - pretty_json = json.dumps(data, indent=4) - print(pretty_json) available_models = [model_data['id'] for model_data in data.get("data", [])] - print(available_models) + #print(available_models) else: print(f"Error: {response.status_code} - {response.text}") # FIX # create / update OpenAI family Llm objects @@ -149,7 +145,7 @@ async def _load_openai_models(self): name = model llm_name = "{}/{}".format(llm_provider,name) # what LiteLLM expects safe_name = f"{provider}/{name}".replace("/", "-").replace(":", "-") - llm = self.get_llm(safe_name) + llm = await self.get_llm(safe_name) if llm: stmt = update(Llm).where(Llm.id == safe_name).values(name=name, llm_name=llm_name, @@ -210,14 +206,23 @@ async def retrieve_llms(self, offset: int = 0, limit: int = 100, sort_by: Option return llms, total_count def completion(self, llm, messages, **optional_params) -> Union[ModelResponse, CustomStreamWrapper]: - response = self.router.completion(model=llm.llm_name, - messages=messages) - print("completion response: {}".format(response)) - return response + try: + response = self.router.completion(model=llm.llm_name, + messages=messages, + **optional_params) + #response = completion(model=llm.llm_name, + # messages=messages, + # **kwargs) + print("completion response: {}".format(response)) + return response + except Exception as e: + logger.info(f"completion failed with error: {e.message}") + raise async def acompletion(self, llm, messages, **optional_params) -> Union[CustomStreamWrapper, ModelResponse]: response = await self.router.acompletion(model=llm.llm_name, - messages=messages) + messages=messages, + **optional_params) print("acompletion response: {}".format(response)) return response \ No newline at end of file From f1a0e683b3599533b9e5f54d58a8f60849177ceb Mon Sep 17 00:00:00 2001 From: Patricia Bedard Date: Thu, 21 Nov 2024 13:01:16 -0500 Subject: [PATCH 4/8] refactored LlmsManager model loading methods --- backend/managers/LlmsManager.py | 230 +++++++++++++++----------------- 1 file changed, 107 insertions(+), 123 deletions(-) diff --git a/backend/managers/LlmsManager.py b/backend/managers/LlmsManager.py index 081ee94f..8ab857f5 100644 --- a/backend/managers/LlmsManager.py +++ b/backend/managers/LlmsManager.py @@ -32,6 +32,64 @@ def __init__(self): asyncio.gather(router_init_task, return_exceptions=True) self._initialized = True + async def get_llm(self, id: str) -> Optional[Llm]: + async with db_session_context() as session: + result = await session.execute(select(Llm).filter(Llm.id == id)) + llm = result.scalar_one_or_none() + if llm: + return llm + return None + + async def retrieve_llms(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None, + sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None) -> Tuple[List[Llm], int]: + async with db_session_context() as session: + query = select(Llm).filter(Llm.is_active == True) + + if filters: + for key, value in filters.items(): + if isinstance(value, list): + query = query.filter(getattr(Llm, key).in_(value)) + else: + query = query.filter(getattr(Llm, key) == value) + + if sort_by and sort_by in ['id', 'name', 'provider', 'api_base', 'is_active']: + order_column = getattr(Llm, sort_by) + query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) + + query = query.offset(offset).limit(limit) + + result = await session.execute(query) + llms = result.scalars().all() + + # Get total count + count_query = select(func.count()).select_from(Llm).filter(Llm.is_active == True) + if filters: + for key, value in filters.items(): + if isinstance(value, list): + count_query = count_query.filter(getattr(Llm, key).in_(value)) + else: + count_query = count_query.filter(getattr(Llm, key) == value) + + total_count = await session.execute(count_query) + total_count = total_count.scalar() + + return llms, total_count + + def completion(self, llm, messages, **optional_params) -> Union[ModelResponse, CustomStreamWrapper]: + try: + response = self.router.completion(model=llm.llm_name, + messages=messages, + **optional_params) + # below is the direct way to call the LLM (i.e. not using the router): + #response = completion(model=llm.llm_name, + # messages=messages, + # **optional_params) + print("completion response: {}".format(response)) + return response + except Exception as e: + logger.info(f"completion failed with error: {e.message}") + raise + async def _init_router(self): try: # load models @@ -65,54 +123,34 @@ async def _init_router(self): async def _load_ollama_models(self): try: - ollama_urlroot = get_env_key("OLLAMA_URLROOT") # eg: http://localhost:11434 + ollama_urlroot = get_env_key("OLLAMA_URLROOT", "http://localhost:11434") except ValueError: - return # no Ollama server specified + print("No Ollama server specified. Skipping.") + return # no Ollama server specified, skip # retrieve list of installed models async with httpx.AsyncClient() as client: response = await client.get("{}/api/tags".format(ollama_urlroot)) if response.status_code == 200: data = response.json() available_models = [model_data['model'] for model_data in data.get("models", [])] - #print(available_models) else: - pass # FIX + print(f"Error: {response.status_code} - {response.text}") # create / update Ollama family Llm objects provider = "ollama" - async with db_session_context() as session: - # mark existing models as inactive - stmt = update(Llm).where(Llm.provider == provider).values(is_active=False) - result = await session.execute(stmt) - if result.rowcount > 0: - await session.commit() - # insert / update models - for model in available_models: - name = model.removesuffix(":latest") - llm_name = "{}/{}".format(provider,name) # what LiteLLM expects - safe_name = llm_name.replace("/", "-").replace(":", "-") - llm = await self.get_llm(safe_name) - if llm: - stmt = update(Llm).where(Llm.id == safe_name).values(name=name, - llm_name=llm_name, - provider=provider, - api_base=ollama_urlroot, - is_active=True) - result = await session.execute(stmt) - if result.rowcount > 0: - await session.commit() - else: - new_llm = Llm(id=safe_name, name=name, llm_name=llm_name, - provider=provider, api_base=ollama_urlroot, - is_active=True) - session.add(new_llm) - await session.commit() + models = {} + for model in available_models: + name = model.removesuffix(":latest") + llm_name = "{}/{}".format(provider,name) # what LiteLLM expects + safe_name = llm_name.replace("/", "-").replace(":", "-") # URL-friendly ID + models[model] = {"id": safe_name, "name": name, "llm_name": llm_name, "provider": provider, "api_base": ollama_urlroot} + await self._persist_models(provider=provider, models=models) async def _load_openai_models(self): try: openai_key = get_env_key("OPENAI_API_KEY") except ValueError: print("No OpenAI API key specified. Skipping.") - return # no OpenAI API key specified + return # no OpenAI API key specified, skip # retrieve list of installed models async with httpx.AsyncClient() as client: openai_urlroot = get_env_key("OPENAI_URLROOT", "https://api.openai.com") @@ -123,11 +161,25 @@ async def _load_openai_models(self): if response.status_code == 200: data = response.json() available_models = [model_data['id'] for model_data in data.get("data", [])] - #print(available_models) else: - print(f"Error: {response.status_code} - {response.text}") # FIX + print(f"Error: {response.status_code} - {response.text}") # create / update OpenAI family Llm objects provider = "openai" + models = {} + for model in available_models: + llm_provider = None + if any(substring in model for substring in {"gpt","o1","chatgpt"}): + llm_provider = "openai" + if any(substring in model for substring in {"ada","babbage","curie","davinci","instruct"}): + llm_provider = "text-completion-openai" + if llm_provider: + name = model + llm_name = "{}/{}".format(llm_provider,name) # what LiteLLM expects + safe_name = f"{provider}/{name}".replace("/", "-").replace(":", "-") # URL-friendly ID + models[model] = {"id": safe_name, "name": name, "llm_name": llm_name, "provider": provider, "api_base": openai_urlroot} + await self._persist_models(provider=provider, models=models) + + async def _persist_models(self, provider, models): async with db_session_context() as session: # mark existing models as inactive stmt = update(Llm).where(Llm.provider == provider).values(is_active=False) @@ -135,94 +187,26 @@ async def _load_openai_models(self): if result.rowcount > 0: await session.commit() # insert / update models - for model in available_models: - llm_provider = None - if any(substring in model for substring in {"gpt","o1","chatgpt"}): - llm_provider = "openai" - if any(substring in model for substring in {"ada","babbage","curie","davinci","instruct"}): - llm_provider = "text-completion-openai" - if llm_provider: - name = model - llm_name = "{}/{}".format(llm_provider,name) # what LiteLLM expects - safe_name = f"{provider}/{name}".replace("/", "-").replace(":", "-") - llm = await self.get_llm(safe_name) - if llm: - stmt = update(Llm).where(Llm.id == safe_name).values(name=name, - llm_name=llm_name, - provider=provider, - api_base=openai_urlroot, - is_active=True) - result = await session.execute(stmt) - if result.rowcount > 0: - await session.commit() - else: - new_llm = Llm(id=safe_name, name=name, llm_name=llm_name, - provider=provider, api_base=openai_urlroot, - is_active=True) - session.add(new_llm) + for model in models: + parameters = models[model] + model_id = parameters["id"] + llm = await self.get_llm(model_id) + if llm: + stmt = update(Llm).where(Llm.id == model_id).values(name=parameters["name"], + llm_name=parameters["llm_name"], + provider=parameters["provider"], + api_base=parameters["api_base"], + is_active=True) + result = await session.execute(stmt) + if result.rowcount > 0: await session.commit() + else: + new_llm = Llm(id=model_id, + name=parameters["name"], + llm_name=parameters["llm_name"], + provider=parameters["provider"], + api_base=parameters["api_base"], + is_active=True) + session.add(new_llm) + await session.commit() - async def get_llm(self, id: str) -> Optional[Llm]: - async with db_session_context() as session: - result = await session.execute(select(Llm).filter(Llm.id == id)) - llm = result.scalar_one_or_none() - if llm: - return llm - return None - - async def retrieve_llms(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None, - sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None) -> Tuple[List[Llm], int]: - async with db_session_context() as session: - query = select(Llm).filter(Llm.is_active == True) - - if filters: - for key, value in filters.items(): - if isinstance(value, list): - query = query.filter(getattr(Llm, key).in_(value)) - else: - query = query.filter(getattr(Llm, key) == value) - - if sort_by and sort_by in ['id', 'name', 'provider', 'api_base', 'is_active']: - order_column = getattr(Llm, sort_by) - query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) - - query = query.offset(offset).limit(limit) - - result = await session.execute(query) - llms = result.scalars().all() - - # Get total count - count_query = select(func.count()).select_from(Llm).filter(Llm.is_active == True) - if filters: - for key, value in filters.items(): - if isinstance(value, list): - count_query = count_query.filter(getattr(Llm, key).in_(value)) - else: - count_query = count_query.filter(getattr(Llm, key) == value) - - total_count = await session.execute(count_query) - total_count = total_count.scalar() - - return llms, total_count - - def completion(self, llm, messages, **optional_params) -> Union[ModelResponse, CustomStreamWrapper]: - try: - response = self.router.completion(model=llm.llm_name, - messages=messages, - **optional_params) - #response = completion(model=llm.llm_name, - # messages=messages, - # **kwargs) - print("completion response: {}".format(response)) - return response - except Exception as e: - logger.info(f"completion failed with error: {e.message}") - raise - - async def acompletion(self, llm, messages, **optional_params) -> Union[CustomStreamWrapper, ModelResponse]: - response = await self.router.acompletion(model=llm.llm_name, - messages=messages, - **optional_params) - print("acompletion response: {}".format(response)) - return response - \ No newline at end of file From 72745c530ca9f7253b76407973742dc5d7250cd1 Mon Sep 17 00:00:00 2001 From: Patricia Bedard Date: Thu, 5 Dec 2024 11:08:08 -0500 Subject: [PATCH 5/8] added JWT security for LLMs API calls --- apis/paios/openapi.yaml | 6 ++++++ backend/managers/LlmsManager.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/apis/paios/openapi.yaml b/apis/paios/openapi.yaml index 6a3b150a..fcffd256 100644 --- a/apis/paios/openapi.yaml +++ b/apis/paios/openapi.yaml @@ -789,6 +789,8 @@ paths: description: Not Found /llms: get: + security: + - jwt: [] tags: - LLM Management summary: Retrieve all available LLMs @@ -811,6 +813,8 @@ paths: $ref: '#/components/headers/X-Total-Count' '/llms/{id}': get: + security: + - jwt: [] tags: - LLM Management summary: Retrieve LLM by id @@ -828,6 +832,8 @@ paths: description: LLM not found '/llms/{id}/completion': post: + security: + - jwt: [] tags: - LLM Management summary: Invoke Completion on LLM diff --git a/backend/managers/LlmsManager.py b/backend/managers/LlmsManager.py index 8ab857f5..c0d64f61 100644 --- a/backend/managers/LlmsManager.py +++ b/backend/managers/LlmsManager.py @@ -123,7 +123,7 @@ async def _init_router(self): async def _load_ollama_models(self): try: - ollama_urlroot = get_env_key("OLLAMA_URLROOT", "http://localhost:11434") + ollama_urlroot = get_env_key("OLLAMA_URLROOT") except ValueError: print("No Ollama server specified. Skipping.") return # no Ollama server specified, skip From 1c1b981f9cebaae0a486406d4e86b879cd4dfaa8 Mon Sep 17 00:00:00 2001 From: Patricia Bedard Date: Thu, 5 Dec 2024 17:41:22 -0500 Subject: [PATCH 6/8] refactoring: LlmsManager => ModelsManager --- backend/api/LlmsView.py | 12 ++++++------ .../managers/{LlmsManager.py => ModelsManager.py} | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) rename backend/managers/{LlmsManager.py => ModelsManager.py} (97%) diff --git a/backend/api/LlmsView.py b/backend/api/LlmsView.py index 307abb5c..327e91c5 100644 --- a/backend/api/LlmsView.py +++ b/backend/api/LlmsView.py @@ -1,15 +1,15 @@ from starlette.responses import JSONResponse -from backend.managers.LlmsManager import LlmsManager +from backend.managers.ModelsManager import ModelsManager from backend.pagination import parse_pagination_params from backend.schemas import LlmSchema from litellm.exceptions import BadRequestError class LlmsView: def __init__(self): - self.llmm = LlmsManager() + self.mm = ModelsManager() async def get(self, id: str): - llm = await self.llmm.get_llm(id) + llm = await self.mm.get_llm(id) if llm is None: return JSONResponse(headers={"error": "LLM not found"}, status_code=404) llm_schema = LlmSchema(id=llm.id, name=llm.name, full_name=f"{llm.provider}/{llm.name}", @@ -23,7 +23,7 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): offset, limit, sort_by, sort_order, filters = result - llms, total_count = await self.llmm.retrieve_llms(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters) + llms, total_count = await self.mm.retrieve_llms(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters) results = [LlmSchema(id=llm.id, name=llm.name, full_name=f"{llm.provider}/{llm.name}", provider=llm.provider, api_base=llm.api_base, is_active=llm.is_active) @@ -36,7 +36,7 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): async def completion(self, id: str, body: dict): print("completion. body: {}".format(body)) - llm = await self.llmm.get_llm(id) + llm = await self.mm.get_llm(id) if llm: messages = [] if 'messages' in body and body['messages']: @@ -45,7 +45,7 @@ async def completion(self, id: str, body: dict): if 'optional_params' in body and body['optional_params']: opt_params = body['optional_params'] try: - response = self.llmm.completion(llm, messages, **opt_params) + response = self.mm.completion(llm, messages, **opt_params) return JSONResponse(response.model_dump(), status_code=200) except BadRequestError as e: return JSONResponse(status_code=400, content={"message": e.message}) diff --git a/backend/managers/LlmsManager.py b/backend/managers/ModelsManager.py similarity index 97% rename from backend/managers/LlmsManager.py rename to backend/managers/ModelsManager.py index c0d64f61..c388e1ac 100644 --- a/backend/managers/LlmsManager.py +++ b/backend/managers/ModelsManager.py @@ -12,7 +12,7 @@ import logging logger = logging.getLogger(__name__) -class LlmsManager: +class ModelsManager: _instance = None _lock = Lock() @@ -20,7 +20,7 @@ def __new__(cls, *args, **kwargs): if not cls._instance: with cls._lock: if not cls._instance: - cls._instance = super(LlmsManager, cls).__new__(cls, *args, **kwargs) + cls._instance = super(ModelsManager, cls).__new__(cls, *args, **kwargs) return cls._instance def __init__(self): @@ -80,7 +80,7 @@ def completion(self, llm, messages, **optional_params) -> Union[ModelResponse, C response = self.router.completion(model=llm.llm_name, messages=messages, **optional_params) - # below is the direct way to call the LLM (i.e. not using the router): + # below is the direct way to call the model (i.e. not using the router): #response = completion(model=llm.llm_name, # messages=messages, # **optional_params) @@ -98,7 +98,7 @@ async def _init_router(self): await asyncio.gather(ollama_task, openai_task, return_exceptions=True) - # collect the available LLMs + # collect the available models llms, total_llms = await self.retrieve_llms() # configure router model_list = [] From cda38e6639e36ef2e34cbc99225a198f8d5d55e7 Mon Sep 17 00:00:00 2001 From: Patricia Bedard Date: Wed, 8 Jan 2025 19:37:03 -0500 Subject: [PATCH 7/8] Integrate aisuite in ModelsManager to replace LiteLLM --- apis/paios/openapi.yaml | 4 +- backend/api/LlmsView.py | 12 ++--- backend/managers/ModelsManager.py | 45 ++++++++++++------- backend/models.py | 3 +- backend/requirements.txt | 1 + backend/schemas.py | 2 +- .../versions/73d50424c826_added_llm_table.py | 3 +- 7 files changed, 42 insertions(+), 28 deletions(-) diff --git a/apis/paios/openapi.yaml b/apis/paios/openapi.yaml index fcffd256..e2b4d42e 100644 --- a/apis/paios/openapi.yaml +++ b/apis/paios/openapi.yaml @@ -860,7 +860,7 @@ paths: content: application/json: schema: - type: object + type: string '400': description: Completion failed '404': @@ -1173,7 +1173,7 @@ components: pattern: ^[a-z]{4}-[a-z]{4}-[a-z]{4}$ messagesList: type: array - example: [{"role": "user", "content": "What is Kwaai.ai?"}] + example: [{"role": "user", "content": "What is Personal AI?"}] items: type: object properties: diff --git a/backend/api/LlmsView.py b/backend/api/LlmsView.py index 327e91c5..5dfa643e 100644 --- a/backend/api/LlmsView.py +++ b/backend/api/LlmsView.py @@ -12,8 +12,8 @@ async def get(self, id: str): llm = await self.mm.get_llm(id) if llm is None: return JSONResponse(headers={"error": "LLM not found"}, status_code=404) - llm_schema = LlmSchema(id=llm.id, name=llm.name, full_name=f"{llm.provider}/{llm.name}", - provider=llm.provider, api_base=llm.api_base, is_active=llm.is_active) + llm_schema = LlmSchema(id=llm.id, name=llm.name, provider=llm.provider, full_name=llm.aisuite_name, + api_base=llm.api_base, is_active=llm.is_active) return JSONResponse(llm_schema.model_dump(), status_code=200) async def search(self, filter: str = None, range: str = None, sort: str = None): @@ -24,9 +24,8 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): offset, limit, sort_by, sort_order, filters = result llms, total_count = await self.mm.retrieve_llms(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters) - results = [LlmSchema(id=llm.id, name=llm.name, full_name=f"{llm.provider}/{llm.name}", - provider=llm.provider, api_base=llm.api_base, - is_active=llm.is_active) + results = [LlmSchema(id=llm.id, name=llm.name, provider=llm.provider, full_name=llm.aisuite_name, + api_base=llm.api_base, is_active=llm.is_active) for llm in llms] headers = { 'X-Total-Count': str(total_count), @@ -46,7 +45,8 @@ async def completion(self, id: str, body: dict): opt_params = body['optional_params'] try: response = self.mm.completion(llm, messages, **opt_params) - return JSONResponse(response.model_dump(), status_code=200) + #return JSONResponse(response.model_dump(), status_code=200) # LiteLLM response handling + return JSONResponse(response.choices[0].message.content, status_code=200) # aisuite response handling except BadRequestError as e: return JSONResponse(status_code=400, content={"message": e.message}) except Exception as e: diff --git a/backend/managers/ModelsManager.py b/backend/managers/ModelsManager.py index c388e1ac..a00d947c 100644 --- a/backend/managers/ModelsManager.py +++ b/backend/managers/ModelsManager.py @@ -1,5 +1,6 @@ import asyncio import httpx +import aisuite as ai from threading import Lock from sqlalchemy import select, insert, update, delete, func from backend.models import Llm @@ -27,9 +28,10 @@ def __init__(self): if not hasattr(self, '_initialized'): # Ensure initialization happens only once with self._lock: if not hasattr(self, '_initialized'): + self.ai_client = ai.Client() self.router = None - router_init_task = asyncio.create_task(self._init_router()) - asyncio.gather(router_init_task, return_exceptions=True) + model_load_task = asyncio.create_task(self._load_models()) + asyncio.gather(model_load_task, return_exceptions=True) self._initialized = True async def get_llm(self, id: str) -> Optional[Llm]: @@ -77,20 +79,25 @@ async def retrieve_llms(self, offset: int = 0, limit: int = 100, sort_by: Option def completion(self, llm, messages, **optional_params) -> Union[ModelResponse, CustomStreamWrapper]: try: - response = self.router.completion(model=llm.llm_name, - messages=messages, - **optional_params) - # below is the direct way to call the model (i.e. not using the router): + response = self.ai_client.chat.completions.create(model=llm.aisuite_name, + messages=messages, + **optional_params) + # below is the way to call the model using the LiteLLM router + #response = self.router.completion(model=llm.llm_name, + # messages=messages, + # **optional_params) + # below is the direct way to call the model using LiteLLM (i.e. not using the router): #response = completion(model=llm.llm_name, # messages=messages, # **optional_params) print("completion response: {}".format(response)) + #print("completion response content: {}".format(response.choices[0].message.content)) return response except Exception as e: logger.info(f"completion failed with error: {e.message}") raise - async def _init_router(self): + async def _load_models(self): try: # load models ollama_task = asyncio.create_task(self._load_ollama_models()) @@ -123,13 +130,13 @@ async def _init_router(self): async def _load_ollama_models(self): try: - ollama_urlroot = get_env_key("OLLAMA_URLROOT") + ollama_api_url = get_env_key("OLLAMA_API_URL") except ValueError: print("No Ollama server specified. Skipping.") return # no Ollama server specified, skip # retrieve list of installed models async with httpx.AsyncClient() as client: - response = await client.get("{}/api/tags".format(ollama_urlroot)) + response = await client.get("{}/api/tags".format(ollama_api_url)) if response.status_code == 200: data = response.json() available_models = [model_data['model'] for model_data in data.get("models", [])] @@ -140,24 +147,25 @@ async def _load_ollama_models(self): models = {} for model in available_models: name = model.removesuffix(":latest") + aisuite_name = "{}:{}".format(provider,name) # what aisuite expects llm_name = "{}/{}".format(provider,name) # what LiteLLM expects safe_name = llm_name.replace("/", "-").replace(":", "-") # URL-friendly ID - models[model] = {"id": safe_name, "name": name, "llm_name": llm_name, "provider": provider, "api_base": ollama_urlroot} + models[model] = {"id": safe_name, "name": name, "provider": provider, "aisuite_name": aisuite_name, "llm_name": llm_name, "api_base": ollama_api_url} await self._persist_models(provider=provider, models=models) async def _load_openai_models(self): try: - openai_key = get_env_key("OPENAI_API_KEY") + openai_api_key = get_env_key("OPENAI_API_KEY") except ValueError: print("No OpenAI API key specified. Skipping.") return # no OpenAI API key specified, skip # retrieve list of installed models async with httpx.AsyncClient() as client: - openai_urlroot = get_env_key("OPENAI_URLROOT", "https://api.openai.com") + openai_api_url = get_env_key("OPENAI_API_URL", "https://api.openai.com") headers = { - "Authorization": f"Bearer {openai_key}" + "Authorization": f"Bearer {openai_api_key}" } - response = await client.get(f"{openai_urlroot}/v1/models", headers=headers) + response = await client.get(f"{openai_api_url}/v1/models", headers=headers) if response.status_code == 200: data = response.json() available_models = [model_data['id'] for model_data in data.get("data", [])] @@ -174,9 +182,10 @@ async def _load_openai_models(self): llm_provider = "text-completion-openai" if llm_provider: name = model + aisuite_name = "{}:{}".format(provider,name) # what aisuite expects llm_name = "{}/{}".format(llm_provider,name) # what LiteLLM expects safe_name = f"{provider}/{name}".replace("/", "-").replace(":", "-") # URL-friendly ID - models[model] = {"id": safe_name, "name": name, "llm_name": llm_name, "provider": provider, "api_base": openai_urlroot} + models[model] = {"id": safe_name, "name": name, "provider": provider, "aisuite_name": aisuite_name, "llm_name": llm_name, "api_base": openai_api_url} await self._persist_models(provider=provider, models=models) async def _persist_models(self, provider, models): @@ -193,8 +202,9 @@ async def _persist_models(self, provider, models): llm = await self.get_llm(model_id) if llm: stmt = update(Llm).where(Llm.id == model_id).values(name=parameters["name"], - llm_name=parameters["llm_name"], provider=parameters["provider"], + aisuite_name=parameters["aisuite_name"], + llm_name=parameters["llm_name"], api_base=parameters["api_base"], is_active=True) result = await session.execute(stmt) @@ -203,8 +213,9 @@ async def _persist_models(self, provider, models): else: new_llm = Llm(id=model_id, name=parameters["name"], - llm_name=parameters["llm_name"], provider=parameters["provider"], + aisuite_name=parameters["aisuite_name"], + llm_name=parameters["llm_name"], api_base=parameters["api_base"], is_active=True) session.add(new_llm) diff --git a/backend/models.py b/backend/models.py index 28a1f17b..d1e38939 100644 --- a/backend/models.py +++ b/backend/models.py @@ -68,8 +68,9 @@ class Share(SQLModelBase, table=True): class Llm(SQLModelBase, table=True): id: str = Field(primary_key=True) # the model's unique, URL-friendly name name: str = Field() - llm_name: str = Field() # the model name known to LiteLLM provider: str = Field() # model provider, eg "ollama" + aisuite_name: str = Field() # the model name known to aisuite + llm_name: str = Field() # the model name known to LiteLLM api_base: str | None = Field(default=None) is_active: bool = Field() # is the model installed / available? diff --git a/backend/requirements.txt b/backend/requirements.txt index 6cc65691..6b1abc90 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -21,3 +21,4 @@ webauthn greenlet pyjwt litellm +aisuite[ollama,openai] \ No newline at end of file diff --git a/backend/schemas.py b/backend/schemas.py index b68772e5..0012726c 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -91,8 +91,8 @@ class ShareSchema(ShareBaseSchema): class LlmSchema(BaseModel): id: str name: str - full_name: str provider: str + full_name: str api_base: Optional[str] = None is_active: bool diff --git a/migrations/versions/73d50424c826_added_llm_table.py b/migrations/versions/73d50424c826_added_llm_table.py index 32463cb8..357044aa 100644 --- a/migrations/versions/73d50424c826_added_llm_table.py +++ b/migrations/versions/73d50424c826_added_llm_table.py @@ -23,8 +23,9 @@ def upgrade() -> None: op.create_table('llm', sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('llm_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('aisuite_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('llm_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('api_base', sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column('is_active', sa.Boolean(), nullable=False), sa.PrimaryKeyConstraint('id') From 8d48bc44381c00e30d07427d42a8fb7d3bbb0fb2 Mon Sep 17 00:00:00 2001 From: Patricia Bedard Date: Thu, 9 Jan 2025 11:47:36 -0500 Subject: [PATCH 8/8] fix get_llm to return active models only; cleaned up logging --- backend/api/LlmsView.py | 3 +-- backend/managers/ModelsManager.py | 35 ++++++++++++++----------------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/backend/api/LlmsView.py b/backend/api/LlmsView.py index 5dfa643e..cd8c0a42 100644 --- a/backend/api/LlmsView.py +++ b/backend/api/LlmsView.py @@ -45,8 +45,7 @@ async def completion(self, id: str, body: dict): opt_params = body['optional_params'] try: response = self.mm.completion(llm, messages, **opt_params) - #return JSONResponse(response.model_dump(), status_code=200) # LiteLLM response handling - return JSONResponse(response.choices[0].message.content, status_code=200) # aisuite response handling + return JSONResponse(response.choices[0].message.content, status_code=200) except BadRequestError as e: return JSONResponse(status_code=400, content={"message": e.message}) except Exception as e: diff --git a/backend/managers/ModelsManager.py b/backend/managers/ModelsManager.py index a00d947c..1d1499a7 100644 --- a/backend/managers/ModelsManager.py +++ b/backend/managers/ModelsManager.py @@ -10,8 +10,9 @@ from litellm import Router, completion from litellm.utils import CustomStreamWrapper, ModelResponse -import logging -logger = logging.getLogger(__name__) +# set up logging +from common.log import get_logger +logger = get_logger(__name__) class ModelsManager: _instance = None @@ -34,9 +35,12 @@ def __init__(self): asyncio.gather(model_load_task, return_exceptions=True) self._initialized = True - async def get_llm(self, id: str) -> Optional[Llm]: + async def get_llm(self, id: str, only_active=True) -> Optional[Llm]: async with db_session_context() as session: - result = await session.execute(select(Llm).filter(Llm.id == id)) + query = select(Llm).filter(Llm.id == id) + if only_active: + query = query.filter(getattr(Llm, "is_active") == True) + result = await session.execute(query) llm = result.scalar_one_or_none() if llm: return llm @@ -86,12 +90,7 @@ def completion(self, llm, messages, **optional_params) -> Union[ModelResponse, C #response = self.router.completion(model=llm.llm_name, # messages=messages, # **optional_params) - # below is the direct way to call the model using LiteLLM (i.e. not using the router): - #response = completion(model=llm.llm_name, - # messages=messages, - # **optional_params) - print("completion response: {}".format(response)) - #print("completion response content: {}".format(response.choices[0].message.content)) + logger.debug("completion response: {}".format(response)) return response except Exception as e: logger.info(f"completion failed with error: {e.message}") @@ -122,8 +121,6 @@ async def _load_models(self): "litellm_params": params, } model_list.append(model) - #import pprint - #pprint.pprint(model_list) self.router = Router(model_list=model_list) except Exception as e: logger.exception(e) @@ -132,8 +129,8 @@ async def _load_ollama_models(self): try: ollama_api_url = get_env_key("OLLAMA_API_URL") except ValueError: - print("No Ollama server specified. Skipping.") - return # no Ollama server specified, skip + logger.info("No Ollama server specified. Skipping.") + return # retrieve list of installed models async with httpx.AsyncClient() as client: response = await client.get("{}/api/tags".format(ollama_api_url)) @@ -141,7 +138,7 @@ async def _load_ollama_models(self): data = response.json() available_models = [model_data['model'] for model_data in data.get("models", [])] else: - print(f"Error: {response.status_code} - {response.text}") + logger.warning(f"Error when retrieving models: {response.status_code} - {response.text}") # create / update Ollama family Llm objects provider = "ollama" models = {} @@ -157,8 +154,8 @@ async def _load_openai_models(self): try: openai_api_key = get_env_key("OPENAI_API_KEY") except ValueError: - print("No OpenAI API key specified. Skipping.") - return # no OpenAI API key specified, skip + logger.info("No OpenAI API key specified. Skipping.") + return # retrieve list of installed models async with httpx.AsyncClient() as client: openai_api_url = get_env_key("OPENAI_API_URL", "https://api.openai.com") @@ -170,7 +167,7 @@ async def _load_openai_models(self): data = response.json() available_models = [model_data['id'] for model_data in data.get("data", [])] else: - print(f"Error: {response.status_code} - {response.text}") + logger.warning(f"Error when retrieving models: {response.status_code} - {response.text}") # create / update OpenAI family Llm objects provider = "openai" models = {} @@ -199,7 +196,7 @@ async def _persist_models(self, provider, models): for model in models: parameters = models[model] model_id = parameters["id"] - llm = await self.get_llm(model_id) + llm = await self.get_llm(model_id, only_active=False) if llm: stmt = update(Llm).where(Llm.id == model_id).values(name=parameters["name"], provider=parameters["provider"],