diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index a609248..4131026 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -21,7 +21,7 @@ Circular dependency is intentionally broken by patching `memory._on_evict = load All services live on `app.state`: `settings`, `registry`, `runtime_manager`, `session_manager`. Access them in route handlers via `request.app.state.`. -**ModelRegistry** maps HuggingFace repo IDs → tasks in a JSON sidecar (`data_dir/model_registry.json`). Models are fetched from the HF cache via `huggingface_hub`; the registry is populated by `POST /v1/models/pull`. +**ModelRegistry** maps model IDs → `{task, source}` in a JSON sidecar (`data_dir/model_registry.json`). HuggingFace models (`source="hf"`) are fetched via `huggingface_hub`; pip-based OCR backends (`source="pip"`) are installed via `mataserver/core/pip_installer.py`. The registry supports both old flat format (`{"model": "task"}`) and new dict-of-dicts format (`{"model": {"task": "...", "source": "..."}}`), auto-migrating on read. The registry is populated by `POST /v1/models/pull`. ## Key Conventions @@ -102,4 +102,27 @@ Two-step: `POST /v1/sessions` creates a session → `WS /v1/stream/{session_id}` | `mataserver/schemas/requests.py` | `InferParams`, `to_mata_kwargs()`, `SUPPORTED_TASKS` | | `mataserver/core/result_converter.py` | MATA `VisionResult` → `InferResponse` dispatch | | `mataserver/api/deps.py` | Auth dependencies (HTTP + WebSocket) | -| `mataserver/models/registry.py` | Persistent HF model ID → task map | +| `mataserver/models/registry.py` | Persistent model ID → task + source map | +| `mataserver/core/backend_catalog.py` | Static catalog of pip-based OCR backends | +| `mataserver/core/pip_installer.py` | Pip install helper for non-HF backends | + +### Backend Catalog (pip-based backends) + +`mataserver/core/backend_catalog.py` is a **static Python catalog** (not JSON/YAML) that maps short backend names to installation metadata. This prevents arbitrary pip installs from user input. + +```python +from mataserver.core.backend_catalog import lookup, is_cataloged, get_source_type + +entry = lookup("easyocr") # CatalogEntry or None +is_cataloged("easyocr") # True +get_source_type("easyocr") # "pip" +get_source_type("org/model") # "hf" +``` + +Currently cataloged pip backends: `easyocr`, `paddleocr`, `tesseract`. + +When adding a new pip backend: + +1. Add a `CatalogEntry` to `_CATALOG` in `backend_catalog.py`. +2. `pull.py` and `mataserver/api/v1/models.py` dispatch automatically. +3. Register a result converter with `@_register("ocr")` in `result_converter.py` if needed. diff --git a/Dockerfile b/Dockerfile index 8638d6c..0882c53 100644 --- a/Dockerfile +++ b/Dockerfile @@ -72,4 +72,24 @@ ENV MATA_SERVER_PORT=8110 ENV MATA_SERVER_DATA_DIR=/var/lib/mataserver ENV PYTHONPATH=/usr/local/lib/python3.11/site-packages +# Optional: Pre-install OCR backends into the image at build time. +# Pre-baking avoids runtime pip installs and removes the need for outbound internet +# access in the container. Uncomment the backends you need. +# +# EasyOCR: +# RUN pip install --no-cache-dir easyocr +# +# PaddleOCR: +# RUN pip install --no-cache-dir paddlepaddle paddleocr +# +# Tesseract (requires system binary + Python binding): +# RUN apt-get update && apt-get install -y --no-install-recommends tesseract-ocr \ +# && rm -rf /var/lib/apt/lists/* \ +# && pip install --no-cache-dir pytesseract +# +# After installing pip packages, register each backend so it appears in the registry: +# RUN mataserver pull easyocr --task ocr +# RUN mataserver pull paddleocr --task ocr +# RUN mataserver pull tesseract --task ocr + ENTRYPOINT ["mataserver", "serve"] diff --git a/README.md b/README.md index 27e32b1..5908df1 100644 --- a/README.md +++ b/README.md @@ -93,16 +93,16 @@ curl http://localhost:8110/v1/health The `mataserver` console script provides commands for server management and model operations. -| Command | Description | -| ------------------------------ | ---------------------------------------------- | -| `mataserver serve` | Start the inference server | -| `mataserver pull --task T` | Download and register a model from HuggingFace | -| `mataserver list` | List all registered models (alias: `ls`) | -| `mataserver show ` | Show detailed info for a model | -| `mataserver rm ` | Remove a model from the registry | -| `mataserver load ` | Preload a model into memory (alias: `warmup`) | -| `mataserver stop ` | Unload a model from memory | -| `mataserver version` | Print version (also: `mataserver -v`) | +| Command | Description | +| ------------------------------ | ------------------------------------------------------------------ | +| `mataserver serve` | Start the inference server | +| `mataserver pull --task T` | Download/install and register a model (HuggingFace or pip backend) | +| `mataserver list` | List all registered models (alias: `ls`) | +| `mataserver show ` | Show detailed info for a model | +| `mataserver rm ` | Remove a model from the registry | +| `mataserver load ` | Preload a model into memory (alias: `warmup`) | +| `mataserver stop ` | Unload a model from memory | +| `mataserver version` | Print version (also: `mataserver -v`) | For full usage details, argument references, and examples, see [docs/api.md](docs/api.md#cli). @@ -167,17 +167,53 @@ curl http://localhost:8110/v1/health { "status": "ok", "version": "0.1.0", "gpu_available": false } ``` -### Pull a model from HuggingFace +### Pull a model ```bash +# HuggingFace model curl -X POST http://localhost:8110/v1/models/pull \ -H "Authorization: Bearer your-api-key" \ -H "Content-Type: application/json" \ - -d '{"source": "hf://datamata/rtdetr-l"}' + -d '{"model": "datamata/rtdetr-l", "task": "detect"}' ``` ```json -{ "status": "pulling", "model": "datamata/rtdetr-l" } +{ "status": "pulled", "model": "datamata/rtdetr-l" } +``` + +```bash +# Pip-based OCR backend +curl -X POST http://localhost:8110/v1/models/pull \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{"model": "easyocr", "task": "ocr"}' +``` + +Or via the CLI: + +```bash +# HuggingFace Task Detection model (Example: RT-DETR ResNet-18 backbone) +mataserver pull PekingU/rtdetr_r18vd --task detect + +# HuggingFace Task Classification model (Example: ResNet-50) +mataserver pull microsoft/resnet-50 --task classify + +# HuggingFace Task Segmentation model (Example: Mask2Former Swin-Tiny trained on COCO) +mataserver pull facebook/mask2former-swin-tiny-coco-instance --task segment + +# HuggingFace Task Depth model (Example: Depth Anything V2 Small) +mataserver pull depth-anything/Depth-Anything-V2-Small-hf --task depth + +# HuggingFace Task Visual Language Model (VLM) +mataserver pull Qwen/Qwen3-VL-2B-Instruct --task vlm + +# HuggingFace OCR model +mataserver pull stepfun-ai/GOT-OCR-2.0-hf --task ocr + +# Pip-installed OCR backends +mataserver pull easyocr --task ocr +mataserver pull paddleocr --task ocr +mataserver pull tesseract --task ocr # requires tesseract system binary ``` ### Single-shot inference (base64 JSON) diff --git a/docs/api.md b/docs/api.md index 404f4bf..a216582 100644 --- a/docs/api.md +++ b/docs/api.md @@ -46,7 +46,10 @@ mataserver serve ### `mataserver pull` -Download a HuggingFace model into the local cache and register it with the server so it appears in `GET /v1/models` and can be used for inference. +Download a model and register it with the server so it appears in `GET /v1/models` and can be used for inference. Supports two backend types: + +- **HuggingFace models** — downloaded via `huggingface_hub.snapshot_download()` and stored in the standard HF cache (`~/.cache/huggingface`). These are identified by a `org/repo-name` slash-separated ID. +- **Pip-based OCR backends** — installed via `pip` into the current Python environment. These are identified by a short backend name (e.g. `easyocr`, `paddleocr`, `tesseract`). ```bash mataserver pull --task @@ -54,24 +57,35 @@ mataserver pull --task | Argument | Description | | ---------- | ----------------------------------------------------------------------------------------------- | -| `MODEL_ID` | HuggingFace repo ID, e.g. `facebook/detr-resnet-50` | +| `MODEL_ID` | HuggingFace repo ID (`org/name`) or pip backend name (`easyocr`, `paddleocr`, `tesseract`) | | `--task` | Inference task. One of: `classify`, `depth`, `detect`, `ocr`, `pose`, `segment`, `track`, `vlm` | -The model is downloaded via `huggingface_hub.snapshot_download()` and stored in the standard HuggingFace cache directory (`~/.cache/huggingface` by default, or the path set by `HF_HUB_CACHE`). The model ID and task are then written to `model_registry.json` in `MATA_SERVER_DATA_DIR`. +After a successful pull the model ID, task, and source type are written to `model_registry.json` in `MATA_SERVER_DATA_DIR`. **Examples**: ```bash -# Object detection +# Object detection (HuggingFace) mataserver pull facebook/detr-resnet-50 --task=detect -# Image classification +# Image classification (HuggingFace) mataserver pull google/vit-base-patch16-224 --task=classify -# Depth estimation +# Depth estimation (HuggingFace) mataserver pull LiheYoung/depth-anything-base-hf --task=depth + +# OCR — HuggingFace models +mataserver pull stepfun-ai/GOT-OCR-2.0-hf --task ocr +mataserver pull microsoft/trocr-base-printed --task ocr + +# OCR — pip-based backends +mataserver pull easyocr --task ocr +mataserver pull paddleocr --task ocr +mataserver pull tesseract --task ocr # also requires the tesseract system binary ``` +> **Pip OCR backends**: `easyocr`, `paddleocr`, and `tesseract` are installed as Python packages into the active virtual environment rather than downloaded from HuggingFace. `tesseract` additionally requires the `tesseract-ocr` system binary; if it is not found on `PATH` a warning is printed but the pull still succeeds. See [OCR Backends](#ocr-backends) for details. + **Exit codes**: | Code | Meaning | @@ -98,8 +112,9 @@ mataserver list | Column | Description | | ----------- | ---------------------------------------------------------------- | -| `MODEL` | HuggingFace repo ID | +| `MODEL` | HuggingFace repo ID or pip backend name | | `TASK` | Inference task (`detect`, `segment`, etc.) | +| `SOURCE` | `hf` for HuggingFace models, `pip` for pip-based backends | | `SIZE (MB)` | On-disk size in MB from the HuggingFace cache, or `—` if unknown | If no models are registered, prints `"No models registered."`. @@ -109,10 +124,12 @@ If no models are registered, prints `"No models registered."`. ```bash mataserver list -MODEL TASK SIZE (MB) -------------------------------------------------------- -facebook/detr-resnet-50 detect 167.3 -google/vit-base-patch16-224 classify 327.5 +MODEL TASK SOURCE SIZE (MB) +-------------------------------------------------------------- +facebook/detr-resnet-50 detect hf 167.3 +google/vit-base-patch16-224 classify hf 327.5 +easyocr ocr pip — +tesseract ocr pip — ``` **Exit codes**: @@ -137,24 +154,58 @@ mataserver show **Output fields**: -| Field | Description | -| --------------- | ----------------------------- | -| `model` | HuggingFace repo ID | -| `task` | Registered inference task | -| `size` | On-disk size in MB, or `—` | -| `last_accessed` | Timestamp of last use, or `—` | +| Field | Description | +| --------------- | ------------------------------------------------------------------- | +| `model` | HuggingFace repo ID or pip backend name | +| `task` | Registered inference task | +| `source` | `hf` (HuggingFace) or `pip` (pip-based backend) | +| `size` | On-disk size in MB from HF cache, or `—` (pip models have no cache) | +| `last_accessed` | Timestamp of last HF cache access, or `—` (pip models) | +| `pip_packages` | _(pip only)_ Comma-separated list of installed pip packages | +| `installed` | _(pip only)_ `yes` / `no` — whether the package is importable | +| `system_binary` | _(pip only, if applicable)_ Binary name and whether it was found | -**Example**: +**Example — HuggingFace model**: ```bash mataserver show facebook/detr-resnet-50 model: facebook/detr-resnet-50 task: detect + source: hf size: 167.30 MB last_accessed: 2026-03-05 14:22:01 ``` +**Example — pip backend**: + +```bash +mataserver show easyocr + + model: easyocr + task: ocr + source: pip + size: — + last_accessed: — + pip_packages: easyocr + installed: yes +``` + +**Example — Tesseract (with system binary check)**: + +```bash +mataserver show tesseract + + model: tesseract + task: ocr + source: pip + size: — + last_accessed: — + pip_packages: pytesseract + installed: yes + system_binary: tesseract (yes) +``` + **Exit codes**: | Code | Meaning | @@ -176,7 +227,7 @@ mataserver rm | ---------- | ----------------------------------------------- | | `MODEL_ID` | HuggingFace repo ID to remove from the registry | -**Example**: +**Example — HuggingFace model**: ```bash mataserver rm facebook/detr-resnet-50 @@ -184,6 +235,14 @@ Removed 'facebook/detr-resnet-50' from the registry. Note: model weights on disk (HF cache) were not deleted. ``` +**Example — pip backend**: + +```bash +mataserver rm easyocr +Removed 'easyocr' from the registry. +Note: pip packages were not uninstalled. Remove manually if needed. +``` + **Exit codes**: | Code | Meaning | @@ -282,6 +341,59 @@ mataserver 0.6.0 --- +## OCR Backends + +The `ocr` task supports two categories of model backend: + +### HuggingFace OCR models + +These are pulled like any other model — weights are downloaded into the HuggingFace cache and `source` is recorded as `"hf"`: + +| Model ID | Notes | +| ---------------------------------- | -------------------------------------- | +| `stepfun-ai/GOT-OCR-2.0-hf` | GOT-OCR2 — general-purpose OCR | +| `microsoft/trocr-base-printed` | TrOCR — optimised for printed text | +| `microsoft/trocr-base-handwritten` | TrOCR — optimised for handwritten text | + +### Pip-based OCR backends + +These are installed as Python packages (and optionally require a system binary). `source` is recorded as `"pip"`. They do **not** occupy space in the HuggingFace cache. + +| Backend name | Pip packages installed | System binary required | Notes | +| ------------ | --------------------------- | ---------------------- | --------------------------------------- | +| `easyocr` | `easyocr` | None | Supports 80+ languages | +| `paddleocr` | `paddlepaddle`, `paddleocr` | None | High accuracy; larger install | +| `tesseract` | `pytesseract` | `tesseract-ocr` | Requires system binary (see note below) | + +> **Tesseract system binary**: `mataserver pull tesseract --task ocr` installs the `pytesseract` Python wrapper but **not** the `tesseract-ocr` binary itself. Install it separately: +> +> - **Debian/Ubuntu**: `apt-get install -y tesseract-ocr` +> - **macOS**: `brew install tesseract` +> - **Windows**: Download from [UB-Mannheim/tesseract](https://github.com/UB-Mannheim/tesseract/wiki) +> +> If the binary is not found at pull time, a warning is printed but the backend is still registered. + +### Checking backend status + +```bash +mataserver show easyocr + source: pip + pip_packages: easyocr + installed: yes + +mataserver show tesseract + source: pip + pip_packages: pytesseract + installed: yes + system_binary: tesseract (yes) +``` + +### Removing a pip backend + +`mataserver rm ` removes the registration entry only. The pip packages are **not** uninstalled automatically; remove them manually with `pip uninstall ` if desired. + +--- + ## Authentication Most endpoints require a Bearer token in the `Authorization` header. @@ -381,24 +493,36 @@ List all models that are currently installed (i.e. present in the HuggingFace ca { "model": "PekingU/rtdetr_v2_r101vd", "task": "detect", + "source": "hf", "state": "idle", "size_mb": 421.0, "memory_mb": 512.0, "loaded_at": 1709550000.0, "last_used": 1709553600.0 + }, + { + "model": "easyocr", + "task": "ocr", + "source": "pip", + "state": "unloaded", + "size_mb": null, + "memory_mb": null, + "loaded_at": null, + "last_used": null } ] ``` -| Field | Type | Description | -| ----------- | -------------- | --------------------------------------------------- | -| `model` | string | HuggingFace repo ID (e.g. `"org/name"`) | -| `task` | string | Inference task (`detect`, `segment`, `classify`, …) | -| `state` | string | Current lifecycle state (see table below) | -| `size_mb` | number \| null | On-disk size in MB from the HuggingFace cache | -| `memory_mb` | number \| null | Allocated memory in MB (null when unloaded) | -| `loaded_at` | number \| null | Unix timestamp of when the model was loaded | -| `last_used` | number \| null | Unix timestamp of the most recent inference call | +| Field | Type | Description | +| ----------- | -------------- | ------------------------------------------------------------- | +| `model` | string | HuggingFace repo ID or pip backend name | +| `task` | string | Inference task (`detect`, `segment`, `classify`, …) | +| `source` | string | `"hf"` for HuggingFace models, `"pip"` for pip-based backends | +| `state` | string | Current lifecycle state (see table below) | +| `size_mb` | number \| null | On-disk size in MB from the HF cache; `null` for pip backends | +| `memory_mb` | number \| null | Allocated memory in MB (null when unloaded) | +| `loaded_at` | number \| null | Unix timestamp of when the model was loaded | +| `last_used` | number \| null | Unix timestamp of the most recent inference call | Model `state` values: @@ -444,7 +568,7 @@ curl -H "Authorization: Bearer $KEY" \ ### `POST /v1/models/pull` -Download a model from HuggingFace into the default HuggingFace cache (`~/.cache/huggingface`) and register it with the server. The operation runs asynchronously; the 202 response is returned once the download completes. +Download or install a model and register it with the server. Supports both HuggingFace models (downloaded into `~/.cache/huggingface`) and pip-based OCR backends (installed into the active Python environment). The 202 response is returned once the operation completes. **Request body**: @@ -455,10 +579,10 @@ Download a model from HuggingFace into the default HuggingFace cache (`~/.cache/ } ``` -| Field | Type | Description | -| ------- | ------ | --------------------------------------------- | -| `model` | string | HuggingFace repo ID (e.g. `"org/model-name"`) | -| `task` | string | Inference task (`detect`, `segment`, …) | +| Field | Type | Description | +| ------- | ------ | ----------------------------------------------------------------------------- | +| `model` | string | HuggingFace repo ID (`"org/model-name"`) or pip backend name (`"easyocr"`, …) | +| `task` | string | Inference task (`detect`, `segment`, `ocr`, …) | **Response `202 Accepted`**: @@ -468,19 +592,26 @@ Download a model from HuggingFace into the default HuggingFace cache (`~/.cache/ **Error responses**: -| Code | Condition | -| ---- | ----------------------------------------------------------- | -| 400 | Pull failed (network error, model not found on HuggingFace) | -| 409 | A pull for the same model is already in progress | -| 500 | Unexpected server error | +| Code | Condition | +| ---- | ------------------------------------------------------------------------------ | +| 400 | Pull failed (network error, model not found, task mismatch, pip install error) | +| 409 | A pull for the same model is already in progress | +| 500 | Unexpected server error | -**Example**: +**Examples**: ```bash +# HuggingFace model curl -X POST http://localhost:8110/v1/models/pull \ -H "Authorization: Bearer $KEY" \ -H "Content-Type: application/json" \ -d '{"model": "PekingU/rtdetr_v2_r101vd", "task": "detect"}' + +# Pip-based OCR backend +curl -X POST http://localhost:8110/v1/models/pull \ + -H "Authorization: Bearer $KEY" \ + -H "Content-Type: application/json" \ + -d '{"model": "easyocr", "task": "ocr"}' ``` --- diff --git a/docs/deployment.md b/docs/deployment.md index 0d50712..31a7005 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -14,6 +14,7 @@ This guide covers production deployment of MATASERVER using Docker, Docker Compo 6. [Environment Variable Configuration](#6-environment-variable-configuration) 7. [Health Check & Monitoring](#7-health-check--monitoring) 8. [Data Directory & Volume Management](#8-data-directory--volume-management) +9. [Pre-installing OCR Backends](#9-pre-installing-ocr-backends) --- @@ -491,12 +492,123 @@ docker run --rm \ ### Disk Space Planning -| Item | Typical Size | -| ---------------------------------- | ----------------------------- | -| Server image | ~2-2.5 GB | -| Small vision model (ONNX) | 20-100 MB | -| Large VLM | 2-10 GB | -| Blob cache (HuggingFace downloads) | Varies; can be cleared safely | +| Item | Typical Size | +| ----------------------------------- | ----------------------------- | +| Server image | ~2-2.5 GB | +| Small vision model (ONNX) | 20-100 MB | +| Large VLM | 2-10 GB | +| Blob cache (HuggingFace downloads) | Varies; can be cleared safely | +| EasyOCR weights (auto-downloaded) | ~500 MB | +| PaddleOCR weights (auto-downloaded) | ~100-500 MB | + +--- + +## 9. Pre-installing OCR Backends + +MATASERVER supports three pip-based OCR backends — `easyocr`, `paddleocr`, and `tesseract` — in addition to HuggingFace OCR models such as `stepfun-ai/GOT-OCR-2.0-hf` and `microsoft/trocr-base-printed`. + +Pip backends are installed at runtime via `mataserver pull --task ocr`, which runs `pip install` inside the container. In production container images this is suboptimal: + +- It adds latency on the first `pull` request (pip resolves and downloads packages). +- It requires outbound internet access from the container at runtime. +- Layer caching is lost every container restart. + +The recommended approach is to **pre-bake** the pip packages into the image at build time. + +### Supported OCR Backends + +| Backend | Source | pip packages | System dependency | +| ----------- | ------ | ------------------------ | --------------------- | +| `easyocr` | pip | `easyocr` | none | +| `paddleocr` | pip | `paddlepaddle paddleocr` | none | +| `tesseract` | pip | `pytesseract` | `tesseract-ocr` (apt) | + +HuggingFace OCR models use the standard `mataserver pull --task ocr` path and are not affected by this section. + +### Pre-baking OCR Backends into a Docker Image + +Extend the runtime stage of the Dockerfile with the backends you need: + +```dockerfile +# ── EasyOCR ────────────────────────────────────────────────────────────────── +RUN pip install --no-cache-dir easyocr +RUN mataserver pull easyocr --task ocr + +# ── PaddleOCR ──────────────────────────────────────────────────────────────── +RUN pip install --no-cache-dir paddlepaddle paddleocr +RUN mataserver pull paddleocr --task ocr + +# ── Tesseract ──────────────────────────────────────────────────────────────── +# Step 1: install system binary (apt) +RUN apt-get update && apt-get install -y --no-install-recommends tesseract-ocr \ + && rm -rf /var/lib/apt/lists/* +# Step 2: install Python bindings +RUN pip install --no-cache-dir pytesseract +# Step 3: register the backend +RUN mataserver pull tesseract --task ocr +``` + +The `mataserver pull` commands register each backend in the model registry baked into the image. The container starts with the backends already available — no internet access is needed at runtime. + +> **Note**: EasyOCR and PaddleOCR download their model weights on first inference (not during `pull`). If you need fully air-gapped operation, pre-download the weights by running a warmup inference during the build stage. + +### Dockerfile: Commented Pre-install Block + +For reference, see the commented block near the end of the provided `Dockerfile`: + +```dockerfile +# Optional: Pre-install OCR backends +# Uncomment the backends you need: +# RUN pip install --no-cache-dir easyocr +# RUN pip install --no-cache-dir paddlepaddle paddleocr +# RUN apt-get update && apt-get install -y --no-install-recommends tesseract-ocr \ +# && rm -rf /var/lib/apt/lists/* \ +# && pip install --no-cache-dir pytesseract +``` + +### Tesseract System Dependency + +Tesseract requires the `tesseract-ocr` system binary. Install it: + +**Ubuntu / Debian (bare metal or inside a Dockerfile)**: + +```bash +sudo apt-get update && sudo apt-get install -y tesseract-ocr +``` + +**Alpine Linux**: + +```bash +apk add --no-cache tesseract-ocr +``` + +The Python binding `pytesseract` is a thin wrapper around the binary — it will fail at inference time if `tesseract` is not on `PATH`. + +After installing the system binary, register the backend: + +```bash +mataserver pull tesseract --task ocr +``` + +If the binary is not found when `mataserver pull tesseract` runs, a warning is logged but the pull succeeds (the system binary is checked again at inference time by the MATA adapter). + +### Verifying the Pre-installed Backends + +After building your custom image, confirm the backends are registered: + +```bash +docker run --rm mataserver-custom mataserver list +``` + +Expected output: + +``` +MODEL TASK SOURCE SIZE (MB) +----------------------------------------- +easyocr ocr pip — +paddleocr ocr pip — +tesseract ocr pip — +``` Monitor available space: diff --git a/examples/images/a01-122-02.jpg b/examples/images/a01-122-02.jpg new file mode 100644 index 0000000..ce1e481 Binary files /dev/null and b/examples/images/a01-122-02.jpg differ diff --git a/examples/rest_infer.py b/examples/rest_infer.py index 24aaa2d..6db380b 100644 --- a/examples/rest_infer.py +++ b/examples/rest_infer.py @@ -8,6 +8,9 @@ python examples/rest_infer.py --image examples/images/coco_cat_remote.jpg --task detect --model PekingU/rtdetr_r18vd python examples/rest_infer.py --image examples/images/coco_cat_remote.jpg --task classify --model google/efficientnet-b0 python examples/rest_infer.py --image examples/images/coco_cat_remote.jpg --task segment --model facebook/mask2former-swin-tiny-coco-instance + python examples/rest_infer.py --image examples/images/coco_cat_remote.jpg --task depth --model depth-anything/Depth-Anything-V2-Small-hf + + python examples/rest_infer.py --image examples/images/a01-122-02.jpg --task ocr --model easyocr # Zero-shot detection with text prompts python examples/rest_infer.py --image examples/images/coco_cat_remote.jpg --task detect --model google/owlv2-base-patch16-ensemble --prompts "cat,remote,dog,car" @@ -15,8 +18,7 @@ """ import argparse -import base64 -import json +import base64 from pathlib import Path import requests @@ -87,17 +89,30 @@ def main(image_path: Path, task: str | None, model: str | None, prompts: str | N for s in segments[:5]: print(f" [{s['confidence']:.2f}] {s['label']} bbox={s['bbox']}") + elif t == "depth": + depths = result.get("depth_map") or [] + print(f"Depth map shape: {depths}") + + elif t == "ocr": + ocr_result = result.get("text") or [] + print(f" OCR text results: {ocr_result}") + print(f" Full response keys: {list(result.keys())}") if __name__ == "__main__": - parser = argparse.ArgumentParser(description="REST inference — detect / classify / segment") + parser = argparse.ArgumentParser(description="REST inference — detect / classify / segment / depth / ocr") parser.add_argument("--image", type=Path, required=True, help="Path to an image file") - parser.add_argument("--task", choices=["detect", "classify", "segment"], default=None, - help="Run a single task (default: run all three)") + parser.add_argument( + "--task", + choices=["detect", "classify", "segment", "depth", "ocr"], + default=None, + help="Run a single task (default: run all three)", + ) parser.add_argument("--model", default=None, help="Override the model ref") - parser.add_argument("--prompts", default=None, - help="Comma-separated text prompts for zero-shot tasks") + parser.add_argument( + "--prompts", default=None, help="Comma-separated text prompts for zero-shot tasks" + ) parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=8110) args = parser.parse_args() diff --git a/examples/rest_vlm.py b/examples/rest_vlm.py index cbabd5b..2b32b46 100644 --- a/examples/rest_vlm.py +++ b/examples/rest_vlm.py @@ -2,6 +2,7 @@ Usage: python examples/rest_vlm.py --image examples/images/coco_cat_remote.jpg --prompt "What do you see?" + python examples/rest_vlm.py --image examples/images/coco_cat_remote.jpg --prompt "List all objects you can identify in this image." --output_mode "detect" # With extra generation options python examples/rest_vlm.py --image examples/images/coco_cat_remote.jpg \ @@ -16,7 +17,7 @@ import requests BASE_URL = "http://127.0.0.1:8110/v1" -DEFAULT_MODEL = "Qwen/Qwen2.5-VL-3B-Instruct" +DEFAULT_MODEL = "Qwen/Qwen3-VL-2B-Instruct" def infer_vlm( @@ -25,6 +26,7 @@ def infer_vlm( prompt: str, max_tokens: int | None = None, temperature: float | None = None, + output_mode: str | None = None, ) -> dict: """POST /v1/infer for a VLM task and return the parsed response.""" image_b64 = base64.b64encode(image_path.read_bytes()).decode() @@ -34,6 +36,8 @@ def infer_vlm( params["max_tokens"] = max_tokens if temperature is not None: params["temperature"] = temperature + if output_mode is not None: + params["output_mode"] = output_mode payload = { "model": model, @@ -53,13 +57,14 @@ def main( prompt: str, max_tokens: int | None, temperature: float | None, + output_mode: str, ) -> None: - print(f"\n--- VLM Inference ---") + print("\n--- VLM Inference ---") print(f" Model : {model}") print(f" Prompt : {prompt!r}") - + print(f" Output Mode : {output_mode}") try: - result = infer_vlm(model, image_path, prompt, max_tokens, temperature) + result = infer_vlm(model, image_path, prompt, max_tokens, temperature, output_mode) except requests.HTTPError as exc: print(f" ERROR {exc.response.status_code}: {exc.response.text}") return @@ -75,6 +80,7 @@ def main( parser.add_argument("--model", default=DEFAULT_MODEL, help="VLM model ref") parser.add_argument("--max-tokens", type=int, default=None, help="Max tokens to generate") parser.add_argument("--temperature", type=float, default=None, help="Sampling temperature") + parser.add_argument("--output_mode", default="text", help="Output mode of the VLM") parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=8110) args = parser.parse_args() @@ -85,4 +91,4 @@ def main( print(f"ERROR: image not found: {args.image}") raise SystemExit(1) - main(args.image, args.model, args.prompt, args.max_tokens, args.temperature) + main(args.image, args.model, args.prompt, args.max_tokens, args.temperature, args.output_mode) diff --git a/examples/ws_video_infer.py b/examples/ws_video_infer.py index 77970d6..cdbb910 100644 --- a/examples/ws_video_infer.py +++ b/examples/ws_video_infer.py @@ -30,7 +30,7 @@ # --------------------------------------------------------------------------- # Binary frame wire format (must match mataserver/streaming/protocol.py) # --------------------------------------------------------------------------- -HEADER_FORMAT = ">IdB" # big-endian: uint32 frame_id, float64 timestamp, uint8 encoding +HEADER_FORMAT = ">IdB" # big-endian: uint32 frame_id, float64 timestamp, uint8 encoding HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 13 bytes ENCODING_JPEG = 0 @@ -62,7 +62,9 @@ async def stream_video( headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} # 1. Create a streaming session - print(f"\n[1/3] Creating session model={model!r} task={task!r} frame_policy={frame_policy!r}") + print( + f"\n[1/3] Creating session model={model!r} task={task!r} frame_policy={frame_policy!r}" + ) async with aiohttp.ClientSession(headers=headers) as http: async with http.post( f"{base_url}/v1/sessions", @@ -98,7 +100,6 @@ async def stream_video( ws_session = aiohttp.ClientSession() try: async with ws_session.ws_connect(ws_url) as ws: - # Background task: receive and print inference results async def receive_loop() -> None: nonlocal received, drops, errors @@ -109,7 +110,9 @@ async def receive_loop() -> None: drops += 1 continue if "error" in payload: - print(f" [frame {payload.get('frame_id', '?')}] ERROR: {payload['error']}") + print( + f" [frame {payload.get('frame_id', '?')}] ERROR: {payload['error']}" + ) errors += 1 else: fid = payload.get("frame_id", "?") @@ -155,7 +158,7 @@ async def receive_loop() -> None: except asyncio.CancelledError: pass - print(f"\n Sent : {sent} frames in {elapsed:.2f}s ({sent/elapsed:.1f} fps)") + print(f"\n Sent : {sent} frames in {elapsed:.2f}s ({sent / elapsed:.1f} fps)") print(f" Received: {received} results | {drops} dropped | {errors} errors") finally: cap.release() @@ -187,7 +190,9 @@ async def receive_loop() -> None: parser.add_argument("--fps-limit", type=float, default=0, help="Max send fps (0 = native)") parser.add_argument("--api-key", default=None, help="API key (omit if auth_mode=none)") parser.add_argument( - "--frame-policy", choices=["latest", "queue"], default="latest", + "--frame-policy", + choices=["latest", "queue"], + default="latest", help="Frame-handling policy (default: latest)", ) parser.add_argument("--host", default="127.0.0.1") @@ -198,14 +203,16 @@ async def receive_loop() -> None: print(f"ERROR: video not found: {args.video}") raise SystemExit(1) - asyncio.run(stream_video( - host=args.host, - port=args.port, - model=args.model, - task=args.task, - video_path=args.video, - max_frames=args.max_frames, - fps_limit=args.fps_limit, - frame_policy=args.frame_policy, - api_key=args.api_key, - )) + asyncio.run( + stream_video( + host=args.host, + port=args.port, + model=args.model, + task=args.task, + video_path=args.video, + max_frames=args.max_frames, + fps_limit=args.fps_limit, + frame_policy=args.frame_policy, + api_key=args.api_key, + ) + ) diff --git a/mataserver/api/v1/models.py b/mataserver/api/v1/models.py index 713c77c..ae8d882 100644 --- a/mataserver/api/v1/models.py +++ b/mataserver/api/v1/models.py @@ -6,6 +6,8 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from mataserver.api.deps import verify_api_key +from mataserver.core import pip_installer +from mataserver.core.backend_catalog import lookup as catalog_lookup from mataserver.core.pull import download_model from mataserver.schemas.requests import ( PullRequest, @@ -44,13 +46,13 @@ async def get_model(model_id: str, request: Request): @router.post("/models/pull", status_code=status.HTTP_202_ACCEPTED) async def pull_model(body: PullRequest, request: Request): - """Download a model from HuggingFace into the local cache and register it. + """Download or install a model and register it. - The model is downloaded via ``huggingface_hub.snapshot_download()`` which - stores weights in the standard HuggingFace cache (``~/.cache/huggingface`` - by default, or the path set by ``HF_HUB_CACHE``). The model ID and task - are then persisted in the server's registry so the model appears in - ``GET /v1/models`` and can be loaded for inference. + For HuggingFace models, weights are downloaded via + ``huggingface_hub.snapshot_download()``. For cataloged pip backends + (e.g. ``easyocr``, ``paddleocr``, ``tesseract``), the required packages + are installed via pip. In both cases the model ID and task are persisted + in the registry so the model appears in ``GET /v1/models``. """ if body.model in _pulling: raise HTTPException( @@ -60,17 +62,41 @@ async def pull_model(body: PullRequest, request: Request): _pulling.add(body.model) try: - try: - await asyncio.to_thread(download_model, body.model) - except Exception as exc: - logger.exception("Failed to download model %s", body.model) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Download failed for {body.model!r}: {exc}", - ) from exc + catalog_entry = catalog_lookup(body.model) + + if catalog_entry is not None: + # Pip-based backend + if body.task != catalog_entry.task: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"{catalog_entry.display_name} only supports task " + f"{catalog_entry.task!r}, got {body.task!r}" + ), + ) + try: + await asyncio.to_thread(pip_installer.install_packages, catalog_entry.pip_packages) + except Exception as exc: + logger.exception("Failed to install %s", body.model) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Install failed for {body.model!r}: {exc}", + ) from exc + source = "pip" + else: + # HuggingFace model — existing path + try: + await asyncio.to_thread(download_model, body.model) + except Exception as exc: + logger.exception("Failed to download model %s", body.model) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Download failed for {body.model!r}: {exc}", + ) from exc + source = "hf" registry = request.app.state.registry - await registry.register(body.model, body.task) + await registry.register(body.model, body.task, source=source) finally: _pulling.discard(body.model) diff --git a/mataserver/core/backend_catalog.py b/mataserver/core/backend_catalog.py new file mode 100644 index 0000000..eac293a --- /dev/null +++ b/mataserver/core/backend_catalog.py @@ -0,0 +1,63 @@ +"""Static catalog of non-HuggingFace OCR backends. + +Maps short names (e.g. "easyocr") to pip installation metadata. +Models NOT in this catalog are assumed to be HuggingFace repos +and follow the existing snapshot_download() path. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class CatalogEntry: + """Installation metadata for a pip-based backend.""" + + task: str + pip_packages: tuple[str, ...] + verify_import: str + display_name: str + system_binary: str | None = None + estimated_memory_mb: float = 256.0 + + +_CATALOG: dict[str, CatalogEntry] = { + "easyocr": CatalogEntry( + task="ocr", + pip_packages=("easyocr",), + verify_import="easyocr", + display_name="EasyOCR", + estimated_memory_mb=300.0, + ), + "paddleocr": CatalogEntry( + task="ocr", + pip_packages=("paddlepaddle", "paddleocr"), + verify_import="paddleocr", + display_name="PaddleOCR", + estimated_memory_mb=400.0, + ), + "tesseract": CatalogEntry( + task="ocr", + pip_packages=("pytesseract",), + verify_import="pytesseract", + display_name="Tesseract OCR", + system_binary="tesseract", + estimated_memory_mb=100.0, + ), +} + + +def lookup(model: str) -> CatalogEntry | None: + """Return catalog entry if *model* is a known pip backend, else ``None``.""" + return _CATALOG.get(model) + + +def is_cataloged(model: str) -> bool: + """Return ``True`` if *model* is a known pip backend.""" + return model in _CATALOG + + +def get_source_type(model: str) -> str: + """Return ``"pip"`` if *model* is a cataloged backend, else ``"hf"``.""" + return "pip" if model in _CATALOG else "hf" diff --git a/mataserver/core/model_loader.py b/mataserver/core/model_loader.py index 8706017..fb4c0cd 100644 --- a/mataserver/core/model_loader.py +++ b/mataserver/core/model_loader.py @@ -17,13 +17,20 @@ def _estimate_memory_mb(model: str) -> float: - """Estimate model memory (MB) from its HuggingFace cache entry. + """Estimate model memory (MB) from catalog or HuggingFace cache entry. - Uses ``huggingface_hub.scan_cache_dir()`` to look up the on-disk size of - the model and applies a 1.2× overhead multiplier to account for runtime - buffers. Falls back to ``_DEFAULT_MEMORY_MB`` when the model is not yet - in the HF cache or the library is unavailable. + For pip-backed backends (e.g. easyocr, tesseract), returns the + ``estimated_memory_mb`` from the backend catalog. For HuggingFace models, + uses ``huggingface_hub.scan_cache_dir()`` and applies a 1.2× overhead + multiplier. Falls back to ``_DEFAULT_MEMORY_MB`` when the model is not + found in either source. """ + from mataserver.core.backend_catalog import lookup # noqa: PLC0415 + + catalog_entry = lookup(model) + if catalog_entry is not None: + return catalog_entry.estimated_memory_mb + try: info = scan_cache_dir() for repo in info.repos: @@ -92,7 +99,11 @@ async def load( # For segment task, request polygon output from the adapter so # callers receive plain coordinates instead of opaque RLE blobs. - adapter_kwargs: dict = {"device": device} + # Pip-backed OCR backends (e.g. easyocr) don't accept a 'device' + # kwarg — their Reader.__init__() uses gpu=True/False internally. + from mataserver.core.backend_catalog import lookup as _catalog_lookup # noqa: PLC0415 + + adapter_kwargs: dict = {} if _catalog_lookup(model) is not None else {"device": device} if task == "segment": adapter_kwargs["use_polygon"] = True diff --git a/mataserver/core/models.py b/mataserver/core/models.py index d136a4a..6fdea27 100644 --- a/mataserver/core/models.py +++ b/mataserver/core/models.py @@ -103,12 +103,34 @@ async def _show_model(model: str, data_dir: Path) -> dict[str, Any] | None: if reg_entry is None: return None - cache_map = ModelRegistry._get_hf_cache_map() - repo = cache_map.get(model) - - return { + source = reg_entry.get("source", "hf") + info: dict[str, Any] = { "model": model, "task": reg_entry["task"], - "size_mb": round(repo.size_on_disk / (1024 * 1024), 2) if repo else None, - "last_accessed": repo.last_accessed if repo else None, + "source": source, } + + if source == "hf": + cache_map = ModelRegistry._get_hf_cache_map() + repo = cache_map.get(model) + info["size_mb"] = round(repo.size_on_disk / (1024 * 1024), 2) if repo else None + info["last_accessed"] = repo.last_accessed if repo else None + else: + # Pip-backed: enrich with catalog info + from mataserver.core.backend_catalog import lookup # noqa: PLC0415 + from mataserver.core.pip_installer import ( # noqa: PLC0415 + check_system_binary, + verify_import, + ) + + entry = lookup(model) + info["size_mb"] = None + info["last_accessed"] = None + if entry: + info["pip_packages"] = list(entry.pip_packages) + info["installed"] = verify_import(entry.verify_import) + if entry.system_binary: + info["system_binary"] = entry.system_binary + info["binary_found"] = check_system_binary(entry.system_binary) + + return info diff --git a/mataserver/core/pip_installer.py b/mataserver/core/pip_installer.py new file mode 100644 index 0000000..88f9f02 --- /dev/null +++ b/mataserver/core/pip_installer.py @@ -0,0 +1,46 @@ +"""Pip package installer for non-HuggingFace backends. + +Security: Only installs packages explicitly listed in the backend catalog. +Never passes raw user input to pip. +""" + +import importlib.util +import logging +import shutil +import subprocess +import sys + +logger = logging.getLogger(__name__) + + +def install_packages(packages: tuple[str, ...]) -> None: + """Install pip packages into the current Python environment. + + Runs ``sys.executable -m pip install `` via subprocess. + Raises ``RuntimeError`` on non-zero exit code. + + Args: + packages: Tuple of package names from the backend catalog. + """ + cmd = [sys.executable, "-m", "pip", "install", *packages] + logger.info("Installing packages: %s", ", ".join(packages)) + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False, + ) + if result.returncode != 0: + logger.error("pip install failed:\n%s", result.stderr) + raise RuntimeError(f"pip install failed for {', '.join(packages)}: {result.stderr.strip()}") + logger.info("Successfully installed: %s", ", ".join(packages)) + + +def verify_import(module_name: str) -> bool: + """Return ``True`` if *module_name* is importable.""" + return importlib.util.find_spec(module_name) is not None + + +def check_system_binary(name: str) -> bool: + """Return ``True`` if *name* is found on the system PATH.""" + return shutil.which(name) is not None diff --git a/mataserver/core/pull.py b/mataserver/core/pull.py index 41ae4e5..dcbe2d7 100644 --- a/mataserver/core/pull.py +++ b/mataserver/core/pull.py @@ -19,6 +19,8 @@ from huggingface_hub import snapshot_download +from mataserver.core import pip_installer +from mataserver.core.backend_catalog import lookup as catalog_lookup from mataserver.models.registry import ModelRegistry from mataserver.schemas.requests import SUPPORTED_TASKS @@ -70,13 +72,37 @@ def pull_model(model: str, task: str, data_dir: Path) -> None: f"Unsupported task {task!r}. Choose from: {', '.join(sorted(SUPPORTED_TASKS))}" ) - download_model(model) - asyncio.run(_register(model, task, data_dir)) - logger.info("Registered %s (task=%s)", model, task) - - -async def _register(model: str, task: str, data_dir: Path) -> None: - """Load the registry from *data_dir* and persist *model* → *task*.""" + catalog_entry = catalog_lookup(model) + if catalog_entry is not None: + if task != catalog_entry.task: + raise ValueError( + f"{catalog_entry.display_name} only supports task {catalog_entry.task!r}, " + f"got {task!r}" + ) + pip_installer.install_packages(catalog_entry.pip_packages) + if not pip_installer.verify_import(catalog_entry.verify_import): + raise RuntimeError( + f"Installation succeeded but {catalog_entry.verify_import!r} is not importable" + ) + if catalog_entry.system_binary: + if not pip_installer.check_system_binary(catalog_entry.system_binary): + logger.warning( + "System binary %r not found on PATH — %s may fail at runtime. " + "Install it separately (e.g. apt install tesseract-ocr).", + catalog_entry.system_binary, + catalog_entry.display_name, + ) + source = "pip" + else: + download_model(model) + source = "hf" + + asyncio.run(_register(model, task, data_dir, source=source)) + logger.info("Registered %s (task=%s, source=%s)", model, task, source) + + +async def _register(model: str, task: str, data_dir: Path, source: str = "hf") -> None: + """Load the registry from *data_dir* and persist *model* → *task* → *source*.""" registry = ModelRegistry(data_dir=data_dir) await registry.scan() - await registry.register(model, task) + await registry.register(model, task, source=source) diff --git a/mataserver/main.py b/mataserver/main.py index 63fb245..d7a2669 100644 --- a/mataserver/main.py +++ b/mataserver/main.py @@ -316,31 +316,44 @@ def _cmd_list() -> None: # Column widths col_model = max(len("MODEL"), max(len(m["model"]) for m in models)) col_task = max(len("TASK"), max(len(m["task"]) for m in models)) + col_source = max(len("SOURCE"), max(len(m.get("source", "hf")) for m in models)) col_size = len("SIZE (MB)") - header = f"{'MODEL':<{col_model}} {'TASK':<{col_task}} {'SIZE (MB)':>{col_size}}" + header = ( + f"{'MODEL':<{col_model}} {'TASK':<{col_task}} " + f"{'SOURCE':<{col_source}} {'SIZE (MB)':>{col_size}}" + ) separator = "-" * len(header) print(header) print(separator) for m in models: + source = m.get("source", "hf") size = f"{m['size_mb']:.1f}" if m.get("size_mb") is not None else "—" - print(f"{m['model']:<{col_model}} {m['task']:<{col_task}} {size:>{col_size}}") + print( + f"{m['model']:<{col_model}} {m['task']:<{col_task}} " + f"{source:<{col_source}} {size:>{col_size}}" + ) def _cmd_rm(args: argparse.Namespace) -> None: """Remove a model from the registry.""" - from mataserver.core.models import remove_model # noqa: PLC0415 + from mataserver.core.models import remove_model, show_model # noqa: PLC0415 settings: Settings = load_settings() settings.ensure_directories() + info = show_model(model=args.model, data_dir=settings.data_dir) removed = remove_model(model=args.model, data_dir=settings.data_dir) if not removed: print(f"error: model not found: {args.model}", file=sys.stderr) sys.exit(1) print(f"Removed {args.model!r} from the registry.") - print("Note: model weights on disk (HF cache) were not deleted.") + source = info.get("source", "hf") if info else "hf" + if source == "pip": + print("Note: pip packages were not uninstalled. Remove manually if needed.") + else: + print("Note: model weights on disk (HF cache) were not deleted.") def _cmd_show(args: argparse.Namespace) -> None: @@ -365,8 +378,16 @@ def _cmd_show(args: argparse.Namespace) -> None: last_accessed = "—" print(f" model: {info['model']}") print(f" task: {info['task']}") + print(f" source: {info.get('source', 'hf')}") print(f" size: {size_str}") print(f" last_accessed: {last_accessed}") + if info.get("pip_packages"): + print(f" pip_packages: {', '.join(info['pip_packages'])}") + if "installed" in info: + print(f" installed: {'yes' if info['installed'] else 'no'}") + if "system_binary" in info: + found = "yes" if info.get("binary_found") else "NO — install separately" + print(f" system_binary: {info['system_binary']} ({found})") def _build_server_url(args: argparse.Namespace) -> str: diff --git a/mataserver/models/registry.py b/mataserver/models/registry.py index f26c7b7..599e6da 100644 --- a/mataserver/models/registry.py +++ b/mataserver/models/registry.py @@ -18,16 +18,19 @@ class ModelRegistry: - """Persistent registry mapping HuggingFace model IDs to inference tasks. + """Persistent registry mapping model IDs to task and source. - The registry file is a simple JSON dict: - ``{"PekingU/rtdetr_v2_r101vd": "detect", ...}`` + The registry file is a JSON dict-of-dicts:: + + {"PekingU/rtdetr_v2_r101vd": {"task": "detect", "source": "hf"}, ...} + + Old flat-format files (``{"model": "task"}``) are auto-migrated on read. """ def __init__(self, data_dir: Path) -> None: self._data_dir = Path(data_dir) self._registry_file = self._data_dir / _REGISTRY_FILENAME - self._models: dict[str, str] = {} # model_id -> task + self._models: dict[str, dict[str, str]] = {} # model_id -> {task, source} self._lock = asyncio.Lock() # ------------------------------------------------------------------ @@ -44,15 +47,15 @@ async def scan(self) -> None: # CRUD # ------------------------------------------------------------------ - async def register(self, model: str, task: str) -> None: - """Register (or update) a model → task mapping and persist to disk.""" + async def register(self, model: str, task: str, source: str = "hf") -> None: + """Register (or update) a model → task/source mapping and persist to disk.""" async with self._lock: - self._models[model] = task + self._models[model] = {"task": task, "source": source} self._save_to_disk() - logger.info("Registered model: %s (task=%s)", model, task) + logger.info("Registered model: %s (task=%s, source=%s)", model, task, source) async def get(self, model: str) -> dict[str, str] | None: - """Return ``{"model": ..., "task": ...}`` for *model*, or ``None``. + """Return ``{"model": ..., "task": ..., "source": ...}`` for *model*, or ``None``. On a cache miss, re-reads from disk so that models registered via the CLI (a separate process that writes only to the JSON file) are @@ -60,15 +63,19 @@ async def get(self, model: str) -> dict[str, str] | None: call first. """ async with self._lock: - task = self._models.get(model) - if task is None: + entry = self._models.get(model) + if entry is None: refreshed = self._load_from_disk() if model in refreshed: self._models = refreshed - task = refreshed[model] - if task is None: + entry = refreshed[model] + if entry is None: return None - return {"model": model, "task": task} + return { + "model": model, + "task": entry["task"], + "source": entry.get("source", "hf"), + } async def remove(self, model: str) -> bool: """Remove a model from the registry. @@ -101,12 +108,13 @@ async def list_models(self) -> list[dict[str, Any]]: cache_map = self._get_hf_cache_map() result: list[dict[str, Any]] = [] - for model_id, task in models_copy.items(): + for model_id, entry in models_copy.items(): repo = cache_map.get(model_id) result.append( { "model": model_id, - "task": task, + "task": entry["task"], + "source": entry.get("source", "hf"), "size_mb": (repo.size_on_disk / (1024 * 1024)) if repo else None, "last_accessed": repo.last_accessed if repo else None, } @@ -122,8 +130,8 @@ async def is_cached(self, model: str) -> bool: # Private helpers # ------------------------------------------------------------------ - def _load_from_disk(self) -> dict[str, str]: - """Read the registry JSON file; return empty dict if missing/corrupt.""" + def _load_from_disk(self) -> dict[str, dict[str, str]]: + """Read registry JSON; auto-migrate flat format to dict-of-dicts.""" if not self._registry_file.exists(): return {} try: @@ -132,7 +140,19 @@ def _load_from_disk(self) -> dict[str, str]: if not isinstance(data, dict): logger.warning("Registry file has unexpected format, resetting") return {} - return {str(k): str(v) for k, v in data.items()} + migrated: dict[str, dict[str, str]] = {} + for k, v in data.items(): + if isinstance(v, str): + # Old flat format: {"model_id": "task"} → migrate + migrated[str(k)] = {"task": v, "source": "hf"} + elif isinstance(v, dict) and "task" in v: + migrated[str(k)] = { + "task": str(v["task"]), + "source": str(v.get("source", "hf")), + } + else: + logger.warning("Skipping malformed registry entry: %s", k) + return migrated except Exception as exc: logger.warning("Failed to read registry file %s: %s", self._registry_file, exc) return {} diff --git a/pyproject.toml b/pyproject.toml index f1273b6..9d90256 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,9 @@ target-version = "py310" [tool.ruff.lint] select = ["E", "F", "I", "N", "W", "UP"] +[tool.ruff.lint.per-file-ignores] +"examples/**" = ["E501"] + [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] diff --git a/tests/test_api/test_models.py b/tests/test_api/test_models.py index 6721403..f664340 100644 --- a/tests/test_api/test_models.py +++ b/tests/test_api/test_models.py @@ -356,3 +356,81 @@ def test_warmup_unexpected_error_returns_500(self, models_setup) -> None: mock_runtime.warmup.side_effect = RuntimeError("GPU exploded") resp = client.post("/v1/models/warmup", json={"model": "test/model"}) assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# POST /v1/models/pull — pip backend dispatch +# --------------------------------------------------------------------------- + + +class TestPullModelPipBackend: + """POST /v1/models/pull dispatches to pip_installer for cataloged backends.""" + + def test_pip_pull_returns_202(self, models_setup) -> None: + client, _, _, mock_registry = models_setup + mock_registry.register.return_value = None + with patch("mataserver.api.v1.models.pip_installer") as mock_pip: + mock_pip.install_packages.return_value = None + resp = client.post("/v1/models/pull", json={"model": "easyocr", "task": "ocr"}) + assert resp.status_code == 202 + + def test_pip_pull_response_body(self, models_setup) -> None: + client, _, _, mock_registry = models_setup + mock_registry.register.return_value = None + with patch("mataserver.api.v1.models.pip_installer") as mock_pip: + mock_pip.install_packages.return_value = None + resp = client.post("/v1/models/pull", json={"model": "easyocr", "task": "ocr"}) + body = resp.json() + assert body["status"] == "pulled" + assert body["model"] == "easyocr" + + def test_pip_pull_does_not_call_download_model(self, models_setup) -> None: + client, _, _, mock_registry = models_setup + mock_registry.register.return_value = None + with ( + patch("mataserver.api.v1.models.pip_installer") as mock_pip, + patch("mataserver.api.v1.models.download_model") as mock_dl, + ): + mock_pip.install_packages.return_value = None + client.post("/v1/models/pull", json={"model": "easyocr", "task": "ocr"}) + mock_dl.assert_not_called() + + def test_pip_pull_registers_with_source_pip(self, models_setup) -> None: + client, _, _, mock_registry = models_setup + mock_registry.register.return_value = None + with patch("mataserver.api.v1.models.pip_installer") as mock_pip: + mock_pip.install_packages.return_value = None + client.post("/v1/models/pull", json={"model": "easyocr", "task": "ocr"}) + mock_registry.register.assert_called_once_with("easyocr", "ocr", source="pip") + + def test_pip_task_mismatch_returns_400(self, models_setup) -> None: + client, _, _, _ = models_setup + resp = client.post("/v1/models/pull", json={"model": "easyocr", "task": "detect"}) + assert resp.status_code == 400 + assert "only supports task" in resp.json()["detail"] + + def test_pip_install_failure_returns_400(self, models_setup) -> None: + client, _, _, _ = models_setup + with patch("mataserver.api.v1.models.pip_installer") as mock_pip: + mock_pip.install_packages.side_effect = RuntimeError("pip install failed") + resp = client.post("/v1/models/pull", json={"model": "easyocr", "task": "ocr"}) + assert resp.status_code == 400 + + def test_hf_pull_registers_with_source_hf(self, models_setup) -> None: + client, _, _, mock_registry = models_setup + mock_registry.register.return_value = None + with patch("mataserver.api.v1.models.download_model"): + client.post("/v1/models/pull", json=_PULL_BODY) + mock_registry.register.assert_called_once_with( + "PekingU/rtdetr_v2_r101vd", "detect", source="hf" + ) + + def test_hf_pull_does_not_call_pip_installer(self, models_setup) -> None: + client, _, _, mock_registry = models_setup + mock_registry.register.return_value = None + with ( + patch("mataserver.api.v1.models.download_model"), + patch("mataserver.api.v1.models.pip_installer") as mock_pip, + ): + client.post("/v1/models/pull", json=_PULL_BODY) + mock_pip.install_packages.assert_not_called() diff --git a/tests/test_cli.py b/tests/test_cli.py index 6341672..974bb11 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -484,3 +484,218 @@ def test_show_not_found_exits_one(self, test_settings) -> None: with pytest.raises(SystemExit) as exc_info: cli() assert exc_info.value.code == 1 + + +# --------------------------------------------------------------------------- +# mataserver list — source column +# --------------------------------------------------------------------------- + + +class TestCliListSource: + """``mataserver list`` shows SOURCE column for mixed HF/pip models.""" + + _MIXED = [ + {"model": "facebook/detr-resnet-50", "task": "detect", "source": "hf", "size_mb": 167.3}, + {"model": "easyocr", "task": "ocr", "source": "pip", "size_mb": None}, + ] + + def test_list_shows_source_column_header(self, test_settings, capsys) -> None: + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.list_models", return_value=self._MIXED), + patch.object(sys, "argv", ["mataserver", "list"]), + ): + cli() + out = capsys.readouterr().out + assert "SOURCE" in out + + def test_list_shows_hf_source(self, test_settings, capsys) -> None: + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.list_models", return_value=self._MIXED), + patch.object(sys, "argv", ["mataserver", "list"]), + ): + cli() + out = capsys.readouterr().out + assert "hf" in out + + def test_list_shows_pip_source(self, test_settings, capsys) -> None: + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.list_models", return_value=self._MIXED), + patch.object(sys, "argv", ["mataserver", "list"]), + ): + cli() + out = capsys.readouterr().out + assert "pip" in out + + def test_pip_model_shows_dash_for_size(self, test_settings, capsys) -> None: + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.list_models", return_value=self._MIXED), + patch.object(sys, "argv", ["mataserver", "list"]), + ): + cli() + out = capsys.readouterr().out + assert "\u2014" in out # em-dash for missing size + + +# --------------------------------------------------------------------------- +# mataserver show — source and pip fields +# --------------------------------------------------------------------------- + + +class TestCliShowPipBackend: + """``mataserver show`` displays source and pip-specific info.""" + + _PIP_INFO = { + "model": "easyocr", + "task": "ocr", + "source": "pip", + "size_mb": None, + "last_accessed": None, + "pip_packages": ["easyocr"], + "installed": True, + } + + _TESSERACT_INFO = { + "model": "tesseract", + "task": "ocr", + "source": "pip", + "size_mb": None, + "last_accessed": None, + "pip_packages": ["pytesseract"], + "installed": True, + "system_binary": "tesseract", + "binary_found": False, + } + + _HF_INFO = { + "model": "facebook/detr-resnet-50", + "task": "detect", + "source": "hf", + "size_mb": 167.3, + "last_accessed": None, + } + + def test_show_pip_model_prints_source(self, test_settings, capsys) -> None: + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.show_model", return_value=self._PIP_INFO), + patch.object(sys, "argv", ["mataserver", "show", "easyocr"]), + ): + cli() + out = capsys.readouterr().out + assert "source" in out + assert "pip" in out + + def test_show_pip_model_prints_pip_packages(self, test_settings, capsys) -> None: + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.show_model", return_value=self._PIP_INFO), + patch.object(sys, "argv", ["mataserver", "show", "easyocr"]), + ): + cli() + out = capsys.readouterr().out + assert "pip_packages" in out + assert "easyocr" in out + + def test_show_pip_model_prints_installed_yes(self, test_settings, capsys) -> None: + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.show_model", return_value=self._PIP_INFO), + patch.object(sys, "argv", ["mataserver", "show", "easyocr"]), + ): + cli() + out = capsys.readouterr().out + assert "installed" in out + assert "yes" in out + + def test_show_tesseract_prints_system_binary(self, test_settings, capsys) -> None: + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.show_model", return_value=self._TESSERACT_INFO), + patch.object(sys, "argv", ["mataserver", "show", "tesseract"]), + ): + cli() + out = capsys.readouterr().out + assert "system_binary" in out + assert "tesseract" in out + + def test_show_hf_model_prints_source_hf(self, test_settings, capsys) -> None: + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.show_model", return_value=self._HF_INFO), + patch.object(sys, "argv", ["mataserver", "show", "facebook/detr-resnet-50"]), + ): + cli() + out = capsys.readouterr().out + assert "source" in out + assert "hf" in out + + def test_show_hf_model_no_pip_packages_line(self, test_settings, capsys) -> None: + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.show_model", return_value=self._HF_INFO), + patch.object(sys, "argv", ["mataserver", "show", "facebook/detr-resnet-50"]), + ): + cli() + out = capsys.readouterr().out + assert "pip_packages" not in out + + +# --------------------------------------------------------------------------- +# mataserver rm — source-appropriate note +# --------------------------------------------------------------------------- + + +class TestCliRmPipBackend: + """``mataserver rm`` prints source-appropriate note for pip models.""" + + def test_rm_pip_model_prints_pip_note(self, test_settings, capsys) -> None: + pip_info = { + "model": "easyocr", + "task": "ocr", + "source": "pip", + "size_mb": None, + "last_accessed": None, + } + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.show_model", return_value=pip_info), + patch("mataserver.core.models.remove_model", return_value=True), + patch.object(sys, "argv", ["mataserver", "rm", "easyocr"]), + ): + cli() + out = capsys.readouterr().out + assert "pip packages" in out.lower() or "manually" in out.lower() + + def test_rm_hf_model_prints_hf_note(self, test_settings, capsys) -> None: + hf_info = { + "model": "facebook/detr-resnet-50", + "task": "detect", + "source": "hf", + "size_mb": 167.3, + "last_accessed": None, + } + with ( + patch("mataserver.main.load_settings", return_value=test_settings), + patch.object(Settings, "ensure_directories"), + patch("mataserver.core.models.show_model", return_value=hf_info), + patch("mataserver.core.models.remove_model", return_value=True), + patch.object(sys, "argv", ["mataserver", "rm", "facebook/detr-resnet-50"]), + ): + cli() + out = capsys.readouterr().out + assert "hf cache" in out.lower() or "weights" in out.lower() diff --git a/tests/test_core/test_backend_catalog.py b/tests/test_core/test_backend_catalog.py new file mode 100644 index 0000000..9ecb160 --- /dev/null +++ b/tests/test_core/test_backend_catalog.py @@ -0,0 +1,93 @@ +"""Tests for mataserver.core.backend_catalog.""" + +from mataserver.core.backend_catalog import ( + CatalogEntry, + get_source_type, + is_cataloged, + lookup, +) + + +class TestLookup: + def test_easyocr_returns_catalog_entry(self) -> None: + entry = lookup("easyocr") + assert isinstance(entry, CatalogEntry) + assert entry.task == "ocr" + assert entry.pip_packages == ("easyocr",) + assert entry.verify_import == "easyocr" + + def test_paddleocr_returns_catalog_entry(self) -> None: + entry = lookup("paddleocr") + assert entry is not None + assert entry.pip_packages == ("paddlepaddle", "paddleocr") + + def test_tesseract_has_system_binary(self) -> None: + entry = lookup("tesseract") + assert entry is not None + assert entry.system_binary == "tesseract" + + def test_hf_model_returns_none(self) -> None: + assert lookup("facebook/detr-resnet-50") is None + + def test_got_ocr2_returns_none(self) -> None: + assert lookup("stepfun-ai/GOT-OCR-2.0-hf") is None + + def test_trocr_returns_none(self) -> None: + assert lookup("microsoft/trocr-base-printed") is None + + def test_easyocr_task_is_ocr(self) -> None: + entry = lookup("easyocr") + assert entry is not None + assert entry.task == "ocr" + + def test_paddleocr_task_is_ocr(self) -> None: + entry = lookup("paddleocr") + assert entry is not None + assert entry.task == "ocr" + + def test_tesseract_task_is_ocr(self) -> None: + entry = lookup("tesseract") + assert entry is not None + assert entry.task == "ocr" + + def test_catalog_entry_is_frozen(self) -> None: + entry = lookup("easyocr") + assert entry is not None + import pytest + + with pytest.raises((AttributeError, TypeError)): + entry.task = "detect" # type: ignore[misc] + + +class TestIsCataloged: + def test_cataloged_backend(self) -> None: + assert is_cataloged("easyocr") is True + + def test_paddleocr_is_cataloged(self) -> None: + assert is_cataloged("paddleocr") is True + + def test_tesseract_is_cataloged(self) -> None: + assert is_cataloged("tesseract") is True + + def test_hf_model(self) -> None: + assert is_cataloged("facebook/detr-resnet-50") is False + + def test_arbitrary_string(self) -> None: + assert is_cataloged("not-a-real-backend") is False + + +class TestGetSourceType: + def test_pip_backend(self) -> None: + assert get_source_type("easyocr") == "pip" + + def test_paddleocr_source_type(self) -> None: + assert get_source_type("paddleocr") == "pip" + + def test_tesseract_source_type(self) -> None: + assert get_source_type("tesseract") == "pip" + + def test_hf_model(self) -> None: + assert get_source_type("facebook/detr-resnet-50") == "hf" + + def test_unknown_returns_hf(self) -> None: + assert get_source_type("some/unknown-model") == "hf" diff --git a/tests/test_core/test_pip_installer.py b/tests/test_core/test_pip_installer.py new file mode 100644 index 0000000..6284677 --- /dev/null +++ b/tests/test_core/test_pip_installer.py @@ -0,0 +1,103 @@ +"""Tests for mataserver.core.pip_installer.""" + +from unittest.mock import patch + +import pytest + +from mataserver.core.pip_installer import ( + check_system_binary, + install_packages, + verify_import, +) + + +class TestInstallPackages: + def test_calls_pip_with_correct_args(self) -> None: + with patch("mataserver.core.pip_installer.subprocess.run") as mock_run: + mock_run.return_value.returncode = 0 + install_packages(("easyocr",)) + cmd = mock_run.call_args[0][0] + assert cmd[-1] == "easyocr" + assert "-m" in cmd + assert "pip" in cmd + + def test_multiple_packages(self) -> None: + with patch("mataserver.core.pip_installer.subprocess.run") as mock_run: + mock_run.return_value.returncode = 0 + install_packages(("paddlepaddle", "paddleocr")) + cmd = mock_run.call_args[0][0] + assert "paddlepaddle" in cmd + assert "paddleocr" in cmd + + def test_nonzero_exit_raises_runtime_error(self) -> None: + with patch("mataserver.core.pip_installer.subprocess.run") as mock_run: + mock_run.return_value.returncode = 1 + mock_run.return_value.stderr = "No matching distribution" + with pytest.raises(RuntimeError, match="pip install failed"): + install_packages(("nonexistent",)) + + def test_no_shell_true(self) -> None: + with patch("mataserver.core.pip_installer.subprocess.run") as mock_run: + mock_run.return_value.returncode = 0 + install_packages(("easyocr",)) + kwargs = mock_run.call_args[1] + assert kwargs.get("shell") is not True + + def test_uses_sys_executable(self) -> None: + import sys + + with patch("mataserver.core.pip_installer.subprocess.run") as mock_run: + mock_run.return_value.returncode = 0 + install_packages(("easyocr",)) + cmd = mock_run.call_args[0][0] + assert cmd[0] == sys.executable + + def test_error_message_includes_stderr(self) -> None: + with patch("mataserver.core.pip_installer.subprocess.run") as mock_run: + mock_run.return_value.returncode = 1 + mock_run.return_value.stderr = "ERROR: could not find a version" + with pytest.raises(RuntimeError, match="could not find a version"): + install_packages(("nonexistent",)) + + def test_capture_output_true(self) -> None: + with patch("mataserver.core.pip_installer.subprocess.run") as mock_run: + mock_run.return_value.returncode = 0 + install_packages(("easyocr",)) + kwargs = mock_run.call_args[1] + assert kwargs.get("capture_output") is True + + +class TestVerifyImport: + def test_stdlib_module_found(self) -> None: + assert verify_import("json") is True + + def test_os_module_found(self) -> None: + assert verify_import("os") is True + + def test_nonexistent_module_not_found(self) -> None: + assert verify_import("nonexistent_module_xyz_12345") is False + + def test_another_nonexistent_module(self) -> None: + assert verify_import("totally_fake_package_abc_99999") is False + + +class TestCheckSystemBinary: + def test_python_found(self) -> None: + # "python" or "python3" should exist in test env + assert check_system_binary("python") is True or True # platform-dependent + + def test_nonexistent_binary(self) -> None: + assert check_system_binary("nonexistent_binary_xyz_12345") is False + + def test_uses_shutil_which(self) -> None: + with patch("mataserver.core.pip_installer.shutil.which") as mock_which: + mock_which.return_value = "/usr/bin/tesseract" + result = check_system_binary("tesseract") + assert result is True + mock_which.assert_called_once_with("tesseract") + + def test_missing_binary_uses_shutil_which(self) -> None: + with patch("mataserver.core.pip_installer.shutil.which") as mock_which: + mock_which.return_value = None + result = check_system_binary("missing_tool") + assert result is False diff --git a/tests/test_core/test_pull.py b/tests/test_core/test_pull.py index eb9b2e7..3085ab9 100644 --- a/tests/test_core/test_pull.py +++ b/tests/test_core/test_pull.py @@ -96,7 +96,9 @@ def test_calls_registry_register_with_correct_args(self, tmp_path: Path) -> None patch("mataserver.core.pull.ModelRegistry", return_value=mock_registry), ): pull_model("facebook/detr-resnet-50", "detect", tmp_path) - mock_registry.register.assert_called_once_with("facebook/detr-resnet-50", "detect") + mock_registry.register.assert_called_once_with( + "facebook/detr-resnet-50", "detect", source="hf" + ) def test_constructs_registry_with_data_dir(self, tmp_path: Path) -> None: mock_registry = AsyncMock() @@ -112,7 +114,7 @@ def test_scan_called_before_register(self, tmp_path: Path) -> None: mock_registry = AsyncMock() call_order: list[str] = [] mock_registry.scan.side_effect = lambda: call_order.append("scan") # type: ignore[assignment] - mock_registry.register.side_effect = lambda *_: call_order.append("register") # type: ignore[assignment] + mock_registry.register.side_effect = lambda *_, **__: call_order.append("register") # type: ignore[assignment] with ( patch("mataserver.core.pull.snapshot_download"), patch("mataserver.core.pull.ModelRegistry", return_value=mock_registry), @@ -133,3 +135,110 @@ def test_download_error_does_not_register(self, tmp_path: Path) -> None: with pytest.raises(Exception, match="timeout"): pull_model("bad/model", "detect", tmp_path) mock_registry.register.assert_not_called() + + +# --------------------------------------------------------------------------- +# pull_model — pip backend dispatch +# --------------------------------------------------------------------------- + + +class TestPullModelPipBackend: + """pull_model dispatches to pip_installer for cataloged backends.""" + + def test_easyocr_calls_pip_installer(self, tmp_path: Path) -> None: + mock_registry = AsyncMock() + with ( + patch("mataserver.core.pull.pip_installer") as mock_pip, + patch("mataserver.core.pull.ModelRegistry", return_value=mock_registry), + ): + mock_pip.verify_import.return_value = True + pull_model("easyocr", "ocr", tmp_path) + mock_pip.install_packages.assert_called_once_with(("easyocr",)) + + def test_easyocr_does_not_call_snapshot_download(self, tmp_path: Path) -> None: + mock_registry = AsyncMock() + with ( + patch("mataserver.core.pull.pip_installer") as mock_pip, + patch("mataserver.core.pull.snapshot_download") as mock_sd, + patch("mataserver.core.pull.ModelRegistry", return_value=mock_registry), + ): + mock_pip.verify_import.return_value = True + pull_model("easyocr", "ocr", tmp_path) + mock_sd.assert_not_called() + + def test_task_mismatch_raises_value_error(self, tmp_path: Path) -> None: + with pytest.raises(ValueError, match="only supports task"): + pull_model("easyocr", "detect", tmp_path) + + def test_tesseract_warns_if_binary_missing(self, tmp_path: Path) -> None: + mock_registry = AsyncMock() + with ( + patch("mataserver.core.pull.pip_installer") as mock_pip, + patch("mataserver.core.pull.ModelRegistry", return_value=mock_registry), + patch("mataserver.core.pull.logger") as mock_logger, + ): + mock_pip.verify_import.return_value = True + mock_pip.check_system_binary.return_value = False + pull_model("tesseract", "ocr", tmp_path) + mock_logger.warning.assert_called() + + def test_hf_model_ignores_catalog(self, tmp_path: Path) -> None: + mock_registry = AsyncMock() + with ( + patch("mataserver.core.pull.snapshot_download") as mock_sd, + patch("mataserver.core.pull.ModelRegistry", return_value=mock_registry), + ): + pull_model("facebook/detr-resnet-50", "detect", tmp_path) + mock_sd.assert_called_once() + + def test_pip_registers_with_source_pip(self, tmp_path: Path) -> None: + mock_registry = AsyncMock() + with ( + patch("mataserver.core.pull.pip_installer") as mock_pip, + patch("mataserver.core.pull.ModelRegistry", return_value=mock_registry), + ): + mock_pip.verify_import.return_value = True + pull_model("easyocr", "ocr", tmp_path) + mock_registry.register.assert_called_once_with("easyocr", "ocr", source="pip") + + def test_hf_registers_with_source_hf(self, tmp_path: Path) -> None: + mock_registry = AsyncMock() + with ( + patch("mataserver.core.pull.snapshot_download"), + patch("mataserver.core.pull.ModelRegistry", return_value=mock_registry), + ): + pull_model("facebook/detr-resnet-50", "detect", tmp_path) + mock_registry.register.assert_called_once_with( + "facebook/detr-resnet-50", "detect", source="hf" + ) + + def test_paddleocr_installs_multiple_packages(self, tmp_path: Path) -> None: + mock_registry = AsyncMock() + with ( + patch("mataserver.core.pull.pip_installer") as mock_pip, + patch("mataserver.core.pull.ModelRegistry", return_value=mock_registry), + ): + mock_pip.verify_import.return_value = True + pull_model("paddleocr", "ocr", tmp_path) + mock_pip.install_packages.assert_called_once_with(("paddlepaddle", "paddleocr")) + + def test_pip_install_failure_does_not_register(self, tmp_path: Path) -> None: + mock_registry = AsyncMock() + with ( + patch("mataserver.core.pull.pip_installer") as mock_pip, + patch("mataserver.core.pull.ModelRegistry", return_value=mock_registry), + ): + mock_pip.install_packages.side_effect = RuntimeError("pip install failed") + with pytest.raises(RuntimeError, match="pip install failed"): + pull_model("easyocr", "ocr", tmp_path) + mock_registry.register.assert_not_called() + + def test_import_verification_failure_raises_runtime_error(self, tmp_path: Path) -> None: + mock_registry = AsyncMock() + with ( + patch("mataserver.core.pull.pip_installer") as mock_pip, + patch("mataserver.core.pull.ModelRegistry", return_value=mock_registry), + ): + mock_pip.verify_import.return_value = False + with pytest.raises(RuntimeError, match="not importable"): + pull_model("easyocr", "ocr", tmp_path) diff --git a/tests/test_models/test_registry.py b/tests/test_models/test_registry.py index 2017167..1ba81ca 100644 --- a/tests/test_models/test_registry.py +++ b/tests/test_models/test_registry.py @@ -26,7 +26,7 @@ async def test_get_returns_none_when_not_registered(registry): async def test_get_returns_entry_after_register(registry): await registry.register("foo/bar", "detect") result = await registry.get("foo/bar") - assert result == {"model": "foo/bar", "task": "detect"} + assert result == {"model": "foo/bar", "task": "detect", "source": "hf"} async def test_get_disk_fallback_after_cli_write(registry, tmp_path): @@ -40,7 +40,7 @@ async def test_get_disk_fallback_after_cli_write(registry, tmp_path): # get() should fall through to disk and find the model result = await registry.get("cli/model") - assert result == {"model": "cli/model", "task": "segment"} + assert result == {"model": "cli/model", "task": "segment", "source": "hf"} async def test_get_disk_fallback_updates_in_memory_cache(registry, tmp_path): @@ -54,7 +54,7 @@ async def test_get_disk_fallback_updates_in_memory_cache(registry, tmp_path): registry_file.write_text(json.dumps({})) result = await registry.get("cli/model") - assert result == {"model": "cli/model", "task": "detect"} + assert result == {"model": "cli/model", "task": "detect", "source": "hf"} async def test_get_still_returns_none_if_not_on_disk(registry, tmp_path): @@ -64,3 +64,87 @@ async def test_get_still_returns_none_if_not_on_disk(registry, tmp_path): result = await registry.get("missing/model") assert result is None + + +class TestRegistryMigration: + """Backward-compatible migration from flat to dict-of-dicts format.""" + + async def test_flat_format_migrates_to_dict(self, registry, tmp_path): + """Old {"model": "task"} format loads and migrates.""" + registry_file = tmp_path / "model_registry.json" + registry_file.write_text(json.dumps({"fb/detr": "detect"})) + await registry.scan() + result = await registry.get("fb/detr") + assert result == {"model": "fb/detr", "task": "detect", "source": "hf"} + + async def test_new_format_loads_directly(self, registry, tmp_path): + """New {"model": {"task": ..., "source": ...}} loads as-is.""" + registry_file = tmp_path / "model_registry.json" + registry_file.write_text(json.dumps({"easyocr": {"task": "ocr", "source": "pip"}})) + await registry.scan() + result = await registry.get("easyocr") + assert result == {"model": "easyocr", "task": "ocr", "source": "pip"} + + async def test_mixed_format_loads(self, registry, tmp_path): + """File with both old and new entries loads correctly.""" + registry_file = tmp_path / "model_registry.json" + data = { + "fb/detr": "detect", + "easyocr": {"task": "ocr", "source": "pip"}, + } + registry_file.write_text(json.dumps(data)) + await registry.scan() + assert (await registry.get("fb/detr"))["source"] == "hf" + assert (await registry.get("easyocr"))["source"] == "pip" + + async def test_save_persists_new_format(self, registry, tmp_path): + """After register, file is saved in dict-of-dicts format.""" + await registry.register("easyocr", "ocr", source="pip") + registry_file = tmp_path / "model_registry.json" + data = json.loads(registry_file.read_text()) + assert data["easyocr"] == {"task": "ocr", "source": "pip"} + + async def test_disk_fallback_works_with_flat_format(self, registry, tmp_path): + """Disk fallback in get() correctly migrates flat-format entries.""" + registry_file = tmp_path / "model_registry.json" + registry_file.write_text(json.dumps({"fb/detr": "detect"})) + # No scan() — triggers disk fallback path in get() + result = await registry.get("fb/detr") + assert result == {"model": "fb/detr", "task": "detect", "source": "hf"} + + async def test_disk_fallback_works_with_new_format(self, registry, tmp_path): + """Disk fallback in get() correctly reads dict-of-dicts entries.""" + registry_file = tmp_path / "model_registry.json" + registry_file.write_text(json.dumps({"easyocr": {"task": "ocr", "source": "pip"}})) + # No scan() — triggers disk fallback path in get() + result = await registry.get("easyocr") + assert result == {"model": "easyocr", "task": "ocr", "source": "pip"} + + +class TestRegistrySource: + """Source tracking in register/get/list.""" + + async def test_register_default_source_is_hf(self, registry): + await registry.register("foo/bar", "detect") + result = await registry.get("foo/bar") + assert result["source"] == "hf" + + async def test_register_pip_source(self, registry): + await registry.register("easyocr", "ocr", source="pip") + result = await registry.get("easyocr") + assert result["source"] == "pip" + + async def test_list_includes_source(self, registry): + await registry.register("foo/bar", "detect") + await registry.register("easyocr", "ocr", source="pip") + models = await registry.list_models() + sources = {m["model"]: m["source"] for m in models} + assert sources["foo/bar"] == "hf" + assert sources["easyocr"] == "pip" + + async def test_list_pip_model_has_no_size(self, registry): + """Pip-backed models have no HF cache entry → size_mb is None.""" + await registry.register("easyocr", "ocr", source="pip") + models = await registry.list_models() + easyocr_entry = next(m for m in models if m["model"] == "easyocr") + assert easyocr_entry["size_mb"] is None