diff --git a/apis/paios/openapi.yaml b/apis/paios/openapi.yaml index 68a4dd9e..e2b4d42e 100644 --- a/apis/paios/openapi.yaml +++ b/apis/paios/openapi.yaml @@ -700,6 +700,8 @@ paths: - $ref: '#/components/parameters/id' /shares: get: + security: + - jwt: [] tags: - Share Management summary: Retrieve all share links @@ -721,6 +723,8 @@ paths: X-Total-Count: $ref: '#/components/headers/X-Total-Count' post: + security: + - jwt: [] summary: Create new share link tags: - Share Management @@ -737,6 +741,8 @@ paths: $ref: '#/components/schemas/ShareCreate' '/shares/{id}': get: + security: + - jwt: [] tags: - Share Management summary: Retrieve share link by id @@ -751,6 +757,8 @@ paths: schema: $ref: '#/components/schemas/Share' put: + security: + - jwt: [] tags: - Share Management summary: Update share link by id @@ -766,6 +774,8 @@ paths: '200': description: OK delete: + security: + - jwt: [] tags: - Share Management summary: Delete share link by id @@ -777,6 +787,84 @@ paths: description: No Content '404': description: Not Found + /llms: + get: + security: + - jwt: [] + tags: + - LLM Management + summary: Retrieve all available LLMs + description: Get all installed / available LLMs. + parameters: + - $ref: '#/components/parameters/sort' + - $ref: '#/components/parameters/range' + - $ref: '#/components/parameters/filter' + responses: + '200': + description: OK + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Llm' + headers: + X-Total-Count: + $ref: '#/components/headers/X-Total-Count' + '/llms/{id}': + get: + security: + - jwt: [] + tags: + - LLM Management + summary: Retrieve LLM by id + description: Retrieve the LLM with the specified id. + parameters: + - $ref: '#/components/parameters/kebab-dot_id' + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/Llm' + '404': + description: LLM not found + '/llms/{id}/completion': + post: + security: + - jwt: [] + tags: + - LLM Management + summary: Invoke Completion on LLM + description: Invoke Completion on the LLM with the specified id. + operationId: backend.api.LlmsView.completion + parameters: + - $ref: '#/components/parameters/kebab-dot_id' + requestBody: + description: Messages to input to Completion function. + content: + application/json: + schema: + type: object + properties: + messages: + $ref: '#/components/schemas/messagesList' + optional_params: + $ref: '#/components/schemas/completionParamList' + required: + - messages + responses: + '200': + description: Completion succeeded + content: + application/json: + schema: + type: string + '400': + description: Completion failed + '404': + description: LLM not found /auth/webauthn/register-options: post: summary: Generate WebAuthn registration options @@ -901,6 +989,8 @@ tags: description: Management of personas - name: Share Management description: Management of share links + - name: LLM Management + description: Discovery and invocation of LLM functionality components: securitySchemes: jwt: @@ -936,6 +1026,13 @@ components: required: true schema: $ref: '#/components/schemas/kebab-snake_id' + kebab-dot_id: + name: id + in: path + description: id of the object + required: true + schema: + $ref: '#/components/schemas/kebab-dot_id' key: name: key in: path @@ -1003,6 +1100,12 @@ components: maxLength: 100 example: langchain-core pattern: ^[a-z0-9]+([_-][a-z0-9]+)*$ + kebab-dot_id: + type: string + minLength: 2 + maxLength: 100 + example: ollama-llama3.2 + pattern: ^[a-z0-9]+([.-][a-z0-9]+)*$ semVer: type: string example: '1.1.0' @@ -1068,6 +1171,19 @@ components: readOnly: true example: abcd-efgh-ijkl pattern: ^[a-z]{4}-[a-z]{4}-[a-z]{4}$ + messagesList: + type: array + example: [{"role": "user", "content": "What is Personal AI?"}] + items: + type: object + properties: + role: + type: string + content: + type: string + completionParamList: + type: object + example: {"max_tokens": 50, "temperature": 0.2} download: type: object properties: @@ -1390,6 +1506,29 @@ components: example: false required: - resource_id + Llm: + type: object + title: Llm + properties: + id: + type: string + name: + type: string + full_name: + type: string + provider: + type: string + api_base: + type: string + nullable: true + is_active: + type: boolean + required: + - id + - name + - full_name + - provider + - is_active RegistrationOptions: type: object properties: diff --git a/backend/api/LlmsView.py b/backend/api/LlmsView.py new file mode 100644 index 00000000..cd8c0a42 --- /dev/null +++ b/backend/api/LlmsView.py @@ -0,0 +1,54 @@ +from starlette.responses import JSONResponse +from backend.managers.ModelsManager import ModelsManager +from backend.pagination import parse_pagination_params +from backend.schemas import LlmSchema +from litellm.exceptions import BadRequestError + +class LlmsView: + def __init__(self): + self.mm = ModelsManager() + + async def get(self, id: str): + llm = await self.mm.get_llm(id) + if llm is None: + return JSONResponse(headers={"error": "LLM not found"}, status_code=404) + llm_schema = LlmSchema(id=llm.id, name=llm.name, provider=llm.provider, full_name=llm.aisuite_name, + api_base=llm.api_base, is_active=llm.is_active) + return JSONResponse(llm_schema.model_dump(), status_code=200) + + async def search(self, filter: str = None, range: str = None, sort: str = None): + result = parse_pagination_params(filter, range, sort) + if isinstance(result, JSONResponse): + return result + + offset, limit, sort_by, sort_order, filters = result + + llms, total_count = await self.mm.retrieve_llms(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters) + results = [LlmSchema(id=llm.id, name=llm.name, provider=llm.provider, full_name=llm.aisuite_name, + api_base=llm.api_base, is_active=llm.is_active) + for llm in llms] + headers = { + 'X-Total-Count': str(total_count), + 'Content-Range': f'shares {offset}-{offset + len(llms) - 1}/{total_count}' + } + return JSONResponse([llm.model_dump() for llm in results], status_code=200, headers=headers) + + async def completion(self, id: str, body: dict): + print("completion. body: {}".format(body)) + llm = await self.mm.get_llm(id) + if llm: + messages = [] + if 'messages' in body and body['messages']: + messages = body['messages'] + opt_params = {} + if 'optional_params' in body and body['optional_params']: + opt_params = body['optional_params'] + try: + response = self.mm.completion(llm, messages, **opt_params) + return JSONResponse(response.choices[0].message.content, status_code=200) + except BadRequestError as e: + return JSONResponse(status_code=400, content={"message": e.message}) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "Completion failed."}) + else: + return JSONResponse(status_code=404, content={"message": "LLM not found"}) diff --git a/backend/api/__init__.py b/backend/api/__init__.py index 697d4c4d..78cf1bfe 100644 --- a/backend/api/__init__.py +++ b/backend/api/__init__.py @@ -9,3 +9,4 @@ from .PersonasView import PersonasView from .SharesView import SharesView from .AuthView import AuthView +from .LlmsView import LlmsView diff --git a/backend/managers/ModelsManager.py b/backend/managers/ModelsManager.py new file mode 100644 index 00000000..1d1499a7 --- /dev/null +++ b/backend/managers/ModelsManager.py @@ -0,0 +1,220 @@ +import asyncio +import httpx +import aisuite as ai +from threading import Lock +from sqlalchemy import select, insert, update, delete, func +from backend.models import Llm +from backend.db import db_session_context +from backend.utils import get_env_key +from typing import List, Tuple, Optional, Dict, Any, Union +from litellm import Router, completion +from litellm.utils import CustomStreamWrapper, ModelResponse + +# set up logging +from common.log import get_logger +logger = get_logger(__name__) + +class ModelsManager: + _instance = None + _lock = Lock() + + def __new__(cls, *args, **kwargs): + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(ModelsManager, cls).__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self): + if not hasattr(self, '_initialized'): # Ensure initialization happens only once + with self._lock: + if not hasattr(self, '_initialized'): + self.ai_client = ai.Client() + self.router = None + model_load_task = asyncio.create_task(self._load_models()) + asyncio.gather(model_load_task, return_exceptions=True) + self._initialized = True + + async def get_llm(self, id: str, only_active=True) -> Optional[Llm]: + async with db_session_context() as session: + query = select(Llm).filter(Llm.id == id) + if only_active: + query = query.filter(getattr(Llm, "is_active") == True) + result = await session.execute(query) + llm = result.scalar_one_or_none() + if llm: + return llm + return None + + async def retrieve_llms(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None, + sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None) -> Tuple[List[Llm], int]: + async with db_session_context() as session: + query = select(Llm).filter(Llm.is_active == True) + + if filters: + for key, value in filters.items(): + if isinstance(value, list): + query = query.filter(getattr(Llm, key).in_(value)) + else: + query = query.filter(getattr(Llm, key) == value) + + if sort_by and sort_by in ['id', 'name', 'provider', 'api_base', 'is_active']: + order_column = getattr(Llm, sort_by) + query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) + + query = query.offset(offset).limit(limit) + + result = await session.execute(query) + llms = result.scalars().all() + + # Get total count + count_query = select(func.count()).select_from(Llm).filter(Llm.is_active == True) + if filters: + for key, value in filters.items(): + if isinstance(value, list): + count_query = count_query.filter(getattr(Llm, key).in_(value)) + else: + count_query = count_query.filter(getattr(Llm, key) == value) + + total_count = await session.execute(count_query) + total_count = total_count.scalar() + + return llms, total_count + + def completion(self, llm, messages, **optional_params) -> Union[ModelResponse, CustomStreamWrapper]: + try: + response = self.ai_client.chat.completions.create(model=llm.aisuite_name, + messages=messages, + **optional_params) + # below is the way to call the model using the LiteLLM router + #response = self.router.completion(model=llm.llm_name, + # messages=messages, + # **optional_params) + logger.debug("completion response: {}".format(response)) + return response + except Exception as e: + logger.info(f"completion failed with error: {e.message}") + raise + + async def _load_models(self): + try: + # load models + ollama_task = asyncio.create_task(self._load_ollama_models()) + openai_task = asyncio.create_task(self._load_openai_models()) + await asyncio.gather(ollama_task, + openai_task, + return_exceptions=True) + # collect the available models + llms, total_llms = await self.retrieve_llms() + # configure router + model_list = [] + for llm in llms: + model_name = f"{llm.provider}/{llm.name}" + params = {} + params["model"] = llm.llm_name + if llm.provider == "ollama": + params["api_base"] = llm.api_base + if llm.provider == "openai": + params["api_key"] = get_env_key("OPENAI_API_KEY") + model = { + "model_name": model_name, + "litellm_params": params, + } + model_list.append(model) + self.router = Router(model_list=model_list) + except Exception as e: + logger.exception(e) + + async def _load_ollama_models(self): + try: + ollama_api_url = get_env_key("OLLAMA_API_URL") + except ValueError: + logger.info("No Ollama server specified. Skipping.") + return + # retrieve list of installed models + async with httpx.AsyncClient() as client: + response = await client.get("{}/api/tags".format(ollama_api_url)) + if response.status_code == 200: + data = response.json() + available_models = [model_data['model'] for model_data in data.get("models", [])] + else: + logger.warning(f"Error when retrieving models: {response.status_code} - {response.text}") + # create / update Ollama family Llm objects + provider = "ollama" + models = {} + for model in available_models: + name = model.removesuffix(":latest") + aisuite_name = "{}:{}".format(provider,name) # what aisuite expects + llm_name = "{}/{}".format(provider,name) # what LiteLLM expects + safe_name = llm_name.replace("/", "-").replace(":", "-") # URL-friendly ID + models[model] = {"id": safe_name, "name": name, "provider": provider, "aisuite_name": aisuite_name, "llm_name": llm_name, "api_base": ollama_api_url} + await self._persist_models(provider=provider, models=models) + + async def _load_openai_models(self): + try: + openai_api_key = get_env_key("OPENAI_API_KEY") + except ValueError: + logger.info("No OpenAI API key specified. Skipping.") + return + # retrieve list of installed models + async with httpx.AsyncClient() as client: + openai_api_url = get_env_key("OPENAI_API_URL", "https://api.openai.com") + headers = { + "Authorization": f"Bearer {openai_api_key}" + } + response = await client.get(f"{openai_api_url}/v1/models", headers=headers) + if response.status_code == 200: + data = response.json() + available_models = [model_data['id'] for model_data in data.get("data", [])] + else: + logger.warning(f"Error when retrieving models: {response.status_code} - {response.text}") + # create / update OpenAI family Llm objects + provider = "openai" + models = {} + for model in available_models: + llm_provider = None + if any(substring in model for substring in {"gpt","o1","chatgpt"}): + llm_provider = "openai" + if any(substring in model for substring in {"ada","babbage","curie","davinci","instruct"}): + llm_provider = "text-completion-openai" + if llm_provider: + name = model + aisuite_name = "{}:{}".format(provider,name) # what aisuite expects + llm_name = "{}/{}".format(llm_provider,name) # what LiteLLM expects + safe_name = f"{provider}/{name}".replace("/", "-").replace(":", "-") # URL-friendly ID + models[model] = {"id": safe_name, "name": name, "provider": provider, "aisuite_name": aisuite_name, "llm_name": llm_name, "api_base": openai_api_url} + await self._persist_models(provider=provider, models=models) + + async def _persist_models(self, provider, models): + async with db_session_context() as session: + # mark existing models as inactive + stmt = update(Llm).where(Llm.provider == provider).values(is_active=False) + result = await session.execute(stmt) + if result.rowcount > 0: + await session.commit() + # insert / update models + for model in models: + parameters = models[model] + model_id = parameters["id"] + llm = await self.get_llm(model_id, only_active=False) + if llm: + stmt = update(Llm).where(Llm.id == model_id).values(name=parameters["name"], + provider=parameters["provider"], + aisuite_name=parameters["aisuite_name"], + llm_name=parameters["llm_name"], + api_base=parameters["api_base"], + is_active=True) + result = await session.execute(stmt) + if result.rowcount > 0: + await session.commit() + else: + new_llm = Llm(id=model_id, + name=parameters["name"], + provider=parameters["provider"], + aisuite_name=parameters["aisuite_name"], + llm_name=parameters["llm_name"], + api_base=parameters["api_base"], + is_active=True) + session.add(new_llm) + await session.commit() + diff --git a/backend/models.py b/backend/models.py index 251cbe4f..d1e38939 100644 --- a/backend/models.py +++ b/backend/models.py @@ -65,6 +65,15 @@ class Share(SQLModelBase, table=True): expiration_dt: datetime | None = Field(default=None) # the link expiration date/time (optional) is_revoked: bool = Field() +class Llm(SQLModelBase, table=True): + id: str = Field(primary_key=True) # the model's unique, URL-friendly name + name: str = Field() + provider: str = Field() # model provider, eg "ollama" + aisuite_name: str = Field() # the model name known to aisuite + llm_name: str = Field() # the model name known to LiteLLM + api_base: str | None = Field(default=None) + is_active: bool = Field() # is the model installed / available? + # Resolve forward references User.model_rebuild() Cred.model_rebuild() diff --git a/backend/requirements.txt b/backend/requirements.txt index 0e5dd6b4..6b1abc90 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -20,3 +20,5 @@ structlog webauthn greenlet pyjwt +litellm +aisuite[ollama,openai] \ No newline at end of file diff --git a/backend/schemas.py b/backend/schemas.py index fce66742..0012726c 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -88,6 +88,14 @@ class ShareCreateSchema(ShareBaseSchema): class ShareSchema(ShareBaseSchema): id: str +class LlmSchema(BaseModel): + id: str + name: str + provider: str + full_name: str + api_base: Optional[str] = None + is_active: bool + class RegistrationOptions(BaseModel): email: str diff --git a/migrations/versions/73d50424c826_added_llm_table.py b/migrations/versions/73d50424c826_added_llm_table.py new file mode 100644 index 00000000..357044aa --- /dev/null +++ b/migrations/versions/73d50424c826_added_llm_table.py @@ -0,0 +1,36 @@ +"""added llm table + +Revision ID: 73d50424c826 +Revises: 187855982332 +Create Date: 2024-10-10 17:31:43.949960 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = '73d50424c826' +down_revision: Union[str, None] = '187855982332' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table('llm', + sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('aisuite_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('llm_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('api_base', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + + +def downgrade() -> None: + op.drop_table('llm')