From 18053dfb98646ede0ccf3cdd1b9608b50313e44e Mon Sep 17 00:00:00 2001 From: Bimantoro Maesa Date: Mon, 9 Mar 2026 19:04:07 +0700 Subject: [PATCH 1/2] Implement Valkey/Redis exporter and graph nodes & Appearance-Based ReID Tracking. - Implement comprehensive tests for the Valkey/Redis exporter, covering export and load functionalities for various result types including VisionResult, ClassifyResult, DepthResult, and OCRResult. - Introduce tests for ValkeyStore and ValkeyLoad graph nodes, ensuring correct artifact handling, key templating, and integration within graph execution. - Validate serialization and deserialization processes, including support for custom serializers and TTL parameters. - Ensure robust error handling for missing keys and unsupported result types. --- .github/copilot-instructions.md | 51 +- CHANGELOG.md | 67 ++ QUICKSTART.md | 69 ++ QUICK_REFERENCE.md | 115 ++- README.md | 71 +- docs/GRAPH_API_REFERENCE.md | 279 +++++++- docs/VALIDATION_GUIDE.md | 123 +++- docs/VALKEY_GUIDE.md | 791 +++++++++++++++++++++ examples/graph/README.md | 17 +- examples/graph/valkey_pipeline.py | 396 +++++++++++ examples/graph/valkey_rtsp_pipeline.py | 173 +++++ examples/track/cross_camera_reid.py | 320 +++++++++ examples/track/reid_tracking.py | 254 +++++++ pyproject.toml | 14 +- src/mata/__init__.py | 2 +- src/mata/adapters/__init__.py | 4 + src/mata/adapters/reid_adapter.py | 345 +++++++++ src/mata/adapters/tracking_adapter.py | 96 ++- src/mata/api.py | 23 +- src/mata/core/exporters/__init__.py | 4 + src/mata/core/exporters/valkey_exporter.py | 245 +++++++ src/mata/core/graph/validator.py | 5 + src/mata/core/model_loader.py | 99 ++- src/mata/core/model_registry.py | 43 ++ src/mata/core/types.py | 54 ++ src/mata/nodes/__init__.py | 7 + src/mata/nodes/filter.py | 14 +- src/mata/nodes/valkey_load.py | 99 +++ src/mata/nodes/valkey_store.py | 133 ++++ src/mata/trackers/__init__.py | 3 + src/mata/trackers/bot_sort.py | 4 +- src/mata/trackers/byte_tracker.py | 13 +- src/mata/trackers/configs/botsort.yaml | 24 +- src/mata/trackers/reid_bridge.py | 179 +++++ tests/test_reid_adapter.py | 694 ++++++++++++++++++ tests/test_reid_bridge.py | 491 +++++++++++++ tests/test_tracking_reid.py | 778 ++++++++++++++++++++ tests/test_transformation_nodes.py | 4 +- tests/test_universal_loader.py | 16 +- tests/test_valkey_config.py | 285 ++++++++ tests/test_valkey_exporter.py | 672 +++++++++++++++++ tests/test_valkey_nodes.py | 442 ++++++++++++ 42 files changed, 7434 insertions(+), 84 deletions(-) create mode 100644 docs/VALKEY_GUIDE.md create mode 100644 examples/graph/valkey_pipeline.py create mode 100644 examples/graph/valkey_rtsp_pipeline.py create mode 100644 examples/track/cross_camera_reid.py create mode 100644 examples/track/reid_tracking.py create mode 100644 src/mata/adapters/reid_adapter.py create mode 100644 src/mata/core/exporters/valkey_exporter.py create mode 100644 src/mata/nodes/valkey_load.py create mode 100644 src/mata/nodes/valkey_store.py create mode 100644 src/mata/trackers/reid_bridge.py create mode 100644 tests/test_reid_adapter.py create mode 100644 tests/test_reid_bridge.py create mode 100644 tests/test_tracking_reid.py create mode 100644 tests/test_valkey_config.py create mode 100644 tests/test_valkey_exporter.py create mode 100644 tests/test_valkey_nodes.py diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index cdc6106..bca6712 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -2,7 +2,7 @@ ## Architecture Overview -MATA is a **task-centric, model-agnostic** computer vision framework with a llama.cpp-inspired universal loader. As of v1.9.0, it features a unified adapter system supporting multiple tasks and runtimes, plus a fully vendored ByteTrack/BotSort tracking system and an OCR evaluation pipeline. +MATA is a **task-centric, model-agnostic** computer vision framework with a llama.cpp-inspired universal loader. As of v1.9.2, it features a unified adapter system supporting multiple tasks and runtimes, a fully vendored ByteTrack/BotSort tracking system with appearance-based ReID (single-camera and cross-camera via Valkey), and an OCR evaluation pipeline. **Universal Loading (v1.5.2+):** @@ -12,9 +12,11 @@ mata.load("classify", "./model.onnx") # Local ONNX file mata.load("segment", "fast-model") # Config alias mata.load("depth", "depth-anything/Depth-Anything-V2-Small-hf") mata.load("track", "facebook/detr-resnet-50", tracker="botsort") # 🆕 v1.8.0 +mata.load("track", "facebook/detr-resnet-50", tracker="botsort", + reid_model="openai/clip-vit-base-patch32") # 🆕 v1.9.2 ReID ``` -**Object Tracking (v1.8.0):** +**Object Tracking (v1.8.0+):** ```python # One-liner video/stream tracking @@ -31,6 +33,18 @@ for result in mata.track("rtsp://cam/stream", # Persistent per-frame tracking tracker = mata.load("track", "facebook/detr-resnet-50", tracker="bytetrack") result = tracker.update(frame, persist=True) # YOLO-like pattern + +# Appearance-based ReID (v1.9.2+) — BotSort only +results = mata.track("video.mp4", + model="facebook/detr-resnet-50", + reid_model="openai/clip-vit-base-patch32") + +# Cross-camera ReID via Valkey (v1.9.2+) +from mata.trackers import ReIDBridge +bridge = ReIDBridge("valkey://localhost:6379", camera_id="cam-1") +results = mata.track("rtsp://cam/stream", model="...", + reid_model="openai/clip-vit-base-patch32", + reid_bridge=bridge, stream=True) ``` **Zero-Shot Capabilities:** @@ -56,18 +70,24 @@ Task Adapters (HuggingFace/ONNX/TorchScript/PyTorch) VisionResult (Unified result: bbox + mask + track_id + embedding) ↓ Runtime (PyTorch/ONNX Runtime/TorchScript) - ↓ ↓ ↓ -Export System Tracking Layer Evaluation Layer (v1.8.1+) -(JSON/CSV/ (v1.8.0) ↓ -Image/Crops) ↓ Validator (detect/segment/classify/depth/ocr) - TrackingAdapter ↓ - ↓ Metrics (DetMetrics/SegMetrics/ClassifyMetrics/ - Vendored Trackers DepthMetrics/OCRMetrics) ← v1.9.0 - (BYTETracker/ ↓ - BOTSORT) Printer + DatasetLoader - ↓ (no external dep) (COCO/COCO-Text JSON) - KalmanFilter + IoU + ↓ ↓ ↓ +Export System Tracking Layer Evaluation Layer (v1.8.1+) +(JSON/CSV/ (v1.8.0) ↓ +Image/Crops) ↓ Validator (detect/segment/classify/depth/ocr) + TrackingAdapter ↓ + ↓ Metrics (DetMetrics/SegMetrics/ClassifyMetrics/ + Vendored Trackers DepthMetrics/OCRMetrics) ← v1.9.0 + (BYTETracker/BOTSORT) ↓ + ↓ (no external dep) Printer + DatasetLoader + KalmanFilter + IoU (COCO/COCO-Text JSON) matching + GMC + ↓ + ReID Layer (🆕 v1.9.2) + ↓ ↓ + ReIDAdapter ReIDBridge (cross-camera) + ↓ ↓ + HuggingFace Valkey embedding store + ONNX (publish/query/TTL eviction) ``` **Key Design Pattern:** Task contracts over model specifics - all adapters implement the same `predict()` interface returning task-specific results (VisionResult for detect/segment, ClassifyResult, DepthResult). @@ -150,6 +170,11 @@ pytest tests/test_tracking_visualization.py -v # Visualization/export (103 tests pytest tests/test_video_io.py -v # Video I/O utilities (56 tests) pytest tests/test_track_node.py -v # Track graph node (39 tests) +# ReID test suites (v1.9.2) +pytest tests/test_reid_adapter.py -v # ReID adapter unit tests (40+ tests) +pytest tests/test_tracking_reid.py -v # TrackingAdapter + API integration (25+ tests) +pytest tests/test_reid_bridge.py -v # Cross-camera bridge (15+ tests) + # VLM tool-calling test suites (v1.7.0) pytest tests/test_tool_schema.py -v # Tool schema (33 tests) pytest tests/test_tool_registry.py -v # Tool registry (49 tests) diff --git a/CHANGELOG.md b/CHANGELOG.md index 475ab5d..8241115 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,73 @@ Versions follow [Semantic Versioning](https://semver.org/). --- +## [1.9.2] Beta Release - 2026-03-09 + +### Added + +**Valkey / Redis Graph Pipeline Storage** + +- `export_valkey(result, url, key, ttl, serializer)` — serializes any MATA result type to a Valkey or Redis key with optional TTL; supports `json` (default) and `msgpack` serializers +- `load_valkey(url, key, result_type="auto")` — deserializes a stored result back to the original type; auto-detects `VisionResult`, `ClassifyResult`, `DepthResult`, and `OCRResult` from stored payload structure +- `publish_valkey(result, url, channel, serializer)` — fire-and-forget Pub/Sub broadcast; returns subscriber count +- `_parse_valkey_uri()` helper supporting `valkey://host:port/key`, `valkey://host:port/db/key`, and `redis://user:pass@host:port/db/key` formats +- `ValkeyStore` graph sink node — pass-through sink that writes an artifact to Valkey during graph execution; supports `{node}` and `{timestamp}` key template placeholders, TTL, and serializer selection +- `ValkeyLoad` graph source node — source node with `inputs={}` that loads a stored result from Valkey and injects it into the graph as a typed artifact +- `valkey://` and `redis://` URI scheme dispatch added to all six `result.save()` methods (`VisionResult`, `DetectResult`, `SegmentResult`, `ClassifyResult`, `DepthResult`, `OCRResult`) — existing file-based paths are fully unaffected +- `ModelRegistry.get_valkey_connection(name="default")` — reads named Valkey connection profiles from the `storage.valkey` section of `.mata/models.yaml` or `~/.mata/models.yaml`; resolves `password_env` from environment variables; raises `ModelNotFoundError` for unknown connection names +- YAML `storage.valkey.` config schema with `url`, `db`, `ttl`, `password_env`, and `tls` fields +- Optional dependency groups: `pip install mata[valkey]` → `valkey>=6.0.0`; `pip install mata[redis]` → `redis>=5.0.0`; both added to the `dev` extras group +- `export_valkey`, `load_valkey`, and `publish_valkey` exported from `mata.core.exporters` +- `ValkeyStore` and `ValkeyLoad` exported from `mata.nodes` +- 89 new tests: 42 exporter tests (`test_valkey_exporter.py`), 33 graph node tests (`test_valkey_nodes.py`), 14 config and pub/sub tests (`test_valkey_config.py`) + +**Documentation** + +- `docs/VALKEY_GUIDE.md` — full integration guide covering installation, basic usage, graph pipeline integration, YAML configuration, streaming patterns, Pub/Sub architecture, security (TLS, `password_env`, SSRF prevention, key sanitization), performance tuning (serializer choice, TTL strategies, connection pooling, async patterns), and top-5 troubleshooting issues +- `docs/GRAPH_API_REFERENCE.md` — new "Storage Nodes" section with full parameter tables for `ValkeyStore` and `ValkeyLoad` +- `README.md` — Valkey added to Key Features list and Optional Dependencies table +- `QUICKSTART.md` — new "Valkey / Redis Result Storage" section with annotated code examples +- `QUICK_REFERENCE.md` — new "Valkey/Redis Storage Quick Reference (v1.9)" section with cheatsheet + +**Appearance-Based ReID Tracking** + +- `mata.track(..., reid_model="org/model")` — activate appearance-based re-identification for BotSort by supplying any HuggingFace image encoder ID or local `.onnx` path +- `ReIDAdapter` — abstract base class for appearance feature extractors; L2-normalised embedding output; lazy-loaded to keep startup cost zero when ReID is unused +- `HuggingFaceReIDAdapter` — ViT / CLIP / AutoModel architecture auto-detection (CLIP image encoder, ViT/DeiT/Swin/BEiT pooler output, generic AutoModel mean-pooling); all `transformers` imports lazy +- `ONNXReIDAdapter` — ONNX Runtime ReID extractor; auto-detects NCHW/NHWC input layout from model metadata; supports CPU and CUDA execution providers +- `TrackingAdapter.update()` now extracts detection crops, batch-encodes them through the ReID encoder, and injects embeddings into `BOTSORT` — activating the appearance distance branch in `get_dists()`; `Instance.embedding` populated in output `VisionResult` +- `mata.track()` extended with `reid_model: str | None` and `with_reid: bool = False` kwargs; `with_reid=True` without `reid_model` raises `ValueError` +- Config alias support: `reid_model` and `with_reid` keys can be declared in `.mata/models.yaml` under a `track:` alias; runtime kwargs always take precedence +- `ReIDBridge` — cross-camera appearance store backed by Valkey/Redis; publishes L2-normalised embeddings keyed by `reid:{camera_id}:{track_id}`; `query()` returns nearest matches above cosine-similarity threshold from other cameras; uses `scan_iter` (production-safe, non-blocking); TTL-based auto-eviction; `msgpack` binary serialisation +- `TrackingAdapter.__init__()` extended with `reid_bridge: ReIDBridge | None`; after each `update()` confirmed tracks with embeddings are published automatically; `ConnectionError` caught and logged, never raised +- `mata.track()` / `mata.load("track", ...)` extended with `reid_bridge` kwarg; forwarded to `TrackingAdapter` +- `ReIDAdapter`, `HuggingFaceReIDAdapter`, `ONNXReIDAdapter` exported from `mata.adapters` +- `ReIDBridge` exported from `mata.trackers` +- `src/mata/trackers/configs/botsort.yaml` — commented `reid_model` / `with_reid` documentation block added (v1.9.2+) +- 80+ new tests: `test_reid_adapter.py` (ReID adapter unit tests), `test_tracking_reid.py` (TrackingAdapter + API integration), `test_reid_bridge.py` (cross-camera bridge) +- `examples/track/reid_tracking.py` — basic single-camera ReID tracking example script +- `examples/track/cross_camera_reid.py` — cross-camera ReID via Valkey example script + +**Documentation** + +- `docs/VALKEY_GUIDE.md` — full integration guide covering installation, basic usage, graph pipeline integration, YAML configuration, streaming patterns, Pub/Sub architecture, security (TLS, `password_env`, SSRF prevention, key sanitization), performance tuning (serializer choice, TTL strategies, connection pooling, async patterns), and top-5 troubleshooting issues +- `docs/GRAPH_API_REFERENCE.md` — new "Storage Nodes" section with full parameter tables for `ValkeyStore` and `ValkeyLoad` +- `README.md` — Valkey added to Key Features list and Optional Dependencies table; ReID tracking section added with single-camera and cross-camera usage examples +- `QUICKSTART.md` — new "Valkey / Redis Result Storage" section with annotated code examples +- `QUICK_REFERENCE.md` — new "Valkey/Redis Storage Quick Reference (v1.9)" section with cheatsheet +- `docs/VALIDATION_GUIDE.md` — ReID tracking validation notes added + +### Changed + +- `mata.nodes.__all__` extended with `ValkeyStore` and `ValkeyLoad` +- `mata.core.exporters.__init__` extended with `export_valkey`, `load_valkey`, `publish_valkey` +- `mata.track()` signature extended with `reid_model`, `with_reid`, `reid_bridge` kwargs (backward-compatible defaults) +- `TrackingAdapter.__init__()` extended with `reid_encoder`, `reid_bridge` kwargs (both default to `None`; zero overhead when unused) +- `BOTSORT.get_dists()` appearance-distance branch now reachable when `encoder` is set via `reid_encoder` +- ByteTrack vs BotSort ReID comparison table in `README.md` updated to reflect v1.9.2 BotSort support + +--- + ## [1.9.1] - 2026-03-08 ### Changed diff --git a/QUICKSTART.md b/QUICKSTART.md index 44edcda..1f636fb 100644 --- a/QUICKSTART.md +++ b/QUICKSTART.md @@ -366,6 +366,75 @@ print(f"mAP@50-95: {metrics.box.map:.3f}") All four tasks are supported — detection, segmentation, classification, and depth. See the [Validation Guide](docs/VALIDATION_GUIDE.md) for dataset setup, full API reference, and metrics details. +## Valkey / Redis Result Storage + +Persist any MATA result to [Valkey](https://valkey.io/) or Redis for distributed pipelines, cross-process sharing, or caching. + +### Install + +```bash +pip install mata[valkey] # valkey-py (recommended) +pip install mata[redis] # redis-py (alternative) +``` + +### Save and load a result + +```bash +# Start a local Valkey server (or use an existing Redis server) +docker run -d -p 6379:6379 valkey/valkey:latest +``` + +```python +import mata + +# Run detection and save result to Valkey +result = mata.run("detect", "image.jpg", model="PekingU/rtdetr_r18vd") +result.save("valkey://localhost:6379/detections:frame_001") + +# Load it back later (in a different process or service) +from mata.core.exporters import load_valkey +loaded = load_valkey(url="valkey://localhost:6379", key="detections:frame_001") +print(loaded) # equivalent VisionResult +``` + +### Use in a graph pipeline with `ValkeyStore` / `ValkeyLoad` + +```python +import mata +from mata.nodes import Detect, Filter, ValkeyStore, ValkeyLoad, Fuse +from mata.core.graph import Graph + +detector = mata.load("detect", "PekingU/rtdetr_r18vd") + +# Pipeline A — detect and persist +store_graph = ( + Graph() + .then(Detect(using="detr", out="dets")) + .then(Filter(src="dets", score_gt=0.4, out="filtered")) + .then(ValkeyStore( + src="filtered", + url="valkey://localhost:6379", + key="pipeline:detections:{timestamp}", + ttl=60, # expires after 60 s + )) +) +mata.infer("frame.jpg", graph=store_graph, providers={"detr": detector}) + +# Pipeline B — load and annotate (in a separate service) +load_graph = ( + Graph() + .then(ValkeyLoad( + url="valkey://localhost:6379", + key="pipeline:detections:latest", + out="dets", + )) + .then(Fuse(detections="dets", out="annotated")) +) +result = mata.infer("frame.jpg", graph=load_graph, providers={}) +``` + +See the [Graph API Reference](docs/GRAPH_API_REFERENCE.md#storage-nodes) for full parameter documentation. + ## Next Steps 1. **Read the full documentation**: [README.md](README.md) diff --git a/QUICK_REFERENCE.md b/QUICK_REFERENCE.md index f9d22b6..ad69ecc 100644 --- a/QUICK_REFERENCE.md +++ b/QUICK_REFERENCE.md @@ -15,6 +15,7 @@ | [Object Tracking](#-object-tracking-quick-reference-v18) | v1.8 | | [OCR / Text Extraction](#-ocr--text-extraction-quick-reference-v19) | v1.9 | | [Evaluation](#-evaluation-quick-reference-v18) | v1.8 | +| [Valkey/Redis Storage](#-valkeyredis-storage-quick-reference-v19) | v1.9 | --- @@ -1583,7 +1584,117 @@ metrics = mata.val( --- -**Version:** 1.8.1 -**Date:** February 20, 2026 +## 🗄️ Valkey/Redis Storage Quick Reference (v1.9) + +### Installation + +```bash +pip install mata[valkey] # valkey-py (recommended) +pip install mata[redis] # redis-py (alternative) +``` + +### `result.save()` — URI scheme + +```python +# Any result type supports valkey:// and redis:// URIs directly in save() +result.save("valkey://localhost:6379/my_key") # basic +result.save("valkey://localhost:6379/0/my_key") # with DB number +result.save("redis://localhost:6379/my_key") # redis-py fallback +result.save("valkey://localhost:6379/my_key", ttl=300) # with TTL (seconds) +``` + +### Low-level exporter + +```python +from mata.core.exporters import export_valkey, load_valkey, publish_valkey + +# Store +export_valkey(result, url="valkey://localhost:6379", key="my_key", ttl=3600) + +# Load +loaded = load_valkey(url="valkey://localhost:6379", key="my_key") + +# Load with explicit result type (skip auto-detection) +loaded = load_valkey(url="valkey://localhost:6379", key="my_key", + result_type="vision") # or "classify", "depth", "ocr" + +# Pub/Sub publish (fire-and-forget) +n_receivers = publish_valkey(result, url="valkey://localhost:6379", + channel="detections:stream") +``` + +### URI formats + +| Format | Example | +| ------ | ------- | +| Basic | `valkey://localhost:6379/key` | +| With DB | `valkey://localhost:6379/0/key` | +| With auth | `valkey://user:pass@host:6379/0/key` | +| Redis | `redis://localhost:6379/key` | +| Redis TLS | `rediss://host:6379/key` | + +### Graph nodes: `ValkeyStore` / `ValkeyLoad` + +```python +from mata.nodes import ValkeyStore, ValkeyLoad + +# Sink node — store artifact and pass it through unchanged +ValkeyStore( + src="filtered", # artifact name in graph context + url="valkey://localhost:6379", + key="pipeline:{node}:{timestamp}", # {node} and {timestamp} placeholders + ttl=3600, # optional TTL in seconds + serializer="json", # "json" (default) or "msgpack" + out="filtered", # optional override for output name +) + +# Source node — load artifact as graph entry point +ValkeyLoad( + url="valkey://localhost:6379", + key="upstream:detections:latest", + result_type="auto", # or "vision", "classify", "depth", "ocr" + out="dets", +) +``` + +### Auto-detection of result type + +| Key in stored data | Detected type | Output artifact | +| ------------------ | ------------- | --------------- | +| `instances` | `vision` | `Detections` | +| `predictions` | `classify` | `Classifications` | +| `depth` | `depth` | `DepthMap` | +| `regions` | `ocr` | _(raw dict)_ | + +### Named connections (YAML config) + +```yaml +# .mata/models.yaml +storage: + valkey: + default: + url: "valkey://localhost:6379" + db: 0 + ttl: 3600 + production: + url: "valkey://prod-cluster:6379" + password_env: "VALKEY_PASSWORD" # read from env, never stored in plaintext + tls: true +``` + +```python +from mata.core.model_registry import ModelRegistry + +registry = ModelRegistry() +conn = registry.get_valkey_connection("production") +# → {"url": "valkey://...", "password": "", "tls": True} +``` + +**Documentation:** [Graph API Reference — Storage Nodes](docs/GRAPH_API_REFERENCE.md#storage-nodes) + +--- + +**Version:** 1.9.0 +**Date:** March 9, 2026 **Status:** ✅ Production Ready ```` diff --git a/README.md b/README.md index b93274f..1dcd12a 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,9 @@ MATA focuses on **stable task contracts** and **pluggable runtimes**, allowing y - **Vision-Language Models**: Image captioning, VQA, and visual understanding with Qwen3-VL and more - **Multi-Format Runtime**: PyTorch ✅ | ONNX Runtime ✅ | TorchScript ✅ | Torchvision ✅ | TensorRT (planned) - **Graph System** (v1.6): Multi-task workflows with `mata.infer()`, parallel execution, conditional branching, and video tracking -- **Object Tracking** (v1.8): `mata.track()` — Video/stream tracking with vendored ByteTrack and BotSort, persistent track IDs, trajectory trails, and CSV/JSON export +- **Object Tracking** (v1.8+): `mata.track()` — Video/stream tracking with vendored ByteTrack and BotSort, persistent track IDs, trajectory trails, CSV/JSON export, and appearance-based ReID (v1.9.2) - **OCR / Text Extraction** (v1.9): `mata.run("ocr", ...)` — extract printed and handwritten text using GOT-OCR2, TrOCR, EasyOCR, PaddleOCR, or Tesseract with per-region confidence and bounding boxes +- **Valkey/Redis Result Storage** (v1.9): persist any result to Valkey/Redis with `result.save("valkey://host/key")` or via `ValkeyStore`/`ValkeyLoad` graph nodes — enables distributed pipelines and cross-process result sharing - **Validation & Evaluation**: `mata.val()` — mAP/accuracy/depth metrics against COCO, ImageNet, or custom datasets - **Export & Visualization**: Save as JSON/CSV/image overlays/crops with dual backends (PIL/matplotlib) - **Task-First API**: Specify what you want (detect, segment, classify, depth, ocr, vlm), not which model to use @@ -98,6 +99,10 @@ pip install onnxruntime-gpu # GPU # For publication-quality visualizations pip install matplotlib + +# For Valkey/Redis result storage +pip install mata[valkey] # valkey-py client (recommended) +pip install mata[redis] # redis-py client (alternative) ``` ## 🚀 Quick Start @@ -339,7 +344,7 @@ result = mata.infer(graph=graph, video="video.mp4", providers={...}) | Speed | Faster | Slightly slower | | Accuracy | Good | Better (especially for panning cameras) | | Default | No | **Yes** (MATA default, matches Ultralytics) | -| ReID | ❌ v1.8 | ❌ v1.8 (planned for v1.9) | +| ReID | ❌ No | ✅ v1.9.2 (`reid_model=` kwarg) | **Configuration via YAML:** @@ -360,6 +365,68 @@ models: tracker = mata.load("track", "highway-cam") ``` +#### Appearance-Based ReID (New in v1.9.2) + +Enable appearance re-identification with BotSort to recover track IDs after occlusion or re-entry: + +```python +# Pass any HuggingFace image encoder (ViT, CLIP, OSNet, etc.) as a ReID model +results = mata.track( + "video.mp4", + model="facebook/detr-resnet-50", + tracker="botsort", + reid_model="openai/clip-vit-base-patch32", # appearance encoder + conf=0.3, + save=True, +) +``` + +ONNX models are also supported for production deployment: + +```python +# Use a local .onnx ReID model for low-latency inference +results = mata.track( + "video.mp4", + model="facebook/detr-resnet-50", + reid_model="osnet_x1_0.onnx", # local ONNX ReID model +) +``` + +ReID can also be declared in a config alias: + +```yaml +# .mata/models.yaml +models: + track: + smart-cam: + source: "facebook/detr-resnet-50" + tracker: botsort + reid_model: "openai/clip-vit-base-patch32" + tracker_config: + track_high_thresh: 0.6 + appearance_thresh: 0.25 +``` + +```python +tracker = mata.load("track", "smart-cam") # ReID loaded automatically +``` + +**Cross-camera re-identification** via Valkey: + +```python +from mata.trackers import ReIDBridge + +# Camera 1 — publish embeddings +bridge = ReIDBridge("valkey://localhost:6379", camera_id="cam-1") +results = mata.track("rtsp://cam1/stream", model="detr", + reid_model="openai/clip-vit-base-patch32", + reid_bridge=bridge, stream=True) + +# Camera 2 — query nearest identity +bridge2 = ReIDBridge("valkey://localhost:6379", camera_id="cam-2") +# Embeddings from cam-1 are queryable cross-camera with cosine similarity +``` + See [Tracking Examples](examples/track/) | [Quick Reference](QUICK_REFERENCE.md#object-tracking) ### OCR / Text Extraction (New in v1.9) diff --git a/docs/GRAPH_API_REFERENCE.md b/docs/GRAPH_API_REFERENCE.md index 65ff361..d1d6099 100644 --- a/docs/GRAPH_API_REFERENCE.md +++ b/docs/GRAPH_API_REFERENCE.md @@ -9,16 +9,17 @@ 1. [Public API](#public-api) 2. [Artifacts](#artifacts) 3. [Built-in Nodes](#built-in-nodes) -4. [Graph Builder](#graph-builder) -5. [Schedulers](#schedulers) -6. [Execution Context](#execution-context) -7. [Providers & Protocols](#providers--protocols) -8. [Conditional Execution](#conditional-execution) -9. [Temporal / Video](#temporal--video) -10. [Observability](#observability) -11. [DSL Helpers](#dsl-helpers) -12. [Presets](#presets) -13. [Converters & Utilities](#converters--utilities) +4. [Storage Nodes](#storage-nodes) +5. [Graph Builder](#graph-builder) +6. [Schedulers](#schedulers) +7. [Execution Context](#execution-context) +8. [Providers & Protocols](#providers--protocols) +9. [Conditional Execution](#conditional-execution) +10. [Temporal / Video](#temporal--video) +11. [Observability](#observability) +12. [DSL Helpers](#dsl-helpers) +13. [Presets](#presets) +14. [Converters & Utilities](#converters--utilities) --- @@ -860,6 +861,264 @@ Uses `torchvision.ops.nms` internally. --- +## Storage Nodes + +Storage nodes connect graph pipelines to [Valkey](https://valkey.io/) (or wire-compatible Redis) for distributed result sharing, cross-pipeline caching, and event-driven architectures. + +```python +from mata.nodes import ValkeyStore, ValkeyLoad +``` + +**Installation:** + +```bash +pip install mata[valkey] # valkey-py client +pip install mata[redis] # redis-py client (fallback) +``` + +Both nodes lazy-import the client library — `import mata` succeeds without either installed. An `ImportError` with an actionable message is raised only when a storage node actually executes. + +--- + +### URI Scheme Format + +Storage nodes and `result.save()` accept Valkey/Redis URIs in the following formats: + +| Format | Example | +| -------------- | --------------------------------------- | +| Basic | `valkey://localhost:6379/my_key` | +| With DB number | `valkey://localhost:6379/0/my_key` | +| With password | `valkey://user:pass@host:6379/0/my_key` | +| Redis fallback | `redis://localhost:6379/my_key` | +| TLS (Redis) | `rediss://host:6379/my_key` | + +The **key** is the last path segment (or everything after `db/` when a numeric DB is present). Passwords in URIs are passed through to the client and are **never logged**. + +**Direct save from result objects:** + +```python +# Any MATA result type supports valkey:// URIs in save() +result.save("valkey://localhost:6379/pipeline:detections:latest") +result.save("valkey://localhost:6379/0/detections:frame_042") # DB 0 +result.save("redis://localhost:6379/detections") # redis-py +``` + +--- + +### Key Template Syntax + +`ValkeyStore` accepts a `key` parameter that supports safe placeholder substitution: + +| Placeholder | Resolved value | Example output | +| ------------- | ---------------------------- | -------------- | +| `{node}` | Node's `name` attribute | `ValkeyStore` | +| `{timestamp}` | Unix epoch (integer seconds) | `1741478400` | + +Placeholders are resolved with `str.format()` using only these two predefined variables — user-controlled input is **never** interpolated directly. + +```python +ValkeyStore( + src="filtered", + url="valkey://localhost:6379", + key="pipeline:{node}:{timestamp}", # → "pipeline:ValkeyStore:1741478400" + ttl=3600, +) +``` + +--- + +### `ValkeyStore` + +Sink node that writes an artifact to Valkey during graph execution. The artifact passes through unchanged, so downstream nodes can still consume it. + +```python +ValkeyStore( + src: str, + url: str, + key: str, + ttl: int | None = None, + serializer: str = "json", + out: str | None = None, +) +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +| ------------ | ------------- | ------------- | -------------------------------------------------------------- | +| `src` | `str` | _required_ | Name of the input artifact in the graph context | +| `url` | `str` | _required_ | Valkey/Redis connection URL (see URI formats above) | +| `key` | `str` | _required_ | Key name or template (`{node}`, `{timestamp}` supported) | +| `ttl` | `int \| None` | `None` | Key expiration in seconds; `None` = no expiry | +| `serializer` | `str` | `"json"` | `"json"` (default) or `"msgpack"` (requires `msgpack` package) | +| `out` | `str \| None` | same as `src` | Output artifact name (pass-through) | + +**I/O:** + +| I/O | Name | Type | +| ------ | ---------- | ---------------------- | +| Input | `artifact` | `Artifact` (any) | +| Output | `artifact` | `Artifact` (unchanged) | + +**Supported artifact types:** `Detections`, `Masks`, `Classifications`, `DepthMap`. Other `Artifact` subclasses are stored as-is (best-effort). + +**Example:** + +```python +from mata.nodes import Detect, Filter, ValkeyStore +from mata.core.graph import Graph + +graph = ( + Graph() + .then(Detect(using="detr", out="dets")) + .then(Filter(src="dets", score_gt=0.5, out="filtered")) + .then(ValkeyStore( + src="filtered", + url="valkey://localhost:6379", + key="pipeline:detections:{timestamp}", + ttl=3600, + )) +) + +result = mata.infer("frame.jpg", graph=graph, providers={"detr": detector}) +# result.filtered is still available — ValkeyStore is a pass-through sink +``` + +--- + +### `ValkeyLoad` + +Source node that loads a previously stored result from Valkey and injects it into the graph as a typed artifact. Use this as an **entry node** to consume results produced by another pipeline. + +```python +ValkeyLoad( + url: str, + key: str, + result_type: str = "auto", + out: str = "loaded", +) +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +| ------------- | ----- | ---------- | --------------------------------------------------------- | +| `url` | `str` | _required_ | Valkey/Redis connection URL | +| `key` | `str` | _required_ | Key name to load from | +| `result_type` | `str` | `"auto"` | `"auto"`, `"vision"`, `"classify"`, `"depth"`, or `"ocr"` | +| `out` | `str` | `"loaded"` | Output artifact name in the graph context | + +**I/O:** + +| I/O | Name | Type | +| ------ | ---------------------- | ---------- | +| Input | _(none — source node)_ | — | +| Output | `artifact` | `Artifact` | + +**Auto-detection logic** (`result_type="auto"`): + +| Key present in stored data | Detected type | Output artifact | +| -------------------------- | ------------- | ----------------- | +| `instances` | `vision` | `Detections` | +| `predictions` | `classify` | `Classifications` | +| `depth` | `depth` | `DepthMap` | +| `regions` | `ocr` | _(raw dict)_ | + +**Raises:** + +- `KeyError` — key does not exist in Valkey +- `ValueError` — stored data cannot be auto-detected or `result_type` is unrecognized +- `ImportError` — valkey/redis client not installed + +**Example:** + +```python +from mata.nodes import ValkeyLoad, Filter, Fuse +from mata.core.graph import Graph + +graph = ( + Graph() + .then(ValkeyLoad( + url="valkey://localhost:6379", + key="upstream:detections:latest", + result_type="vision", + out="dets", + )) + .then(Filter(src="dets", score_gt=0.7, out="filtered")) + .then(Fuse(detections="filtered")) +) + +result = mata.infer("frame.jpg", graph=graph, providers={}) +``` + +--- + +### Complete Store → Load Pipeline Example + +This pattern enables two independent pipelines to share detection results through Valkey: + +```python +import mata +from mata.nodes import Detect, Filter, ValkeyStore, ValkeyLoad, Fuse +from mata.core.graph import Graph + +detector = mata.load("detect", "facebook/detr-resnet-50") + +# --- Pipeline A: run detection and persist to Valkey --- +store_graph = ( + Graph() + .then(Detect(using="detr", out="dets")) + .then(Filter(src="dets", score_gt=0.4, out="filtered")) + .then(ValkeyStore( + src="filtered", + url="valkey://localhost:6379", + key="shared:detections:latest", + ttl=60, # expires after 60 seconds + )) +) + +mata.infer("frame_001.jpg", graph=store_graph, providers={"detr": detector}) + +# --- Pipeline B: load persisted results and annotate --- +load_graph = ( + Graph() + .then(ValkeyLoad( + url="valkey://localhost:6379", + key="shared:detections:latest", + out="dets", + )) + .then(Fuse(detections="dets", out="annotated")) +) + +annotated = mata.infer("frame_001.jpg", graph=load_graph, providers={}) +``` + +**Named connections via YAML config:** + +```yaml +# .mata/models.yaml +storage: + valkey: + default: + url: "valkey://localhost:6379" + db: 0 + ttl: 3600 + production: + url: "valkey://prod-cluster:6379" + password_env: "VALKEY_PASSWORD" # read from env var, never stored in plaintext + db: 1 + tls: true +``` + +```python +from mata.core.model_registry import ModelRegistry + +registry = ModelRegistry() +conn = registry.get_valkey_connection("production") # resolves password from env +``` + +--- + ## Graph Builder ### `Graph` diff --git a/docs/VALIDATION_GUIDE.md b/docs/VALIDATION_GUIDE.md index 75f5833..a3af987 100644 --- a/docs/VALIDATION_GUIDE.md +++ b/docs/VALIDATION_GUIDE.md @@ -1128,4 +1128,125 @@ from mata.eval.metrics import COCO_IOU_THRESHOLDS 5. **Single-IoU confusion matrix:** The `ConfusionMatrix` operates at a single IoU threshold (default 0.45) and confidence threshold (default 0.25), regardless of the `iou` parameter passed to `mata.val()`. 6. **OCR recognition-only:** `OCRMetrics` computes recognition metrics only (CER, WER, exact-match accuracy). Text detection metrics (H-mean precision/recall/F1 on bounding-box matching) and end-to-end evaluation (combined detection + recognition) are not yet supported. Pass `mode="e2e"` is reserved for a future release. -7. **OCR image-level comparison:** Ground-truth transcriptions for all text regions in an image are concatenated with spaces before CER/WER computation. This avoids the hard problem of pairing predicted regions to GT regions (which requires IoU matching), but means per-region error rates are not available. \ No newline at end of file +7. **OCR image-level comparison:** Ground-truth transcriptions for all text regions in an image are concatenated with spaces before CER/WER computation. This avoids the hard problem of pairing predicted regions to GT regions (which requires IoU matching), but means per-region error rates are not available. + +--- + +## ReID Tracking Notes (v1.9.2) + +Appearance-Based Re-Identification (ReID) in MATA enhances BotSort's track-recovery capability after occlusion or target re-entry. This section covers how to enable ReID, inspect its outputs, and reason about tracking quality. + +### What ReID Adds + +Without ReID, BotSort associates detections to tracks using two cues: +1. **IoU** — spatial overlap between predicted and detected bounding boxes +2. **GMC** — global motion compensation (sparse optical flow) for camera motion + +With ReID enabled (`reid_model=...`), a third cue is added: +3. **Cosine appearance distance** — L2-normalised embedding vectors extracted from detection crops are compared against cached track features (`smooth_feat`) + +This allows BotSort to re-associate tracks even when the predicted position drifts significantly due to occlusion gaps. + +### Enabling ReID + +```python +import mata + +# Single-camera tracking with ReID +results = mata.track( + "video.mp4", + model="facebook/detr-resnet-50", + tracker="botsort", + reid_model="openai/clip-vit-base-patch32", # any HF image encoder + conf=0.3, +) + +# Inspect per-instance embedding vectors +for frame_result in results: + for inst in frame_result.instances: + print(f"Track #{inst.track_id} embedding shape: " + f"{inst.embedding.shape if inst.embedding is not None else 'N/A'}") +``` + +ONNX models are also supported: + +```python +results = mata.track( + "video.mp4", + model="facebook/detr-resnet-50", + reid_model="osnet_x1_0.onnx", # local ONNX ReID model +) +``` + +### Inspecting Embedding Quality + +Each tracked instance with an active ReID encoder will have `Instance.embedding` populated with an L2-normalised float32 vector of shape `(D,)`: + +```python +import numpy as np + +for result in results: + for inst in result.instances: + if inst.embedding is not None: + emb = inst.embedding + assert abs(np.linalg.norm(emb) - 1.0) < 1e-5, "Not unit norm" + print(f"Track #{inst.track_id}: {emb.shape}, norm={np.linalg.norm(emb):.4f}") +``` + +### Cross-Camera ReID with `ReIDBridge` + +`ReIDBridge` publishes confirmed-track embeddings to a shared Valkey store so independent tracker instances can resolve the same physical identity across feeds. + +```python +from mata.trackers import ReIDBridge + +# Camera A +bridge_a = ReIDBridge( + "valkey://localhost:6379", + camera_id="cam-a", + ttl=300, # embeddings expire after 5 min + similarity_thresh=0.25, # cosine similarity cutoff +) + +# mata.track() with reid_bridge: each confirmed track is published after update() +for result in mata.track( + "rtsp://cam-a/stream", + model="facebook/detr-resnet-50", + reid_model="openai/clip-vit-base-patch32", + reid_bridge=bridge_a, + stream=True, +): + active = [i for i in result.instances if i.track_id is not None] + print(f"Active tracks: {len(active)}") + +# Camera B — query nearest identity from cam-a +bridge_b = ReIDBridge("valkey://localhost:6379", camera_id="cam-b") +query_embedding = ... # np.ndarray shape (D,), L2-normalised +matches = bridge_b.query(query_embedding, exclude_camera="cam-b", top_k=1) +if matches: + print(f"Best cross-camera match: {matches[0]}") + # {'track_id': 7, 'camera_id': 'cam-a', 'similarity': 0.83, ...} +``` + +### ReID Validation Tips + +| Scenario | Recommended Approach | +| -------- | -------------------- | +| Verify embeddings are populated | Check `inst.embedding is not None` after `update()` | +| Measure track-recovery rate | Count frames where a lost track recovers its original ID | +| Tune appearance threshold | Adjust `appearance_thresh` in `tracker_config` (BotSort default: 0.25) | +| Reduce false re-associations | Increase `reid_model` → use a more discriminative encoder (e.g., OSNet vs CLIP) | +| GPU inference for ReID | Pass `device="cuda"` at `mata.load("track", ..., device="cuda")` | +| ONNX production deployment | Export your ReID model to ONNX and pass the `.onnx` path as `reid_model` | + +### Known Limitations (ReID) + +1. **BotSort only:** ReID is integrated into BotSort's `get_dists()` method. ByteTrack does not support appearance-distance matching — `reid_model` is silently ignored when `tracker="bytetrack"`. + +2. **No detection-level alignment:** ReID embeddings are computed for all detections that pass the confidence threshold, not only those that fail IoU association. For very dense scenes this may increase latency. Future work: skip ReID for IoU-matched detections (40–60% latency reduction). + +3. **Cross-camera ID namespace:** Each tracker process maintains an independent `STrack._count` — cross-camera track IDs are not globally unique. `ReIDBridge` resolves this at the application layer by storing `(camera_id, track_id)` pairs. + +4. **Embedding warm-up:** BotSort's `smooth_feat` is a running average that stabilises after ~5 frames. Track re-association quality may be lower for newly initialised tracks. + +5. **Valkey dependency for `ReIDBridge`:** `ReIDBridge` requires `pip install mata[valkey]` (or `mata[redis]`). If the server is unreachable, `publish()` / `query()` log a warning and return gracefully — tracking continues unaffected. \ No newline at end of file diff --git a/docs/VALKEY_GUIDE.md b/docs/VALKEY_GUIDE.md new file mode 100644 index 0000000..c0074c3 --- /dev/null +++ b/docs/VALKEY_GUIDE.md @@ -0,0 +1,791 @@ +# MATA Valkey/Redis Integration Guide + +**Version**: 1.9.2 +**Last Updated**: March 9, 2026 +**Status**: ✅ Production Ready + +--- + +## Table of Contents + +1. [Installation & Setup](#1-installation--setup) +2. [Basic Usage (save/load)](#2-basic-usage-saveload) +3. [Graph Pipeline Integration](#3-graph-pipeline-integration) +4. [YAML Configuration (Named Connections)](#4-yaml-configuration-named-connections) +5. [Streaming Patterns (Per-Frame Tracking)](#5-streaming-patterns-per-frame-tracking) +6. [Pub/Sub Event-Driven Architecture](#6-pubsub-event-driven-architecture) +7. [Security](#7-security) +8. [Performance Tuning](#8-performance-tuning) +9. [Troubleshooting](#9-troubleshooting) + +--- + +## 1. Installation & Setup + +### Install the client library + +MATA supports [Valkey](https://valkey.io/) (the open-source Redis fork) and Redis via optional extras: + +```bash +pip install mata[valkey] # valkey-py >= 6.0.0 (recommended) +pip install mata[redis] # redis-py >= 5.0.0 (alternative, wire-compatible) +``` + +Both clients are optional — `import mata` succeeds without either installed. An `ImportError` with an actionable message is raised only when a storage operation is actually executed. + +### Start a local Valkey server + +The quickest way to get a local server running is with Docker: + +```bash +docker run -d --name valkey-server -p 6379:6379 valkey/valkey:latest +``` + +Or install natively: + +```bash +# macOS +brew install valkey + +# Ubuntu / Debian +apt-get install valkey-server +``` + +Verify connectivity: + +```bash +valkey-cli ping # → PONG +# or +redis-cli ping # works against Valkey too +``` + +### Supported URI schemes + +| Scheme | Client | Notes | +| ----------------------------------- | --------- | ------------------------ | +| `valkey://host:port/key` | valkey-py | Recommended | +| `valkey://host:port/0/key` | valkey-py | With DB number | +| `redis://host:port/key` | redis-py | Wire-compatible fallback | +| `redis://user:pass@host:port/0/key` | redis-py | With credentials | +| `rediss://host:port/key` | redis-py | TLS-encrypted connection | + +--- + +## 2. Basic Usage (save/load) + +### Save any result to Valkey + +All MATA result types (`VisionResult`, `ClassifyResult`, `DepthResult`, `OCRResult`, `DetectResult`, `SegmentResult`) support a `valkey://` URI directly in their `save()` method: + +```python +import mata + +# Run detection +result = mata.run("detect", "photo.jpg", model="PekingU/rtdetr_r18vd", threshold=0.4) + +# Save to Valkey — same API as saving to a file +result.save("valkey://localhost:6379/detections:frame_001") + +# With a TTL (expires after 5 minutes) +result.save("valkey://localhost:6379/detections:latest", ttl=300) + +# With a DB number +result.save("valkey://localhost:6379/1/detections:frame_001") +``` + +The same pattern works for all tasks: + +```python +depth = mata.run("depth", "scene.jpg", model="depth-anything/Depth-Anything-V2-Small-hf") +depth.save("valkey://localhost:6379/depth:latest") + +classes = mata.run("classify", "cat.jpg", model="microsoft/resnet-50") +classes.save("valkey://localhost:6379/classify:latest") + +text = mata.run("ocr", "scan.png", model="ucaslcl/GOT-OCR2_0") +text.save("valkey://localhost:6379/ocr:document_001") +``` + +### Load a result back + +```python +from mata.core.exporters import load_valkey + +# Auto-detect result type from stored data +result = load_valkey(url="valkey://localhost:6379", key="detections:frame_001") +print(type(result)) # + +# Explicit result type (faster, skips auto-detection) +result = load_valkey( + url="valkey://localhost:6379", + key="detections:frame_001", + result_type="vision", # "vision", "classify", "depth", "ocr" +) +``` + +### Low-level exporter API + +For more control, use the exporter functions directly: + +```python +from mata.core.exporters import export_valkey, load_valkey + +# Export with all options +export_valkey( + result=detection_result, + url="valkey://localhost:6379", + key="pipeline:output", + ttl=3600, # 1 hour TTL + serializer="json", # "json" (default) or "msgpack" +) + +# Load back +loaded = load_valkey(url="valkey://localhost:6379", key="pipeline:output") +``` + +### Round-trip example + +```python +import mata +from mata.core.exporters import export_valkey, load_valkey + +# Step 1: run inference +result = mata.run("detect", "image.jpg", model="PekingU/rtdetr_r18vd") + +# Step 2: persist +export_valkey(result, url="valkey://localhost:6379", key="my:result", ttl=600) + +# Step 3: load in another process / service +loaded = load_valkey(url="valkey://localhost:6379", key="my:result") + +# Results are equivalent +assert len(loaded.instances) == len(result.instances) +``` + +--- + +## 3. Graph Pipeline Integration + +### `ValkeyStore` — sink node + +`ValkeyStore` writes an artifact to Valkey during graph execution and passes it through unchanged, so downstream nodes can still consume it. + +```python +import mata +from mata.nodes import Detect, Filter, ValkeyStore +from mata.core.graph import Graph + +detector = mata.load("detect", "PekingU/rtdetr_r18vd") + +graph = ( + Graph() + .then(Detect(using="detr", out="dets")) + .then(Filter(src="dets", score_gt=0.4, out="filtered")) + .then(ValkeyStore( + src="filtered", + url="valkey://localhost:6379", + key="pipeline:detections:{timestamp}", # {timestamp} is Unix epoch + ttl=3600, + )) + # Downstream nodes still see "filtered" — ValkeyStore is a pass-through +) + +result = mata.infer("frame.jpg", graph=graph, providers={"detr": detector}) +print(result.filtered) # still accessible after store +``` + +**Key template placeholders:** + +| Placeholder | Resolved value | +| ------------- | ----------------------------------------- | +| `{node}` | Node's `name` attribute (`"ValkeyStore"`) | +| `{timestamp}` | Unix epoch (integer seconds at run time) | + +Only these two placeholders are supported — user data is never interpolated. + +### `ValkeyLoad` — source node + +`ValkeyLoad` loads a stored result from Valkey and injects it as the first artifact in a graph. Use this to build pipelines that consume results produced by another service. + +```python +from mata.nodes import ValkeyLoad, Filter, Fuse +from mata.core.graph import Graph + +graph = ( + Graph() + .then(ValkeyLoad( + url="valkey://localhost:6379", + key="upstream:detections:latest", + result_type="vision", # "auto" also works + out="dets", + )) + .then(Filter(src="dets", score_gt=0.6, out="hi_conf")) + .then(Fuse(detections="hi_conf", out="annotated")) +) + +result = mata.infer("frame.jpg", graph=graph, providers={}) +``` + +### Complete cross-pipeline example + +This pattern enables two independent services to share detection results through Valkey: + +```python +import mata +from mata.nodes import Detect, Filter, ValkeyStore, ValkeyLoad, Fuse +from mata.core.graph import Graph + +detector = mata.load("detect", "PekingU/rtdetr_r18vd") + +# ── Service A: Camera ingestion ────────────────────────────────────────── +store_graph = ( + Graph() + .then(Detect(using="detr", out="dets")) + .then(Filter(src="dets", score_gt=0.3, out="filtered")) + .then(ValkeyStore( + src="filtered", + url="valkey://prod-cluster:6379", + key="cam01:detections:latest", + ttl=10, # fresh for 10 s; overwritten each frame + )) +) + +for frame in camera_frames(): + mata.infer(frame, graph=store_graph, providers={"detr": detector}) + + +# ── Service B: Downstream analytics (separate process / machine) ───────── +load_graph = ( + Graph() + .then(ValkeyLoad( + url="valkey://prod-cluster:6379", + key="cam01:detections:latest", + out="dets", + )) + .then(Filter(src="dets", score_gt=0.7, out="hi_conf")) + .then(Fuse(detections="hi_conf", out="annotated")) +) + +annotated = mata.infer(latest_frame, graph=load_graph, providers={}) +``` + +--- + +## 4. YAML Configuration (Named Connections) + +Instead of hard-coding URLs in your code, define named connection profiles in `.mata/models.yaml` (project-local) or `~/.mata/models.yaml` (user-global): + +```yaml +# .mata/models.yaml +models: + detect: + rtdetr-fast: + source: "PekingU/rtdetr_r18vd" + threshold: 0.4 + +# Storage section — new in v1.9.0 +storage: + valkey: + default: + url: "valkey://localhost:6379" + db: 0 + ttl: 3600 + + staging: + url: "valkey://staging-host:6379" + db: 1 + ttl: 600 + + production: + url: "valkey://prod-cluster:6379" + password_env: "VALKEY_PASSWORD" # ← resolved from environment variable + db: 0 + tls: true + ttl: 86400 +``` + +### Retrieve a connection profile + +```python +from mata.core.model_registry import ModelRegistry + +registry = ModelRegistry() + +# Get the default connection +conn = registry.get_valkey_connection() # name="default" + +# Get a named connection +conn = registry.get_valkey_connection("production") + +# conn is a plain dict — pass to export_valkey as **kwargs +# { "url": "valkey://prod-cluster:6379", "password": "", "tls": True, "ttl": 86400 } +``` + +### Password management with `password_env` + +**Never store passwords in YAML.** Use the `password_env` key to reference an environment variable: + +```yaml +production: + url: "valkey://prod-cluster:6379" + password_env: "VALKEY_PASSWORD" # the env-var NAME, not the value +``` + +At runtime, `ModelRegistry.get_valkey_connection()` resolves `os.environ["VALKEY_PASSWORD"]` and replaces `password_env` with `password` in the returned dict. If the variable is not set, the `password` key is omitted entirely (no error). + +```bash +# Set before running your application +export VALKEY_PASSWORD="s3cr3tP@ssw0rd" +``` + +The env-var name itself is never logged or returned — only the resolved password is passed to the client. + +--- + +## 5. Streaming Patterns (Per-Frame Tracking) + +For real-time video pipelines, write each frame's tracking results to Valkey with a rolling TTL so the key always holds the latest state: + +```python +import mata +import time + +tracker = mata.load("track", "PekingU/rtdetr_r18vd", tracker="botsort") + +VALKEY_URL = "valkey://localhost:6379" +ROLLING_KEY = "track:cam01:latest" +TTL = 5 # seconds; overwritten each frame, auto-expires if feed drops + +cap = cv2.VideoCapture("rtsp://camera/stream") +while True: + ret, frame = cap.read() + if not ret: + break + + result = tracker.update(frame, persist=True) + + # Overwrite the rolling key — downstream consumers always read "latest" + result.save(f"{VALKEY_URL}/{ROLLING_KEY}", ttl=TTL) + +cap.release() +``` + +### Per-frame keyed history + +For audit trails, write each frame to a unique key using the frame index or timestamp: + +```python +for frame_idx, frame in enumerate(video_frames("recording.mp4")): + result = mata.run("detect", frame, model="PekingU/rtdetr_r18vd") + result.save(f"valkey://localhost:6379/recording:frame:{frame_idx:06d}", ttl=3600) +``` + +Retrieve a specific frame later: + +```python +from mata.core.exporters import load_valkey + +frame_result = load_valkey( + url="valkey://localhost:6379", + key="recording:frame:000042", +) +``` + +### Using `ValkeyStore` in `mata.track()` stream mode + +```python +import mata +from mata.nodes import ValkeyStore +from mata.core.graph import Graph + +# Build a small post-processing graph around each tracking result +store_node = ValkeyStore( + src="tracked", + url="valkey://localhost:6379", + key="track:{timestamp}", + ttl=30, +) + +for result in mata.track( + "rtsp://cam/stream", + model="PekingU/rtdetr_r18vd", + tracker="bytetrack", + stream=True, +): + # Export directly from the result object in stream mode + result.save("valkey://localhost:6379/track:stream:latest", ttl=5) +``` + +--- + +## 6. Pub/Sub Event-Driven Architecture + +Use `publish_valkey()` to broadcast results to real-time subscribers without holding state. Messages are fire-and-forget: if no subscriber is listening, the message is dropped. + +```python +from mata.core.exporters import publish_valkey + +result = mata.run("detect", "frame.jpg", model="PekingU/rtdetr_r18vd") + +n = publish_valkey( + result=result, + url="valkey://localhost:6379", + channel="detections:stream", # Pub/Sub channel name +) +print(f"Delivered to {n} subscriber(s)") +``` + +### Subscriber (separate process) + +```python +import valkey +import json + +client = valkey.from_url("valkey://localhost:6379") +pubsub = client.pubsub() +pubsub.subscribe("detections:stream") + +for message in pubsub.listen(): + if message["type"] != "message": + continue + + data = json.loads(message["data"]) + # Reconstruct the result object if needed + from mata.core.exporters.valkey_exporter import _deserialize_result, _detect_result_type + result_type = _detect_result_type(data) + result = _deserialize_result(data, result_type) + process(result) +``` + +### Event-driven pipeline with Pub/Sub and ValkeyStore combined + +For patterns where you need both real-time notifications **and** persistent storage: + +```python +from mata.core.exporters import export_valkey, publish_valkey + +result = mata.run("detect", "frame.jpg", model="PekingU/rtdetr_r18vd") + +# Persist for later retrieval (with TTL) +export_valkey(result, url="valkey://localhost:6379", key="det:latest", ttl=60) + +# Broadcast to live subscribers +publish_valkey(result, url="valkey://localhost:6379", channel="det:events") +``` + +**Channel naming guidelines:** + +- Use hierarchical names separated by `:` — e.g., `cam01:detections`, `pipeline:alerts` +- Do not interpolate user-controlled strings directly into channel names without sanitization +- Pattern subscriptions (`PSUBSCRIBE cam*:detections`) work natively with Valkey pub/sub + +--- + +## 7. Security + +### TLS connections + +For production connections over untrusted networks, use TLS. With `redis-py` you can use the `rediss://` scheme (note the double `s`): + +```python +result.save("rediss://prod-host:6380/detections:latest") +``` + +With a named config: + +```yaml +production: + url: "valkey://prod-host:6380" + tls: true + password_env: "VALKEY_PASSWORD" +``` + +The `tls: true` flag is passed through to the client as `ssl=True` when calling `from_url()`. + +### Credentials — never log, never hard-code + +**Never put passwords in:** + +- Source code +- YAML config files (use `password_env` instead) +- Log messages + +MATA enforces this in the exporter layer: the raw connection URL is never passed to the logger. If the URL contains a password segment (e.g., `valkey://user:pass@host/key`), only the key name is logged, not the URL. + +Bad practice: + +```python +# ❌ Hard-coded credentials in source +result.save("valkey://admin:secret@host:6379/key") +``` + +Correct practice: + +```python +# ✅ Credentials from environment variable via named connection +import os +from mata.core.model_registry import ModelRegistry + +registry = ModelRegistry() +conn = registry.get_valkey_connection("production") +# The password came from os.environ["VALKEY_PASSWORD"] — never written in code +result.save(conn["url"] + "/my_key", password=conn.get("password")) +``` + +### SSRF prevention + +When your application exposes an API that accepts Valkey URIs from users, validate them before use: + +```python +import ipaddress +from urllib.parse import urlparse + +ALLOWED_VALKEY_HOSTS = {"valkey-internal", "localhost", "127.0.0.1"} + +def safe_valkey_uri(uri: str) -> str: + """Validate that a Valkey URI points to an allowed host.""" + parsed = urlparse(uri) + host = parsed.hostname or "" + + # Reject private/loopback IPs if they're not explicitly allowed + try: + addr = ipaddress.ip_address(host) + if addr.is_private and host not in ALLOWED_VALKEY_HOSTS: + raise ValueError(f"Disallowed Valkey host: {host!r}") + except ValueError: + pass # not an IP address — proceed + + if host not in ALLOWED_VALKEY_HOSTS: + raise ValueError(f"Valkey host '{host}' not in allowlist") + + return uri +``` + +Never pass externally-supplied URLs to `export_valkey()` or `load_valkey()` without validation. + +### Key name sanitization + +Key names derived from user input should be sanitized to prevent overwriting critical keys: + +```python +import re + +def safe_key(user_input: str, prefix: str = "user") -> str: + """Sanitize a user-supplied string for use as a Valkey key segment.""" + # Allow only alphanumeric, hyphens, underscores + clean = re.sub(r"[^a-zA-Z0-9_\-]", "_", user_input) + return f"{prefix}:{clean}" +``` + +--- + +## 8. Performance Tuning + +### Serializer choice: JSON vs msgpack + +| | `json` (default) | `msgpack` | +| ----------- | -------------------------- | ----------------------------------------- | +| Format | UTF-8 text | binary | +| Size | larger (Base64 for arrays) | ~30–60% smaller for numeric arrays | +| Speed | fast | faster for large payloads | +| Dependency | stdlib | `pip install msgpack` | +| Readability | human-readable | binary (use `msgpack.unpackb` to inspect) | + +Use `msgpack` for high-throughput tracking pipelines with dense bounding-box arrays: + +```python +export_valkey( + result, + url="valkey://localhost:6379", + key="track:latest", + serializer="msgpack", + ttl=10, +) +``` + +> **Note:** `DepthResult` stores a `(H, W)` float array that can be several MB as JSON. Use `msgpack` or consider downsampling the depth map before storing. + +### TTL strategies + +| Scenario | Recommended TTL | +| ----------------------------- | --------------------------------------------------------------- | +| Rolling latest frame | 5–30 s (overwritten each frame) | +| Frame history / audit trail | 1–24 h (depends on retention needs) | +| Cross-pipeline handoff | 60–300 s (long enough for downstream to pick up) | +| Development / debugging | `None` (inspect keys with `valkey-cli`) | +| Production with memory budget | 10–60 min + set `maxmemory-policy allkeys-lru` in Valkey config | + +### Connection pooling + +`valkey-py` and `redis-py` both maintain an internal connection pool. If you're calling `export_valkey()` in a tight loop, reuse the client instead of reconnecting each call: + +```python +import valkey + +client = valkey.from_url("valkey://localhost:6379", max_connections=10) + +for result in results: + data = result.to_json() + client.setex(f"frame:{i}", 30, data) +``` + +For the public API, you can pass a pre-built client via `**kwargs` if the exporter supports it, or use the node-level graph integration which reuses the connection within a single graph execution. + +### Async patterns + +MATA's core exporters are synchronous. For async use cases (e.g., FastAPI), run exports in a thread pool: + +```python +import asyncio +from concurrent.futures import ThreadPoolExecutor +from mata.core.exporters import export_valkey + +executor = ThreadPoolExecutor(max_workers=4) + +async def async_export(result, url, key, ttl=None): + loop = asyncio.get_running_loop() + await loop.run_in_executor( + executor, + lambda: export_valkey(result, url=url, key=key, ttl=ttl), + ) +``` + +--- + +## 9. Troubleshooting + +### Issue 1: `ImportError` — no client installed + +**Symptom:** + +``` +ImportError: Valkey export requires 'valkey' or 'redis' package. +Install with: pip install mata[valkey] or pip install mata[redis] +``` + +**Solution:** + +```bash +pip install mata[valkey] # or pip install mata[redis] +``` + +This error only occurs when a storage operation is actually called — `import mata` succeeds without either package. + +--- + +### Issue 2: `ConnectionError` — server unreachable + +**Symptom:** + +``` +valkey.exceptions.ConnectionError: Error 111 connecting to localhost:6379. Connection refused. +``` + +**Cause:** No Valkey/Redis server is running, or the host/port is wrong. + +**Solution:** + +```bash +# Start a local server with Docker +docker run -d --name valkey-server -p 6379:6379 valkey/valkey:latest + +# Verify it's reachable +valkey-cli -h localhost -p 6379 ping # → PONG +``` + +Check that firewall rules allow the port if connecting to a remote host. + +--- + +### Issue 3: `KeyError` — key not found on load + +**Symptom:** + +``` +KeyError: "Valkey key 'detections:frame_001' not found" +``` + +**Possible causes:** + +- The TTL expired before you loaded the key +- The key was written to a different DB number +- A typo in the key name + +**Solution:** + +```bash +# Inspect keys matching a pattern +valkey-cli keys "detections:*" + +# Check if a specific key exists and its TTL +valkey-cli exists detections:frame_001 +valkey-cli ttl detections:frame_001 # -1 = no TTL, -2 = does not exist + +# Check the DB number (default is 0) +valkey-cli -n 1 keys "*" +``` + +--- + +### Issue 4: `ValueError` — cannot auto-detect result type + +**Symptom:** + +``` +ValueError: Cannot auto-detect result type from keys: ['foo', 'bar']. +Specify result_type explicitly. +``` + +**Cause:** The stored JSON does not have the expected top-level keys (`instances`, `predictions`, `depth`, `regions`). This can happen if a non-MATA value was stored under the same key. + +**Solution:** + +```bash +# Inspect the raw value +valkey-cli get my_key | python -c "import sys, json; print(json.dumps(json.loads(sys.stdin.read()), indent=2))" +``` + +If the data is valid MATA JSON but from an older format, specify the type explicitly: + +```python +result = load_valkey(url=URL, key="my_key", result_type="vision") +``` + +--- + +### Issue 5: Large memory usage / OOM + +**Symptom:** Valkey server memory grows unboundedly; keys are never evicted. + +**Cause:** TTL was not set, or the eviction policy is `noeviction` (default). + +**Solution:** + +Set a `maxmemory` limit and an eviction policy in your Valkey config (`valkey.conf`): + +``` +maxmemory 512mb +maxmemory-policy allkeys-lru +``` + +Or configure via `valkey-cli`: + +```bash +valkey-cli config set maxmemory 512mb +valkey-cli config set maxmemory-policy allkeys-lru +``` + +Always set a TTL on keys written by high-frequency pipelines: + +```python +result.save("valkey://localhost:6379/track:latest", ttl=30) +``` + +--- + +## See Also + +- [Graph API Reference — Storage Nodes](GRAPH_API_REFERENCE.md#storage-nodes) — full parameter reference for `ValkeyStore` and `ValkeyLoad` +- [QUICK_REFERENCE.md — Valkey section](../QUICK_REFERENCE.md#️-valkeyredis-storage-quick-reference-v19) — cheatsheet +- [Valkey official documentation](https://valkey.io/documentation/) +- [MATA Validation Guide](VALIDATION_GUIDE.md) diff --git a/examples/graph/README.md b/examples/graph/README.md index cf98aa0..73788ed 100644 --- a/examples/graph/README.md +++ b/examples/graph/README.md @@ -12,17 +12,18 @@ pip install -e ".[dev]" python examples/graph/simple_pipeline.py ``` -## Core Examples (5) +## Core Examples (6) These examples demonstrate the fundamental graph system capabilities: -| Example | Description | Key Features | -| ------------------------------------------- | -------------------------------------------- | ------------------------------------------------ | -| ✅ [simple_pipeline.py](simple_pipeline.py) | Detection > Filter > Segmentation > Fuse | `mata.infer()`, `Graph.then()`, basic pipeline | -| [parallel_tasks.py](parallel_tasks.py) | Parallel detection + classification + depth | `Graph.parallel()`, `ParallelScheduler`, speedup | -| [video_tracking.py](video_tracking.py) | Video processing with object tracking | `VideoProcessor`, `Track`, frame policies | -| [vlm_workflows.py](vlm_workflows.py) | VLM grounded detection & scene understanding | `VLMDetect`, `PromoteEntities`, VLM presets | -| [presets_demo.py](presets_demo.py) | Using pre-built graph presets | `grounding_dino_sam()`, `full_scene_analysis()` | +| Example | Description | Key Features | +| ------------------------------------------------- | -------------------------------------------- | ------------------------------------------------------------------ | +| ✅ [simple_pipeline.py](simple_pipeline.py) | Detection > Filter > Segmentation > Fuse | `mata.infer()`, `Graph.then()`, basic pipeline | +| [parallel_tasks.py](parallel_tasks.py) | Parallel detection + classification + depth | `Graph.parallel()`, `ParallelScheduler`, speedup | +| [video_tracking.py](video_tracking.py) | Video processing with object tracking | `VideoProcessor`, `Track`, frame policies | +| [vlm_workflows.py](vlm_workflows.py) | VLM grounded detection & scene understanding | `VLMDetect`, `PromoteEntities`, VLM presets | +| [presets_demo.py](presets_demo.py) | Using pre-built graph presets | `grounding_dino_sam()`, `full_scene_analysis()` | +| [valkey_pipeline.py](valkey_pipeline.py) | Valkey/Redis result storage & Pub/Sub | `ValkeyStore`, `ValkeyLoad`, `publish_valkey`, rolling stream keys | > **Advanced patterns** (custom nodes, conditional logic, provider integration) are documented with > full context in the [Graph Cookbook](../../docs/GRAPH_COOKBOOK.md). diff --git a/examples/graph/valkey_pipeline.py b/examples/graph/valkey_pipeline.py new file mode 100644 index 0000000..8ece4d6 --- /dev/null +++ b/examples/graph/valkey_pipeline.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +"""Valkey / Redis result storage in graph pipelines. + +Demonstrates five patterns: + + 1. **Basic save/load** — `result.save("valkey://...")` and `load_valkey()` + 2. **`ValkeyStore` sink node** — store mid-pipeline and continue downstream + 3. **Cross-pipeline handoff** — Service A stores, Service B loads via `ValkeyLoad` + 4. **Pub/Sub** — broadcast detection events to live subscribers + 5. **Streaming rolling key** — per-frame overwrite with short TTL for live feeds + +For RTSP stream tracking with Valkey see ``examples/graph/rtsp_pipeline.py``. + +All examples run in **mock mode** by default (no model downloads, no real Valkey +server required). When running against a real server, pass ``--real``. + +Requirements: + pip install mata[valkey] # or mata[redis] + +Usage: + # Mock mode — fully self-contained, no server needed + python examples/graph/valkey_pipeline.py + + # Real mode — requires a Valkey/Redis server on localhost:6379 + python examples/graph/valkey_pipeline.py --real + + # Specify a different server + python examples/graph/valkey_pipeline.py --real --url valkey://myhost:6379 +""" + +from __future__ import annotations + +import sys + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _url_from_args(default: str = "valkey://localhost:6379") -> str: + """Return --url CLI override, or the default.""" + for i, arg in enumerate(sys.argv): + if arg == "--url" and i + 1 < len(sys.argv): + return sys.argv[i + 1] + return default + + +# --------------------------------------------------------------------------- +# Mock providers +# --------------------------------------------------------------------------- + +def create_mock_providers(): + """Create a mock detector that returns a small fixed VisionResult.""" + from unittest.mock import Mock + + from mata.core.types import Instance, VisionResult + + mock_detector = Mock() + mock_detector.predict = Mock(return_value=VisionResult( + instances=[ + Instance(bbox=(50, 30, 220, 300), label=0, score=0.91, label_name="person"), + Instance(bbox=(300, 80, 480, 340), label=2, score=0.76, label_name="car"), + Instance(bbox=(10, 10, 40, 40), label=3, score=0.12, label_name="noise"), + ], + meta={"model": "mock-detector"}, + )) + return {"detector": mock_detector} + + +def create_real_providers(): + """Load a real RT-DETR detector from HuggingFace.""" + import mata + + print("Loading PekingU/rtdetr_r18vd from HuggingFace (this may take a moment)...") + detector = mata.load("detect", "PekingU/rtdetr_r18vd") + return {"detector": detector} + + +# --------------------------------------------------------------------------- +# Mock Valkey client (used when --real is NOT passed) +# --------------------------------------------------------------------------- + +class _MockValkeyClient: + """In-memory stand-in so the example runs without a real server.""" + + def __init__(self): + self._store: dict[str, bytes] = {} + self._channels: dict[str, list] = {} + + # Key/value ops + def set(self, key: str, value, ex=None): + self._store[key] = value if isinstance(value, bytes) else value.encode() + + def setex(self, key: str, ttl: int, value): + self.set(key, value) + + def get(self, key: str): + return self._store.get(key) + + def exists(self, key: str) -> int: + return 1 if key in self._store else 0 + + # Pub/Sub + def publish(self, channel: str, message) -> int: + subscribers = self._channels.get(channel, []) + return len(subscribers) + + +# --------------------------------------------------------------------------- +# Shared inference helper +# --------------------------------------------------------------------------- + +IMAGE_PATH = "examples/images/000000039769.jpg" + + +def _detect(providers: dict) -> "mata.core.types.VisionResult": # type: ignore[name-defined] + """Run detection using the pre-loaded adapter in *providers*. + + Calls ``providers["detector"].predict()`` directly so the same code works + with both mock adapters and real HuggingFace adapters loaded by + ``create_real_providers()``. ``mata.run()`` is intentionally not used here + because it only accepts a model string and loads its own adapter internally. + """ + from PIL import Image as PILImage + image = PILImage.open(IMAGE_PATH).convert("RGB") + return providers["detector"].predict(image) + + +# Monkey-patch the exporter to use the mock client when not in --real mode +_MOCK_CLIENT = _MockValkeyClient() + +def _patch_exporter_with_mock(): + """Replace _get_valkey_client() with our in-memory mock.""" + try: + import mata.core.exporters.valkey_exporter as _ve + _ve._get_valkey_client = lambda url, **kw: _MOCK_CLIENT # type: ignore[attr-defined] + except Exception: + pass # exporter not available — skip patching + + +# --------------------------------------------------------------------------- +# Example 1 — Basic save / load +# --------------------------------------------------------------------------- + +def example_basic_save_load(url: str, providers: dict): + """result.save('valkey://...') and load_valkey() round-trip.""" + print("\n" + "=" * 60) + print("Example 1: Basic save / load") + print("=" * 60) + + from mata.core.exporters import export_valkey, load_valkey + + # Run detection using the pre-loaded adapter (mock or real) + result = _detect(providers) + + print(f"Detections before save: {len(result.instances)} objects") + for inst in result.instances: + print(f" {inst.label_name:10s} score={inst.score:.2f}") + + # --- save via result.save() (valkey:// URI scheme) --- + KEY = "example1:detections:latest" + result.save(f"{url}/{KEY}", ttl=300) + print(f"\nSaved to key '{KEY}' (TTL=300s)") + + # --- also via explicit export_valkey() with msgpack --- + export_valkey( + result, + url=url, + key="example1:detections:msgpack", + ttl=300, + serializer="json", # "msgpack" is faster for large payloads + ) + print("Saved second copy with export_valkey(serializer='json')") + + # --- load back --- + loaded = load_valkey(url=url, key=KEY) + print(f"\nLoaded from key '{KEY}': {len(loaded.instances)} objects " + f"(type={type(loaded).__name__})") + assert len(loaded.instances) == len(result.instances), "Round-trip mismatch!" + print("✓ Round-trip verified") + + +# --------------------------------------------------------------------------- +# Example 2 — ValkeyStore sink node in a graph +# --------------------------------------------------------------------------- + +def example_valkey_store_node(url: str, providers: dict): + """ValkeyStore writes mid-pipeline while passing the artifact downstream.""" + print("\n" + "=" * 60) + print("Example 2: ValkeyStore node in a graph pipeline") + print("=" * 60) + + import mata + from mata.core.graph import Graph + from mata.nodes import Detect, Filter, Fuse, ValkeyStore + + graph = ( + Graph("detection_with_store") + .then(Detect(using="detector", out="raw_dets")) + .then(Filter(src="raw_dets", score_gt=0.5, out="filtered")) + # ── Persist filtered results; {timestamp} resolves to Unix epoch ── + .then(ValkeyStore( + src="filtered", + url=url, + key="pipeline:filtered:{timestamp}", + ttl=600, + )) + # ── Downstream nodes still receive "filtered" unchanged ────────── + .then(Fuse(detections="filtered", out="final")) + ) + + result = mata.infer( + image=IMAGE_PATH, + graph=graph, + providers=providers, + ) + + print(f"Graph channels: {list(result.channels.keys())}") + if result.has_channel("final"): + final = result.get_channel("final") + if final.has_channel("detections"): + dets = final.get_channel("detections") + print(f"Final detections (post-store): {len(dets.instances)} objects") + for inst in dets.instances: + print(f" {inst.label_name:10s} score={inst.score:.2f}") + print("✓ ValkeyStore did not interrupt the downstream pipeline") + + +# --------------------------------------------------------------------------- +# Example 3 — Cross-pipeline handoff: Store A → Load B +# --------------------------------------------------------------------------- + +def example_cross_pipeline(url: str, providers: dict): + """Service A stores; Service B loads via ValkeyLoad and continues.""" + print("\n" + "=" * 60) + print("Example 3: Cross-pipeline handoff (ValkeyStore → ValkeyLoad)") + print("=" * 60) + + import mata + from mata.core.graph import Graph + from mata.nodes import Detect, Filter, Fuse, ValkeyLoad, ValkeyStore + + # ── Service A: camera ingestion pipeline ───────────────────────────── + SHARED_KEY = "cross_pipeline:cam01:latest" + + store_graph = ( + Graph("service_a_ingest") + .then(Detect(using="detector", out="dets")) + .then(Filter(src="dets", score_gt=0.3, out="filtered")) + .then(ValkeyStore( + src="filtered", + url=url, + key=SHARED_KEY, + ttl=15, # stays fresh for 15 s before auto-expiry + )) + ) + + print("Service A: running detection and storing results...") + mata.infer( + image=IMAGE_PATH, + graph=store_graph, + providers=providers, + ) + print(f" Stored filtered detections under '{SHARED_KEY}'") + + # ── Service B: downstream analytics pipeline ───────────────────────── + load_graph = ( + Graph("service_b_analytics") + .then(ValkeyLoad( + url=url, + key=SHARED_KEY, + result_type="auto", # auto-detect VisionResult + out="loaded_dets", + )) + .then(Filter(src="loaded_dets", score_gt=0.7, out="hi_conf")) + .then(Fuse(detections="hi_conf", out="analytics")) + ) + + print("\nService B: loading stored results and applying hi-conf filter...") + result_b = mata.infer( + image=IMAGE_PATH, + graph=load_graph, + providers={}, # no detection model needed in Service B + ) + + print(f" Channels: {list(result_b.channels.keys())}") + if result_b.has_channel("analytics"): + anal = result_b.get_channel("analytics") + if anal.has_channel("detections"): + hi = anal.get_channel("detections") + print(f" Hi-confidence objects: {len(hi.instances)}") + print("✓ Service B successfully consumed Service A's stored results") + + +# --------------------------------------------------------------------------- +# Example 4 — Pub/Sub broadcast +# --------------------------------------------------------------------------- + +def example_pubsub(url: str, providers: dict): + """publish_valkey() broadcasts a result to all active channel subscribers.""" + print("\n" + "=" * 60) + print("Example 4: Pub/Sub event broadcast") + print("=" * 60) + + from mata.core.exporters import publish_valkey + + result = _detect(providers) + + CHANNEL = "detections:events:cam01" + n = publish_valkey( + result=result, + url=url, + channel=CHANNEL, + serializer="json", + ) + print(f"Published detection event to channel '{CHANNEL}'") + print(f"Delivered to {n} subscriber(s)") + print("(In production, start a subscriber before publishing — " + "Pub/Sub is fire-and-forget.)") + + # ── combined: persist + broadcast ──────────────────────────────────── + print("\nCombined persist + broadcast pattern:") + from mata.core.exporters import export_valkey + export_valkey(result, url=url, key="det:latest", ttl=60) + n = publish_valkey(result, url=url, channel="det:events") + print(f" Persisted to 'det:latest' (TTL=60s)") + print(f" Broadcast to 'det:events' → {n} subscriber(s)") + print("✓ Pub/Sub example complete") + + +# --------------------------------------------------------------------------- +# Example 5 — Streaming per-frame rolling key +# --------------------------------------------------------------------------- + +def example_streaming_rolling_key(url: str, providers: dict): + """Simulate per-frame tracking with a rolling Valkey key (TTL reset each frame).""" + print("\n" + "=" * 60) + print("Example 5: Streaming / per-frame rolling key") + print("=" * 60) + + from mata.core.exporters import export_valkey, load_valkey + + ROLLING_KEY = "stream:track:latest" + HISTORY_PREFIX = "stream:track:history" + FRAMES = 5 + + print(f"Simulating {FRAMES} frames...") + for frame_idx in range(FRAMES): + result = _detect(providers) + # Rolling latest — overwritten every frame, expires if feed drops + export_valkey(result, url=url, key=ROLLING_KEY, ttl=10) + # Indexed history for audit trail + export_valkey(result, url=url, key=f"{HISTORY_PREFIX}:{frame_idx:06d}", ttl=3600) + print(f" Frame {frame_idx:02d}: stored {len(result.instances)} detections") + + # Read back the latest + latest = load_valkey(url=url, key=ROLLING_KEY) + print(f"\nLatest key '{ROLLING_KEY}': {len(latest.instances)} detections") + + # Read a specific historical frame + frame_2 = load_valkey(url=url, key=f"{HISTORY_PREFIX}:000002") + print(f"History frame 2: {len(frame_2.instances)} detections") + print("✓ Streaming rolling-key pattern complete") + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +def main(): + use_real = "--real" in sys.argv + url = _url_from_args() + + if use_real: + print(f"Running against real Valkey server: {url}") + print("Make sure valkey-py is installed: pip install mata[valkey]") + providers = create_real_providers() + else: + print("Running in MOCK mode (in-memory store, no server required)") + print("Pass --real to run against a real Valkey server.") + _patch_exporter_with_mock() + providers = create_mock_providers() + + example_basic_save_load(url, providers) + example_valkey_store_node(url, providers) + example_cross_pipeline(url, providers) + example_pubsub(url, providers) + example_streaming_rolling_key(url, providers) + + print("\n" + "=" * 60) + print("All Valkey examples completed successfully.") + print("See docs/VALKEY_GUIDE.md for the full reference.") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/graph/valkey_rtsp_pipeline.py b/examples/graph/valkey_rtsp_pipeline.py new file mode 100644 index 0000000..2ea2fcc --- /dev/null +++ b/examples/graph/valkey_rtsp_pipeline.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +"""RTSP stream tracking with Valkey result storage. + +Demonstrates per-frame BotSort object tracking on an RTSP stream where every +frame's results are: + + - Stored as an indexed snapshot ``rtsp:cam01:tracks:`` (TTL 1 h) + - Written to a rolling latest key ``rtsp:cam01:tracks:latest`` (TTL 15 s) + - Broadcast on Pub/Sub channel ``rtsp:cam01:track:events`` + +A round-trip verification load is performed after the loop. + +Requirements: + pip install mata[valkey] # Valkey / Redis client + pip install mata opencv-python # RTSP frame capture + +Usage: + python examples/graph/rtsp_pipeline.py + + # Override the Valkey URL or RTSP source + python examples/graph/rtsp_pipeline.py \\ + --url valkey://myhost:6379 \\ + --rtsp rtsp://example:example@192.168.1.10:8554/Streaming/Channels/102 + +""" + +from __future__ import annotations + +import sys + + +# --------------------------------------------------------------------------- +# CLI helpers +# --------------------------------------------------------------------------- + +def _arg(flag: str, default: str) -> str: + """Return the value after *flag* in sys.argv, or *default*.""" + for i, arg in enumerate(sys.argv): + if arg == flag and i + 1 < len(sys.argv): + return sys.argv[i + 1] + return default + + +# --------------------------------------------------------------------------- +# Provider factory +# --------------------------------------------------------------------------- + +def create_tracker(rtsp_url: str): + """Load a real RT-DETR detector and wrap it with a BotSort tracker.""" + import mata + from mata.adapters.tracking_adapter import TrackingAdapter + + print("Loading PekingU/rtdetr_r18vd from HuggingFace (this may take a moment)...") + detector = mata.load("detect", "PekingU/rtdetr_r18vd") + tracker = TrackingAdapter(detector, tracker_config="botsort") + print(f"BotSort tracker ready — RTSP source: {rtsp_url}") + return tracker + + +# --------------------------------------------------------------------------- +# Main example +# --------------------------------------------------------------------------- + +def run(url: str, rtsp_url: str, tracker, frames: int = 8): + """Run the per-frame tracking + Valkey storage loop. + + Args: + url: Valkey server URL (e.g. ``valkey://localhost:6379``). + rtsp_url: RTSP stream URL to open with cv2.VideoCapture. + tracker: :class:`~mata.adapters.tracking_adapter.TrackingAdapter`. + frames: Maximum number of frames to process (0 = unlimited). + """ + import cv2 + from PIL import Image as PILImage + + from mata.core.exporters import export_valkey, load_valkey, publish_valkey + + TRACK_KEY_PREFIX = "rtsp:cam01:tracks" + LATEST_KEY = "rtsp:cam01:tracks:latest" + CHANNEL = "rtsp:cam01:track:events" + + print(f"\nOpening RTSP stream: {rtsp_url}") + cap = cv2.VideoCapture(rtsp_url) + if not cap.isOpened(): + raise RuntimeError(f"Could not open RTSP stream: {rtsp_url}") + + print(f"Processing up to {frames} frame(s)...") + print(f" Indexed snapshots : {TRACK_KEY_PREFIX}: (TTL 1 h)") + print(f" Rolling latest : {LATEST_KEY} (TTL 15 s)") + print(f" Pub/Sub channel : {CHANNEL}") + print() + + track_counts: list[int] = [] + frame_idx = 0 + + try: + while frames == 0 or frame_idx < frames: + ret, bgr_frame = cap.read() + if not ret: + print("Stream ended or frame read failed — stopping.") + break + + pil_frame = PILImage.fromarray(cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)) + result = tracker.update(pil_frame, persist=True) + + active = [inst for inst in result.instances if inst.track_id is not None] + track_counts.append(len(active)) + + # Per-frame indexed snapshot — long TTL for audit / replay + export_valkey( + result, + url=url, + key=f"{TRACK_KEY_PREFIX}:{frame_idx:06d}", + ttl=3600, + ) + # Rolling latest — short TTL so it auto-expires if the stream drops + export_valkey(result, url=url, key=LATEST_KEY, ttl=15) + + # Broadcast to any live dashboards or alert subscribers + n_subs = publish_valkey(result, url=url, channel=CHANNEL, serializer="json") + + ids = [str(inst.track_id) for inst in active] + print(f" Frame {frame_idx:02d}: {len(active)} track(s) IDs={ids}" + f" → {n_subs} subscriber(s) notified") + + frame_idx += 1 + finally: + cap.release() + + if frame_idx == 0: + print("No frames were processed.") + return + + # ── Round-trip verification ────────────────────────────────────────────── + print() + latest = load_valkey(url=url, key=LATEST_KEY) + active_latest = [i for i in latest.instances if i.track_id is not None] + print(f"Latest key '{LATEST_KEY}': {len(active_latest)} active track(s)") + for inst in active_latest: + print(f" Track #{inst.track_id} {inst.label_name:<10} " + f"score={inst.score:.2f} bbox={inst.bbox}") + + mid_idx = frame_idx // 2 + mid_key = f"{TRACK_KEY_PREFIX}:{mid_idx:06d}" + mid = load_valkey(url=url, key=mid_key) + mid_ids = [str(i.track_id) for i in mid.instances if i.track_id is not None] + print(f"Mid-stream frame [{mid_idx}] ('{mid_key}'): " + f"{len(mid.instances)} instance(s) IDs={mid_ids}") + + avg = sum(track_counts) / len(track_counts) if track_counts else 0 + print(f"Average active tracks/frame: {avg:.1f} over {frame_idx} frames") + print("✓ RTSP tracking + Valkey storage example complete") + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +def main(): + url = _arg("--url", "valkey://localhost:6379") + rtsp_url = _arg("--rtsp", "rtsp://example:example@192.168.1.100:8554/Streaming/Channels/102") + frames = int(_arg("--frames", "8")) + + print(f"Valkey server : {url}") + print(f"RTSP source : {rtsp_url}") + print("Make sure valkey-py is installed: pip install mata[valkey]") + + tracker = create_tracker(rtsp_url) + run(url=url, rtsp_url=rtsp_url, tracker=tracker, frames=frames) + + +if __name__ == "__main__": + main() diff --git a/examples/track/cross_camera_reid.py b/examples/track/cross_camera_reid.py new file mode 100644 index 0000000..d7379c1 --- /dev/null +++ b/examples/track/cross_camera_reid.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +"""Cross-camera object re-identification via Valkey (v1.9.2). + +Demonstrates how to use ``ReIDBridge`` so that independent tracker instances +running on different camera feeds can resolve the same physical identity across +feeds using shared Valkey embedding storage. + +Architecture: + Camera A tracker + └─ mata.track(..., reid_bridge=bridge_a) + └─ bridge_a.publish(track_id, embedding) → Valkey + Camera B tracker + └─ mata.track(..., reid_bridge=bridge_b) + └─ bridge_b.query(embedding) ← Valkey ← cam-a embeddings + +Features demonstrated: +- Constructing ``ReIDBridge`` with camera_id, TTL, and similarity_thresh +- Publishing embeddings per track per frame (automatic via reid_bridge kwarg) +- Querying for cross-camera nearest-identity matches +- Mocked Valkey client (no real server required for the basic demo) +- Real Valkey/Redis server usage notes + +Usage (mock mode — no Valkey server required): + python examples/track/cross_camera_reid.py + +Usage (real Valkey server): + python examples/track/cross_camera_reid.py --valkey valkey://localhost:6379 + +Requirements: + pip install mata[valkey] transformers torch + # or: + pip install mata[redis] transformers torch +""" +from __future__ import annotations + +import argparse +import time +from typing import Any + + +# --------------------------------------------------------------------------- +# Mock Valkey client (used when no --valkey URL is supplied) +# --------------------------------------------------------------------------- + +class _MockValkeyClient: + """In-memory key-value store that mimics the subset of Valkey API used by ReIDBridge.""" + + def __init__(self) -> None: + self._store: dict[bytes | str, bytes] = {} + self._ttls: dict[bytes | str, float] = {} + + def set(self, key: str, value: bytes, ex: int | None = None) -> None: + self._store[key] = value + if ex is not None: + self._ttls[key] = time.time() + ex + + def get(self, key: str) -> bytes | None: + if key in self._ttls and time.time() > self._ttls[key]: + self._store.pop(key, None) + self._ttls.pop(key, None) + return self._store.get(key) + + def delete(self, *keys: str) -> int: + count = 0 + for k in keys: + if k in self._store: + del self._store[k] + self._ttls.pop(k, None) + count += 1 + return count + + def scan_iter(self, match: str = "*", count: int = 100): + """Yield keys matching a simple glob pattern (only ``*`` wildcard).""" + import fnmatch + active = [ + k for k, exp in list(self._ttls.items()) + if time.time() <= exp + ] + [ + k for k in self._store if k not in self._ttls + ] + for key in active: + k_str = key.decode() if isinstance(key, bytes) else key + if fnmatch.fnmatch(k_str, match): + yield key + + def ping(self) -> bool: + return True + + +def _build_mock_bridge(camera_id: str, similarity_thresh: float = 0.25, + shared_store: _MockValkeyClient | None = None): + """Build a ReIDBridge backed by a mock in-memory client.""" + from unittest.mock import patch + from mata.trackers.reid_bridge import ReIDBridge + + client = shared_store or _MockValkeyClient() + bridge = ReIDBridge.__new__(ReIDBridge) + bridge._client = client + bridge._camera_id = camera_id + bridge._ttl = 300 + bridge._similarity_thresh = similarity_thresh + bridge._prefix = "reid" + return bridge + + +# --------------------------------------------------------------------------- +# Example 1 — Publish / Query round-trip (mock) +# --------------------------------------------------------------------------- + +def run_publish_query_demo() -> None: + """Demonstrate publish / query semantics with a mocked Valkey client.""" + import numpy as np + + print("\n=== Example 1: Publish → Query round-trip (mocked Valkey) ===\n") + + shared_store = _MockValkeyClient() + + bridge_cam_a = _build_mock_bridge("cam-a", shared_store=shared_store) + bridge_cam_b = _build_mock_bridge("cam-b", shared_store=shared_store) + + # Camera A observes person with track_id=7 + emb_a = np.random.randn(128).astype(np.float32) + emb_a /= np.linalg.norm(emb_a) + + bridge_cam_a.publish( + track_id=7, + embedding=emb_a, + bbox=(120.0, 50.0, 200.0, 280.0), + label=0, + ) + print(f" cam-a: published track #7 (embedding norm={np.linalg.norm(emb_a):.4f})") + + # Camera B receives a detection with very similar appearance + # (in real usage this would be from a different camera angle of the same person) + noise = np.random.randn(128).astype(np.float32) * 0.05 + emb_b_query = emb_a + noise + emb_b_query /= np.linalg.norm(emb_b_query) + + matches = bridge_cam_b.query(emb_b_query, exclude_camera="cam-b", top_k=3) + + if matches: + best = matches[0] + print( + f" cam-b: best cross-camera match → " + f"camera={best['camera_id']!r} track_id={best['track_id']} " + f"similarity={best['similarity']:.4f}" + ) + print(f" ✅ Cross-camera identity resolved: cam-b detection → cam-a track #7\n") + else: + print(" ⚠ No match found (similarity below threshold)\n") + + # Demonstrate no self-match + self_matches = bridge_cam_a.query(emb_a, exclude_camera="cam-a") + assert not self_matches, "Should not return own camera embeddings" + print(" ✅ Self-camera exclusion works (no same-camera matches)\n") + + +# --------------------------------------------------------------------------- +# Example 2 — Multi-camera tracking loop (mock) +# --------------------------------------------------------------------------- + +def run_multi_camera_loop(num_frames: int = 5) -> None: + """Simulate two parallel camera trackers publishing to a shared store.""" + import numpy as np + from mata.adapters.tracking_adapter import TrackingAdapter + from mata.core.types import Instance, VisionResult + + print("\n=== Example 2: Two-camera tracking loop with shared ReID store ===\n") + + shared_store = _MockValkeyClient() + bridge_a = _build_mock_bridge("cam-a", shared_store=shared_store) + bridge_b = _build_mock_bridge("cam-b", shared_store=shared_store) + + def _make_mock_adapter(camera_label: str, track_id: int, + base_emb: "np.ndarray", bridge) -> TrackingAdapter: + """Build a TrackingAdapter mock that auto-publishes to the bridge.""" + from unittest.mock import Mock + + call_count = {"n": 0} + + def mock_update(image, **kw): + n = call_count["n"] + call_count["n"] += 1 + # Slightly perturb embedding each frame (EMA simulation) + emb = base_emb + np.random.randn(128).astype(np.float32) * 0.03 + emb /= np.linalg.norm(emb) + inst = Instance( + bbox=(100.0 + n * 2, 50.0, 190.0 + n * 2, 280.0), + label=0, score=0.90, label_name="person", + track_id=track_id, embedding=emb, + ) + result = VisionResult(instances=[inst], meta={"frame_idx": n, "camera": camera_label}) + # Simulate the publish step that TrackingAdapter.update() performs when reid_bridge is set + for i in result.instances: + if i.track_id is not None and i.embedding is not None: + bridge.publish(i.track_id, i.embedding, bbox=i.bbox, label=i.label) + return result + + adapter = Mock() + adapter.update = mock_update + return adapter + + # Two cameras tracking the same physical person with different track IDs + base_person_emb = np.random.randn(128).astype(np.float32) + base_person_emb /= np.linalg.norm(base_person_emb) + + adapter_a = _make_mock_adapter("cam-a", track_id=3, base_emb=base_person_emb, bridge=bridge_a) + adapter_b = _make_mock_adapter("cam-b", track_id=11, base_emb=base_person_emb, bridge=bridge_b) + + dummy_frame = np.zeros((480, 640, 3), dtype=np.uint8) + + for frame_idx in range(num_frames): + r_a = adapter_a.update(dummy_frame) + r_b = adapter_b.update(dummy_frame) + + # Camera B queries cross-camera matches for its track #11 + if r_b.instances: + emb_b = r_b.instances[0].embedding + matches = bridge_b.query(emb_b, exclude_camera="cam-b", top_k=1) + match_str = ( + f"→ cam-a track #{matches[0]['track_id']} " + f"(sim={matches[0]['similarity']:.3f})" + if matches else "→ no match" + ) + print( + f" Frame {frame_idx} | cam-b track #11 " + f"(norm={np.linalg.norm(emb_b):.4f}) {match_str}" + ) + + print("\n ✅ Multi-camera ReID loop complete\n") + + +# --------------------------------------------------------------------------- +# Example 3 — Real Valkey server usage notes +# --------------------------------------------------------------------------- + +def print_real_valkey_notes(valkey_url: str) -> None: + """Print real-server usage notes (or attempt a live demo if server available).""" + print("\n=== Example 3: Real Valkey server usage ===\n") + + print(f" Connection URL: {valkey_url}\n") + + try: + from mata.trackers.reid_bridge import ReIDBridge + bridge_test = ReIDBridge(valkey_url, camera_id="test-cam", ttl=10) + bridge_test._client.ping() + print(" ✅ Valkey server reachable — live demo mode\n") + + import numpy as np + + emb = np.random.randn(128).astype(np.float32) + emb /= np.linalg.norm(emb) + bridge_test.publish(track_id=99, embedding=emb, bbox=(0., 0., 100., 200.), label=0) + print(" Published test embedding for track #99") + + matches = bridge_test.query(emb, exclude_camera="other-cam", top_k=1) + print(f" Self-query (include own camera): would find {len(matches)} match(es)") + + count = bridge_test.clear(camera_id="test-cam") + print(f" Cleared {count} test key(s)") + print("\n ✅ Live Valkey round-trip successful\n") + + except Exception as exc: + print(f" ⚠ Could not connect to Valkey ({exc})") + print(" Running as notes-only mode\n") + + print(" Full ReIDBridge API reference:\n") + print(" from mata.trackers import ReIDBridge\n") + print(" bridge = ReIDBridge(") + print(' "valkey://localhost:6379",') + print(' camera_id="cam-front",') + print(" ttl=300, # embeddings expire after 5 minutes") + print(" similarity_thresh=0.25, # cosine similarity cutoff") + print(" )\n") + print(" # In your tracking loop (automatic when passing reid_bridge to mata.track()):") + print(" bridge.publish(track_id=42, embedding=emb, bbox=(x1,y1,x2,y2), label=0)\n") + print(" # Query from a different camera process:") + print(' matches = bridge.query(query_emb, exclude_camera="cam-front", top_k=1)') + print(" # matches = [{'track_id': 42, 'camera_id': 'cam-front', 'similarity': 0.87, ...}]\n") + print(" # Attach bridge to mata.track() for automatic publishing:") + print(" for result in mata.track(") + print(' "rtsp://cam-front/stream",') + print(' model="facebook/detr-resnet-50",') + print(' reid_model="openai/clip-vit-base-patch32",') + print(" reid_bridge=bridge,") + print(" stream=True,") + print(" ):") + print(" ...") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="MATA cross-camera ReID via Valkey example (v1.9.2)" + ) + p.add_argument( + "--valkey", metavar="URL", + help="Valkey/Redis URL for live demo (e.g. valkey://localhost:6379). " + "Omit to run with mocked client.", + ) + return p.parse_args() + + +def main() -> None: + args = _parse_args() + + run_publish_query_demo() + run_multi_camera_loop() + print_real_valkey_notes(args.valkey or "valkey://localhost:6379") + + print("=" * 60) + print("Done.") + print("See examples/track/reid_tracking.py for single-camera ReID usage.") + + +if __name__ == "__main__": + main() diff --git a/examples/track/reid_tracking.py b/examples/track/reid_tracking.py new file mode 100644 index 0000000..fbd0e21 --- /dev/null +++ b/examples/track/reid_tracking.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +"""Single-camera object tracking with appearance-based ReID (v1.9.2). + +Demonstrates how to enable ReID in mata.track() to improve track-ID recovery +after occlusion or target re-entry — using BotSort's appearance-distance branch. + +Features demonstrated: +- ``mata.track()`` with ``reid_model`` kwarg +- Inspecting ``Instance.embedding`` vectors +- Low-level persistent tracking via ``TrackingAdapter`` with ``reid_encoder`` +- ReID config alias in ``.mata/models.yaml`` + +Usage (mock mode — no GPU or real models required): + python examples/track/reid_tracking.py + +Usage (real video + model): + python examples/track/reid_tracking.py --real examples/videos/cup.mp4 + +Requirements: + pip install mata transformers torch +""" +from __future__ import annotations + +import argparse + + +# --------------------------------------------------------------------------- +# Mock helpers (used when --real is not supplied) +# --------------------------------------------------------------------------- + +def _make_mock_detector(): + """Return a minimal detector mock that produces synthetic detections.""" + from unittest.mock import Mock + from mata.core.types import Instance, VisionResult + + frame_count = {"n": 0} + _LABEL_NAMES = {0: "person", 2: "car"} + + def mock_predict(image, **kwargs): + n = frame_count["n"] + frame_count["n"] += 1 + x = 80 + n * 3 # person drifts right each frame + return VisionResult( + instances=[ + Instance(bbox=(x, 50, x + 90, 290), label=0, + score=0.92, label_name="person"), + Instance(bbox=(350, 110, 510, 270), label=2, + score=0.85, label_name="car"), + ], + meta={"frame_idx": n}, + ) + + det = Mock() + det.predict = mock_predict + # Provide id2label so TrackingAdapter can resolve label names + det.id2label = _LABEL_NAMES + return det + + +def _make_mock_reid_encoder(embedding_dim: int = 128): + """Return a mock ReID encoder that produces random unit-norm embeddings.""" + import numpy as np + from unittest.mock import Mock + + def mock_predict(crops): + if not crops: + return np.empty((0, 0), dtype=np.float32) + n = len(crops) + raw = np.random.randn(n, embedding_dim).astype(np.float32) + norms = np.linalg.norm(raw, axis=1, keepdims=True) + return raw / np.where(norms == 0, 1.0, norms) + + encoder = Mock() + encoder.predict = mock_predict + return encoder + + +# --------------------------------------------------------------------------- +# Example 1 — mata.track() one-liner with reid_model +# --------------------------------------------------------------------------- + +def run_one_liner(video_path: str, *, real: bool = False) -> None: + """Show mata.track() one-liner API with reid_model.""" + print("\n=== Example 1: mata.track() one-liner with ReID ===\n") + + if not real: + import numpy as np + from mata.core.types import Instance, VisionResult + + print(" [mock] mata.track() would be called as:\n") + print(" results = mata.track(") + print(' "video.mp4",') + print(' model="facebook/detr-resnet-50",') + print(' tracker="botsort",') + print(' reid_model="openai/clip-vit-base-patch32",') + print(" conf=0.3,") + print(" save=False,") + print(" )\n") + print(" [mock] Simulating 10 frames of tracked objects...\n") + + for frame_idx in range(10): + emb = np.random.randn(128).astype(np.float32) + emb /= np.linalg.norm(emb) + result = VisionResult( + instances=[ + Instance( + bbox=(80 + frame_idx * 3, 50, 170 + frame_idx * 3, 290), + label=0, score=0.92, label_name="person", + track_id=1, embedding=emb, + ), + ], + meta={"frame_idx": frame_idx}, + ) + for inst in result.instances: + emb_str = ( + f"shape=({inst.embedding.shape[0]},) " + f"norm={np.linalg.norm(inst.embedding):.4f}" + if inst.embedding is not None else "None" + ) + print( + f" Frame {frame_idx:02d} | Track #{inst.track_id} " + f"{inst.label_name:<8} score={inst.score:.2f} | " + f"embedding {emb_str}" + ) + print("\n ✅ ReID embeddings populated in Instance.embedding\n") + return + + # Real mode — downloads model on first run + import mata + import numpy as np + + results = mata.track( + video_path, + model="facebook/detr-resnet-50", + tracker="botsort", + reid_model="openai/clip-vit-base-patch32", + conf=0.3, + save=False, + ) + + for frame_idx, result in enumerate(results): + for inst in result.instances: + emb_info = "" + if inst.embedding is not None: + emb_info = ( + f"embedding shape=({inst.embedding.shape[0]},) " + f"norm={np.linalg.norm(inst.embedding):.4f}" + ) + print( + f" Frame {frame_idx:02d} | Track #{inst.track_id} " + f"{inst.label_name:<10} score={inst.score:.2f} | {emb_info}" + ) + + +# --------------------------------------------------------------------------- +# Example 2 — Low-level TrackingAdapter with reid_encoder +# --------------------------------------------------------------------------- + +def run_low_level(num_frames: int = 5) -> None: + """Show low-level TrackingAdapter with reid_encoder.""" + import numpy as np + from mata.adapters.tracking_adapter import TrackingAdapter + + print("\n=== Example 2: Low-level TrackingAdapter with reid_encoder ===\n") + + mock_detector = _make_mock_detector() + mock_encoder = _make_mock_reid_encoder(embedding_dim=128) + + adapter = TrackingAdapter( + mock_detector, + tracker_config={"tracker_type": "botsort"}, + frame_rate=25, + reid_encoder=mock_encoder, + ) + print(f" TrackingAdapter created. reid_encoder set: {adapter._reid_encoder is not None}\n") + + for frame_idx in range(num_frames): + frame = np.zeros((480, 640, 3), dtype=np.uint8) # blank synthetic frame + result = adapter.update(frame) + + for inst in result.instances: + emb_str = "None" + if inst.embedding is not None: + emb_str = ( + f"shape=({inst.embedding.shape[0]},) " + f"norm={np.linalg.norm(inst.embedding):.4f}" + ) + label = str(inst.label_name) if inst.label_name is not None else "?" + print( + f" Frame {frame_idx} | Track #{inst.track_id} " + f"{label:<8} | embedding {emb_str}" + ) + + print("\n ✅ Low-level ReID tracking complete\n") + + +# --------------------------------------------------------------------------- +# Example 3 — YAML config alias with reid_model +# --------------------------------------------------------------------------- + +def print_config_example() -> None: + """Print example .mata/models.yaml config for ReID-enabled tracking.""" + print("\n=== Example 3: Config alias with reid_model ===\n") + print(" Place this in .mata/models.yaml:\n") + config = """\ + models: + track: + smart-cam: + source: "facebook/detr-resnet-50" + tracker: botsort + reid_model: "openai/clip-vit-base-patch32" + frame_rate: 30 + tracker_config: + track_high_thresh: 0.6 + appearance_thresh: 0.25 + track_buffer: 60 +""" + print(config) + print(" Then load with a single call:\n") + print(' import mata') + print(' tracker = mata.load("track", "smart-cam") # ReID loaded automatically') + print(" result = tracker.update(frame)\n") + print(" ✅ Config alias with reid_model demonstrated\n") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="MATA ReID tracking example (v1.9.2)") + p.add_argument( + "--real", metavar="VIDEO", + help="Path to a real video file (downloads model on first run)", + ) + return p.parse_args() + + +def main() -> None: + args = _parse_args() + real = bool(args.real) + video = args.real or "video.mp4" + + run_one_liner(video, real=real) + run_low_level() + print_config_example() + + print("=" * 60) + print("Done.") + print("See examples/track/cross_camera_reid.py for Valkey cross-camera ReID.") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index afbfdb1..3cda02e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mata" -version = "1.9.0" +version = "1.9.2b1" description = "Model-Agnostic Task Architecture - A task-centric, model-agnostic framework for computer vision" readme = "README.md" requires-python = ">=3.10" @@ -103,13 +103,23 @@ ocr-all = [ "pytesseract>=0.3.10", ] +# Valkey/Redis result storage +valkey = [ + "valkey>=6.0.0", +] + +# Redis fallback (wire-compatible with Valkey) +redis = [ + "redis>=5.0.0", +] + # Full installation (all optional features except GPU ONNX) all = [ "mata[onnx,classification,eval,viz,segmentation,ocr]", ] dev = [ - "mata[all]", + "mata[all,valkey,redis]", "pytest>=7.4.0", "pytest-cov>=4.1.0", "black>=23.0.0", diff --git a/src/mata/__init__.py b/src/mata/__init__.py index 06e9205..225443f 100644 --- a/src/mata/__init__.py +++ b/src/mata/__init__.py @@ -24,7 +24,7 @@ >>> print(mata.list_models("detect")) """ -__version__ = "1.9.0" +__version__ = "1.9.2b1" from .api import get_model_info, infer, list_models, load, register_model, run, track, val, verbose from .core import ( diff --git a/src/mata/adapters/__init__.py b/src/mata/adapters/__init__.py index 58c037e..337c58d 100644 --- a/src/mata/adapters/__init__.py +++ b/src/mata/adapters/__init__.py @@ -16,6 +16,7 @@ from .pytorch_adapter import PyTorchDetectAdapter from .pytorch_base import PyTorchBaseAdapter from .pytorch_classify_adapter import PyTorchClassifyAdapter +from .reid_adapter import HuggingFaceReIDAdapter, ONNXReIDAdapter, ReIDAdapter from .torchscript_adapter import TorchScriptDetectAdapter from .torchscript_classify_adapter import TorchScriptClassifyAdapter from .torchvision_detect_adapter import TorchvisionDetectAdapter @@ -25,6 +26,9 @@ "BaseAdapter", "PyTorchBaseAdapter", "ONNXBaseAdapter", + "ReIDAdapter", + "HuggingFaceReIDAdapter", + "ONNXReIDAdapter", # Detection adapters "HuggingFaceDetectAdapter", "HuggingFaceDepthAdapter", diff --git a/src/mata/adapters/reid_adapter.py b/src/mata/adapters/reid_adapter.py new file mode 100644 index 0000000..91e60cf --- /dev/null +++ b/src/mata/adapters/reid_adapter.py @@ -0,0 +1,345 @@ +"""ReID feature extraction adapter. + +Internal adapter used by TrackingAdapter to extract appearance embeddings +from detection crops. Not a public task adapter — users access ReID +through mata.track(with_reid=True). +""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any + +import numpy as np + +from mata.adapters.pytorch_base import PyTorchBaseAdapter +from mata.core.logging import get_logger + +logger = get_logger(__name__) + + +class ReIDAdapter(PyTorchBaseAdapter): + """Base class for ReID feature extraction adapters. + + Subclasses must implement: + _load_model() — Load the encoder model and preprocessor. + _extract_single(crop) — Extract embedding from one crop. + + The public interface is: + predict(crops) — Batch-extract embeddings from N crops. + Returns (N, D) float32 array, L2-normalised. + """ + + def __init__(self, model_id: str, device: str = "auto", **kwargs: Any) -> None: + super().__init__(device=device) + self.model_id = model_id + self._embedding_dim: int | None = None + self._load_model(**kwargs) + + @abstractmethod + def _load_model(self, **kwargs: Any) -> None: + """Load encoder weights and preprocessor.""" + + @abstractmethod + def _extract_single(self, crop: np.ndarray) -> np.ndarray: + """Extract raw embedding from a single BGR/RGB crop. + + Args: + crop: (H, W, 3) uint8 numpy array. + + Returns: + 1-D float32 embedding vector (unnormalised). + """ + + def predict(self, crops: list[np.ndarray]) -> np.ndarray: + """Batch-extract L2-normalised embeddings. + + Args: + crops: List of (H, W, 3) uint8 numpy arrays. + + Returns: + (N, D) float32 array, each row L2-normalised. + Returns empty (0, 0) array if crops is empty. + """ + if not crops: + return np.empty((0, 0), dtype=np.float32) + + embeddings = [] + for crop in crops: + emb = self._extract_single(crop) + embeddings.append(emb) + + result = np.stack(embeddings).astype(np.float32) + + # L2 normalise each row + norms = np.linalg.norm(result, axis=1, keepdims=True) + norms = np.where(norms > 1e-9, norms, 1.0) + result = result / norms + + self._embedding_dim = result.shape[1] + return result + + @property + def embedding_dim(self) -> int | None: + """Embedding dimensionality (available after first predict call).""" + return self._embedding_dim + + def info(self) -> dict[str, Any]: + return { + "type": "reid", + "model_id": self.model_id, + "embedding_dim": self._embedding_dim, + "device": str(self.device), + } + + +class ONNXReIDAdapter(ReIDAdapter): + """ONNX Runtime ReID feature extractor. + + Loads a .onnx file with a single image input and single embedding output. + Input shape: (1, 3, H, W) or (1, H, W, 3) — auto-detected from model metadata. + Output shape: (1, D) — embedding dimension read from output spec. + + Args: + model_id: Path to the .onnx file. + device: Ignored for ONNX (use ``providers`` kwarg instead). + providers: ONNX Runtime execution providers list. + Defaults to ``["CPUExecutionProvider"]``. + + Example:: + + adapter = ONNXReIDAdapter("osnet.onnx") + embeddings = adapter.predict([crop1, crop2]) # (2, D) float32 + """ + + # ImageNet normalisation constants (RGB, [0, 1] range) + _MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) + _STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) + + def _load_model(self, **kwargs: Any) -> None: + import onnxruntime as ort + + providers = kwargs.get("providers", ["CPUExecutionProvider"]) + self._session = ort.InferenceSession(self.model_id, providers=providers) + + inp = self._session.get_inputs()[0] + self._input_name: str = inp.name + self._input_shape: list = inp.shape # e.g. [1, 3, 256, 128] + + self._layout: str = self._detect_layout(self._input_shape) + logger.info(f"Loaded ONNX ReID model: {self.model_id} " f"(input={self._input_shape}, layout={self._layout})") + + @staticmethod + def _detect_layout(shape: list) -> str: + """Detect input tensor layout (NCHW or NHWC) from ONNX input shape. + + Inspects the channel dimension (index 1 for NCHW, index 3 for NHWC). + When ambiguous, defaults to NCHW which is more common for ReID models. + + Args: + shape: ONNX input shape list, e.g. ``[1, 3, 256, 128]``. + + Returns: + ``"NCHW"`` or ``"NHWC"``. + """ + if len(shape) != 4: + return "NCHW" + + c_at_1 = shape[1] + c_at_3 = shape[3] + + # NHWC: channels at index 3 == 3, channels at index 1 is not 3 + if isinstance(c_at_3, int) and c_at_3 == 3: + if not (isinstance(c_at_1, int) and c_at_1 == 3): + return "NHWC" + # Both indices happen to be 3 — default NCHW (more common) + return "NCHW" + + # NCHW: channels at index 1 == 3 + return "NCHW" + + def _get_spatial_dims(self) -> tuple[int, int]: + """Return ``(height, width)`` expected by the model. + + Falls back to ``(256, 128)`` (a common ReID resolution) when the ONNX + model uses dynamic/symbolic dimensions. + """ + shape = self._input_shape + if self._layout == "NCHW": + h = shape[2] if isinstance(shape[2], int) and shape[2] > 0 else 256 + w = shape[3] if isinstance(shape[3], int) and shape[3] > 0 else 128 + else: # NHWC + h = shape[1] if isinstance(shape[1], int) and shape[1] > 0 else 256 + w = shape[2] if isinstance(shape[2], int) and shape[2] > 0 else 128 + return int(h), int(w) + + def _preprocess(self, crop: np.ndarray, height: int, width: int) -> np.ndarray: + """Resize, normalise, and reshape a crop into an ONNX input tensor. + + Args: + crop: ``(H, W, 3)`` uint8 RGB numpy array. + height: Target height. + width: Target width. + + Returns: + ``(1, C, H, W)`` or ``(1, H, W, C)`` float32 array. + """ + from PIL import Image + + pil_img = Image.fromarray(crop.astype(np.uint8)) + pil_img = pil_img.resize((width, height), Image.BILINEAR) + img = np.array(pil_img, dtype=np.float32) / 255.0 + + # ImageNet normalisation + img = (img - self._MEAN) / self._STD # (H, W, 3) + + if self._layout == "NCHW": + tensor = img.transpose(2, 0, 1)[np.newaxis] # (1, C, H, W) + else: + tensor = img[np.newaxis] # (1, H, W, C) + + return tensor.astype(np.float32) + + def _extract_single(self, crop: np.ndarray) -> np.ndarray: + """Run a single crop through the ONNX session and return raw embedding. + + Args: + crop: ``(H, W, 3)`` uint8 RGB numpy array. + + Returns: + 1-D float32 embedding vector (unnormalised). + """ + h, w = self._get_spatial_dims() + tensor = self._preprocess(crop, h, w) + outputs = self._session.run(None, {self._input_name: tensor}) + # First output is typically (1, D); flatten to (D,) + return outputs[0].flatten().astype(np.float32) + + def info(self) -> dict[str, Any]: + return { + **super().info(), + "runtime": "onnx", + "layout": getattr(self, "_layout", None), + "input_shape": getattr(self, "_input_shape", None), + } + + +class HuggingFaceReIDAdapter(ReIDAdapter): + """HuggingFace-backed ReID feature extractor. + + Supports: + - AutoModel (generic feature extraction via last hidden state mean pooling) + - CLIPModel (image encoder branch, returns image embeddings) + - ViT/DeiT/Swin image models (pooler_output if available, else mean pool) + + Architecture auto-detection order: + 1. Model ID contains 'clip' → use CLIPModel image encoder + 2. Config model_type in ('vit', 'deit', 'swin') → use pooler_output + 3. Fallback → AutoModel + mean pooling of last_hidden_state + + All transformers imports are lazy — the library is not imported at module + import time. + """ + + # Architecture families that produce pooler_output via AutoModel + _POOLER_ARCHS = {"vit", "deit", "swin", "beit", "convnext", "mobilevit", "efficientnet"} + + def _load_model(self, **kwargs: Any) -> None: + """Load HuggingFace model with architecture auto-detection.""" + from PIL import Image as _PilImage # noqa: F401 — ensure PIL available + + self._arch = self._detect_architecture() + logger.info(f"Loading ReID encoder: {self.model_id} (arch={self._arch})") + + if self._arch == "clip": + self._load_clip(**kwargs) + else: + self._load_automodel(**kwargs) + + def _detect_architecture(self) -> str: + """Detect model architecture from model_id string, then config probe. + + Returns: + One of: "clip", "vit_pooler", "generic" + """ + model_id_lower = self.model_id.lower() + if "clip" in model_id_lower: + return "clip" + + # Probe config for model_type + try: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(self.model_id) + model_type = getattr(config, "model_type", "").lower() + if model_type in self._POOLER_ARCHS: + return "vit_pooler" + except Exception: + pass + + return "generic" + + def _load_clip(self, **kwargs: Any) -> None: + """Load CLIP model — use image encoder only.""" + from transformers import CLIPModel, CLIPProcessor + + self._processor = CLIPProcessor.from_pretrained(self.model_id) + self._model = CLIPModel.from_pretrained(self.model_id) + self._model.eval() + self._model.to(self.device) + + def _load_automodel(self, **kwargs: Any) -> None: + """Load generic AutoModel for feature extraction.""" + from transformers import AutoModel, AutoProcessor + + try: + self._processor = AutoProcessor.from_pretrained(self.model_id) + except Exception: + # Fallback: some ViT models only expose AutoFeatureExtractor + from transformers import AutoFeatureExtractor + + self._processor = AutoFeatureExtractor.from_pretrained(self.model_id) + + self._model = AutoModel.from_pretrained(self.model_id) + self._model.eval() + self._model.to(self.device) + + def _extract_single(self, crop: np.ndarray) -> np.ndarray: + """Forward pass through the encoder and return a pooled feature vector. + + Args: + crop: (H, W, 3) uint8 numpy array (RGB). + + Returns: + 1-D float32 embedding vector (unnormalised). + """ + import torch + from PIL import Image + + pil_image = Image.fromarray(crop) + + with torch.no_grad(): + if self._arch == "clip": + inputs = self._processor(images=pil_image, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + # Get image features from CLIP's vision encoder + image_features = self._model.get_image_features(**inputs) + embedding = image_features[0].cpu().float().numpy() + else: + inputs = self._processor(images=pil_image, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + outputs = self._model(**inputs) + + if ( + self._arch == "vit_pooler" + and hasattr(outputs, "pooler_output") + and outputs.pooler_output is not None + ): + # Use pooler_output for ViT/DeiT/Swin models + embedding = outputs.pooler_output[0].cpu().float().numpy() + else: + # Generic fallback: mean-pool over the token sequence dimension + last_hidden = outputs.last_hidden_state # (1, T, D) + embedding = last_hidden[0].mean(dim=0).cpu().float().numpy() + + return embedding diff --git a/src/mata/adapters/tracking_adapter.py b/src/mata/adapters/tracking_adapter.py index b260ddb..cec69c0 100644 --- a/src/mata/adapters/tracking_adapter.py +++ b/src/mata/adapters/tracking_adapter.py @@ -219,16 +219,27 @@ def __init__( detector: Any, tracker_config: TrackerConfig | str | dict | None = None, frame_rate: int = 30, + reid_encoder: Any | None = None, + reid_bridge: Any | None = None, ) -> None: self._detector = detector self._config: TrackerConfig = _resolve_config(tracker_config) self._frame_rate: int = int(frame_rate) self._tracker: Any = self._build_tracker() + self._reid_encoder: Any | None = reid_encoder + self._reid_bridge: Any | None = reid_bridge + + # Wire ReID encoder into BOTSORT so get_dists() activates the + # appearance distance branch when encoder is not None. + if self._reid_encoder is not None and hasattr(self._tracker, "encoder"): + self._tracker.encoder = self._reid_encoder logger.debug( - "TrackingAdapter initialised: tracker_type=%s, frame_rate=%d", + "TrackingAdapter initialised: tracker_type=%s, frame_rate=%d, reid=%s, bridge=%s", self._config.tracker_type, self._frame_rate, + self._reid_encoder is not None, + self._reid_bridge is not None, ) # ------------------------------------------------------------------ # @@ -270,6 +281,39 @@ def _to_numpy_image(image: Any) -> np.ndarray | None: pass return None + @staticmethod + def _extract_crops( + image: np.ndarray, + instances: list, + ) -> list: + """Extract image crops from detection bounding boxes. + + Args: + image: (H, W, 3) uint8 numpy array (RGB). + instances: List of Instance objects with bbox attribute. + + Returns: + List of (crop_h, crop_w, 3) uint8 arrays, one per instance. + Instances without a valid bbox are skipped (empty crop placeholder). + """ + h, w = image.shape[:2] + crops: list = [] + for inst in instances: + if inst.bbox is None: + crops.append(np.empty((0, 0, 3), dtype=np.uint8)) + continue + x1, y1, x2, y2 = inst.bbox + # Clip to image bounds + x1i = max(0, int(x1)) + y1i = max(0, int(y1)) + x2i = min(w, int(x2)) + y2i = min(h, int(y2)) + if x2i <= x1i or y2i <= y1i: + crops.append(np.empty((0, 0, 3), dtype=np.uint8)) + continue + crops.append(image[y1i:y2i, x1i:x2i].copy()) + return crops + # ------------------------------------------------------------------ # # Public API # # ------------------------------------------------------------------ # @@ -332,15 +376,46 @@ def update( # ---- 4. Convert to tracker input format ---------------------------- det_results = DetectionResults.from_vision_result(vision_result) - # ---- 5. Optional numpy image for GMC (BotSort) --------------------- + # ---- 4b. ReID embedding extraction --------------------------------- + # Compute np_image once — used for both crop extraction and GMC. np_image = self._to_numpy_image(image) - - # ---- 6. Run tracker ------------------------------------------------ + if self._reid_encoder is not None and np_image is not None and len(det_results) > 0: + crops = self._extract_crops(np_image, vision_result.instances) + valid_crops = [c for c in crops if c.size > 0] + if valid_crops: + embeddings = self._reid_encoder.predict(valid_crops) + emb_idx = 0 + for i, crop in enumerate(crops): + if crop.size > 0 and emb_idx < len(embeddings): + det_results.features[i] = embeddings[emb_idx] + emb_idx += 1 + + # ---- 5. Run tracker ------------------------------------------------ tracked: np.ndarray = self._tracker.update(det_results, img=np_image) + # ---- 6. Collect per-track embeddings from active stracks ----------- + embeddings_by_id: dict[int, np.ndarray] | None = None + if self._reid_encoder is not None: + active = getattr(self._tracker, "tracked_stracks", []) + embeddings_by_id = {} + for st in active: + if getattr(st, "is_activated", False) and getattr(st, "smooth_feat", None) is not None: + embeddings_by_id[st.track_id] = st.smooth_feat + # ---- 7. Build output VisionResult ---------------------------------- id2label = getattr(self._detector, "id2label", None) - result = self._convert_tracker_output(tracked, id2label) + result = self._convert_tracker_output(tracked, id2label, embeddings_by_id) + + # ---- 8. Cross-camera ReID publish ---------------------------------- + if self._reid_bridge is not None: + for inst in result.instances: + if inst.track_id is not None and inst.embedding is not None: + self._reid_bridge.publish( + track_id=inst.track_id, + embedding=inst.embedding, + bbox=inst.bbox, + label=inst.label if inst.label is not None else 0, + ) return result @@ -348,6 +423,7 @@ def _convert_tracker_output( self, tracked: np.ndarray, id2label: dict[int, str] | None, + embeddings_by_id: dict[int, np.ndarray] | None = None, ) -> VisionResult: """Convert tracker output array to a :class:`VisionResult`. @@ -357,10 +433,12 @@ def _convert_tracker_output( each row is ``[x1, y1, x2, y2, track_id, score, cls, idx]``. id2label: Optional ``{class_id: label_name}`` mapping sourced from the wrapped detector. + embeddings_by_id: Optional ``{track_id: smooth_feat}`` mapping + produced from active tracked stracks when ReID is enabled. Returns: :class:`VisionResult` with one :class:`Instance` per tracked - object, each carrying ``track_id``. + object, each carrying ``track_id`` and optionally ``embedding``. """ if tracked is None or len(tracked) == 0: return VisionResult(instances=[], meta={"source": "tracking_adapter"}) @@ -378,12 +456,18 @@ def _convert_tracker_output( else: label_name = f"class_{cls_id}" + # Resolve embedding from active stracks if ReID is enabled. + embedding: np.ndarray | None = None + if embeddings_by_id is not None: + embedding = embeddings_by_id.get(track_id) + inst = Instance( bbox=(x1, y1, x2, y2), score=score, label=cls_id, label_name=label_name, track_id=track_id, + embedding=embedding, ) instances.append(inst) diff --git a/src/mata/api.py b/src/mata/api.py index 0dd74e1..5bf34e1 100644 --- a/src/mata/api.py +++ b/src/mata/api.py @@ -208,6 +208,9 @@ def track( show_track_ids: bool = True, show_trails: bool = False, trail_length: int = 30, + reid_model: str | None = None, + with_reid: bool = False, + reid_bridge: Any | None = None, **kwargs: Any, ) -> list[VisionResult] | Generator[VisionResult, None, None]: """Run object detection + tracking on video, stream, or image sequence. @@ -238,6 +241,15 @@ def track( show_track_ids: Draw track IDs on annotated frames. show_trails: Draw trajectory trails on annotated frames. trail_length: Number of frames to keep in trail history. + reid_model: HuggingFace model ID or local .onnx path for ReID encoder. + When provided, appearance embeddings are extracted from detection + crops and injected into the tracker for identity recovery. + with_reid: Convenience flag — must be paired with reid_model. + Raises ValueError if True but reid_model is None. + reid_bridge: Optional :class:`~mata.trackers.reid_bridge.ReIDBridge` + instance for cross-camera ReID publishing. After each frame, + confirmed track embeddings are published to the shared Valkey + store so other camera instances can query them. **kwargs: Additional arguments passed to detection model. Returns: @@ -267,7 +279,16 @@ def track( # Load adapter eagerly so it is ready before any generator is iterated. # This ensures load() runs immediately (not lazily) which is important for # stream=True callers who consume the generator outside any patch context. - adapter = load("track", model, tracker=tracker, frame_rate=frame_rate, **kwargs) + adapter = load( + "track", + model, + tracker=tracker, + frame_rate=frame_rate, + reid_model=reid_model, + with_reid=with_reid, + reid_bridge=reid_bridge, + **kwargs, + ) # Build the generator and either collect or return it gen = _track_generator( diff --git a/src/mata/core/exporters/__init__.py b/src/mata/core/exporters/__init__.py index 4066cfe..30015d9 100644 --- a/src/mata/core/exporters/__init__.py +++ b/src/mata/core/exporters/__init__.py @@ -12,6 +12,7 @@ from mata.core.exporters.image_exporter import TrackTrailRenderer, export_image, export_ocr_image from mata.core.exporters.json_exporter import export_json, export_tracking_json from mata.core.exporters.text_exporter import export_text +from mata.core.exporters.valkey_exporter import export_valkey, load_valkey, publish_valkey __all__ = [ "export_json", @@ -24,4 +25,7 @@ "export_crops", "export_text", "TrackTrailRenderer", + "export_valkey", + "load_valkey", + "publish_valkey", ] diff --git a/src/mata/core/exporters/valkey_exporter.py b/src/mata/core/exporters/valkey_exporter.py new file mode 100644 index 0000000..77a475a --- /dev/null +++ b/src/mata/core/exporters/valkey_exporter.py @@ -0,0 +1,245 @@ +"""Valkey/Redis exporter for MATA result types. + +Exports any result with a to_dict()/to_json() interface to a Valkey key. +Supports TTL, JSON serialization, and optional msgpack for binary efficiency. +""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from mata.core.logging import get_logger + +if TYPE_CHECKING: + from mata.core.types import ClassifyResult, DepthResult, OCRResult, VisionResult + +logger = get_logger(__name__) + + +def _parse_valkey_uri(uri: str) -> tuple[str, str]: + """Parse a Valkey/Redis URI into (base_url, key). + + Handles the following formats: + - valkey://host:port/key_name + - valkey://host:port/0/key_name (with DB number) + - redis://user:pass@host:port/0/key (with credentials) + + Args: + uri: Full Valkey/Redis URI string + + Returns: + Tuple of (base_url, key) + + Raises: + ValueError: If URI format is invalid or key is missing + """ + from urllib.parse import urlparse + + parsed = urlparse(uri) + path_parts = parsed.path.strip("/").split("/", 1) + + if len(path_parts) == 2 and path_parts[0].isdigit(): + # Has DB number: valkey://host:port/0/key + db = path_parts[0] + key = path_parts[1] + base_url = f"{parsed.scheme}://{parsed.netloc}/{db}" + elif len(path_parts) == 1 and path_parts[0]: + key = path_parts[0] + base_url = f"{parsed.scheme}://{parsed.netloc}" + else: + raise ValueError( + f"Invalid Valkey URI: '{uri}'. Expected format: valkey://host:port/key " "or valkey://host:port/db/key" + ) + + return base_url, key + + +def _get_valkey_client(url: str, **kwargs: Any): + """Lazy-import and connect to Valkey/Redis. + + Tries valkey-py first, falls back to redis-py for compatibility. + + Args: + url: Valkey connection URL (e.g., "valkey://localhost:6379/0") + **kwargs: Additional connection parameters + + Returns: + Connected Valkey/Redis client instance + + Raises: + ImportError: If neither valkey-py nor redis-py is installed + """ + try: + import valkey + + return valkey.from_url(url, **kwargs) + except ImportError: + try: + import redis + + # valkey:// scheme → redis:// for redis-py compatibility + redis_url = url.replace("valkey://", "redis://", 1) + return redis.from_url(redis_url, **kwargs) + except ImportError: + raise ImportError( + "Valkey export requires 'valkey' or 'redis' package. " + "Install with: pip install mata[valkey] or pip install mata[redis]" + ) + + +def export_valkey( + result: VisionResult | ClassifyResult | DepthResult | OCRResult, + url: str, + key: str, + ttl: int | None = None, + serializer: str = "json", + **kwargs: Any, +) -> None: + """Export result to a Valkey/Redis key. + + Args: + result: Any MATA result object with to_dict()/to_json() + url: Valkey connection URL + key: Key name to store under + ttl: Time-to-live in seconds (None = no expiry) + serializer: "json" (default) or "msgpack" + **kwargs: Additional connection parameters + + Raises: + ImportError: If valkey/redis client not installed + ConnectionError: If Valkey server unreachable + """ + client = _get_valkey_client(url, **kwargs) + + if serializer == "json": + data = result.to_json() + elif serializer == "msgpack": + import msgpack + + data = msgpack.packb(result.to_dict(), use_bin_type=True) + else: + raise ValueError(f"Unsupported serializer: '{serializer}'. Use 'json' or 'msgpack'.") + + if ttl is not None: + client.setex(key, ttl, data) + else: + client.set(key, data) + + logger.info(f"Exported result to Valkey key '{key}' (ttl={ttl})") + + +def load_valkey( + url: str, + key: str, + result_type: str = "auto", + **kwargs: Any, +) -> VisionResult | ClassifyResult | DepthResult | OCRResult: + """Load a MATA result from a Valkey/Redis key. + + Args: + url: Valkey connection URL + key: Key name to load from + result_type: "auto" (detect from data), "vision", "classify", "depth", "ocr" + **kwargs: Additional connection parameters + + Returns: + Reconstructed result object + + Raises: + KeyError: If key does not exist + ImportError: If valkey/redis client not installed + """ + client = _get_valkey_client(url, **kwargs) + raw = client.get(key) + + if raw is None: + raise KeyError(f"Valkey key '{key}' not found") + + data = json.loads(raw) + + if result_type == "auto": + result_type = _detect_result_type(data) + + return _deserialize_result(data, result_type) + + +def _detect_result_type(data: dict) -> str: + """Auto-detect result type from serialized dict keys.""" + if "instances" in data: + return "vision" + elif "predictions" in data: + return "classify" + elif "depth" in data: + return "depth" + elif "regions" in data: + return "ocr" + else: + raise ValueError( + f"Cannot auto-detect result type from keys: {list(data.keys())}. " "Specify result_type explicitly." + ) + + +def _deserialize_result(data: dict, result_type: str): + """Reconstruct a typed result from dict.""" + from mata.core.types import ClassifyResult, DepthResult, OCRResult, VisionResult + + type_map = { + "vision": VisionResult, + "detect": VisionResult, + "classify": ClassifyResult, + "depth": DepthResult, + "ocr": OCRResult, + } + + cls = type_map.get(result_type) + if cls is None: + raise ValueError(f"Unknown result_type: '{result_type}'. Use: {list(type_map.keys())}") + + return cls.from_dict(data) + + +def publish_valkey( + result: VisionResult | ClassifyResult | DepthResult | OCRResult, + url: str, + channel: str, + serializer: str = "json", + **kwargs: Any, +) -> int: + """Publish result to a Valkey Pub/Sub channel. + + This is a fire-and-forget operation. Messages are delivered only to + active subscribers and are NOT persisted — if no subscriber is listening + when ``publish_valkey`` is called, the message is silently dropped. + + Channel names should never be derived from user-controlled input without + prior validation, as they are passed directly to the Valkey server. + + Args: + result: Any MATA result object with to_dict()/to_json() + url: Valkey connection URL (e.g., "valkey://localhost:6379") + channel: Pub/Sub channel name to publish to + serializer: "json" (default) or "msgpack" + **kwargs: Additional connection parameters + + Returns: + Number of subscribers that received the message + + Raises: + ImportError: If valkey/redis client not installed + ValueError: If an unsupported serializer is specified + """ + client = _get_valkey_client(url, **kwargs) + + if serializer == "json": + data = result.to_json() + elif serializer == "msgpack": + import msgpack + + data = msgpack.packb(result.to_dict(), use_bin_type=True) + else: + raise ValueError(f"Unsupported serializer: '{serializer}'. Use 'json' or 'msgpack'.") + + num_receivers = client.publish(channel, data) + logger.info(f"Published result to channel '{channel}' ({num_receivers} subscribers)") + return num_receivers diff --git a/src/mata/core/graph/validator.py b/src/mata/core/graph/validator.py index 74342e1..1534f53 100644 --- a/src/mata/core/graph/validator.py +++ b/src/mata/core/graph/validator.py @@ -561,6 +561,7 @@ def _is_compatible_type(self, source_type: type[Artifact], target_type: type[Art Types are compatible if: - They are the same type - source_type is a subclass of target_type + - source_type is the base Artifact class (wildcard: runtime type unknown) Args: source_type: Type produced by source node @@ -569,6 +570,10 @@ def _is_compatible_type(self, source_type: type[Artifact], target_type: type[Art Returns: True if types are compatible """ + # Base Artifact is a wildcard (e.g. ValkeyLoad can produce any subtype); + # defer the concrete type check to runtime. + if source_type is Artifact: + return True try: return issubclass(source_type, target_type) except TypeError: diff --git a/src/mata/core/model_loader.py b/src/mata/core/model_loader.py index 0b24cc6..07d8b3e 100644 --- a/src/mata/core/model_loader.py +++ b/src/mata/core/model_loader.py @@ -230,9 +230,11 @@ def _load_with_explicit_type(self, task: str, source: str | None, model_type: Mo UnsupportedModelError: If type/task combination not supported """ if task == "track": - tracker_config, frame_rate = self._resolve_tracker_kwargs(kwargs) + tracker_config, frame_rate, reid_model, with_reid, reid_bridge = self._resolve_tracker_kwargs(kwargs) detect_adapter = self._load_with_explicit_type("detect", source, model_type, **kwargs) - return self._wrap_with_tracking(detect_adapter, tracker_config, frame_rate) + return self._wrap_with_tracking( + detect_adapter, tracker_config, frame_rate, reid_model, with_reid, reid_bridge + ) if model_type == ModelType.HUGGINGFACE: if not source: @@ -484,6 +486,17 @@ def _load_from_config(self, task: str, alias: str, **kwargs) -> Any: """ config = self.registry.get_config(task, alias) + # For tracking tasks, explicitly extract reid_model and with_reid so + # that (a) top-level config fields are honoured, (b) tracker_config + # .with_reid propagates to the top-level flag, and (c) runtime kwargs + # always take precedence over config values. + if task == "track": + tracker_cfg = config.get("tracker_config", {}) + # with_reid: runtime kwarg > top-level config > tracker_config sub-key + config_with_reid = config.get("with_reid", tracker_cfg.get("with_reid", False)) + kwargs["reid_model"] = kwargs.get("reid_model") or config.get("reid_model") + kwargs["with_reid"] = kwargs.get("with_reid") or config_with_reid + # Merge config with kwargs (kwargs take precedence) merged_kwargs = {**config, **kwargs} @@ -516,9 +529,11 @@ def _load_from_huggingface(self, task: str, model_id: str, **kwargs) -> Any: HuggingFace adapter or pipeline instance for the task """ if task == "track": - tracker_config, frame_rate = self._resolve_tracker_kwargs(kwargs) + tracker_config, frame_rate, reid_model, with_reid, reid_bridge = self._resolve_tracker_kwargs(kwargs) detect_adapter = self._load_from_huggingface("detect", model_id, **kwargs) - return self._wrap_with_tracking(detect_adapter, tracker_config, frame_rate) + return self._wrap_with_tracking( + detect_adapter, tracker_config, frame_rate, reid_model, with_reid, reid_bridge + ) if task == "detect": # Check for zero-shot detection mode @@ -721,9 +736,11 @@ def _load_from_file(self, task: str, file_path: str, **kwargs) -> Any: Appropriate adapter instance based on file extension """ if task == "track": - tracker_config, frame_rate = self._resolve_tracker_kwargs(kwargs) + tracker_config, frame_rate, reid_model, with_reid, reid_bridge = self._resolve_tracker_kwargs(kwargs) detect_adapter = self._load_from_file("detect", file_path, **kwargs) - return self._wrap_with_tracking(detect_adapter, tracker_config, frame_rate) + return self._wrap_with_tracking( + detect_adapter, tracker_config, frame_rate, reid_model, with_reid, reid_bridge + ) path = Path(file_path) extension = path.suffix.lower() @@ -805,7 +822,7 @@ def _load_from_file(self, task: str, file_path: str, **kwargs) -> Any: ) def _resolve_tracker_kwargs(self, kwargs: dict) -> tuple: - """Pop tracker-related kwargs and return ``(tracker_config, frame_rate)``. + """Pop tracker-related kwargs and return ``(tracker_config, frame_rate, reid_model, with_reid)``. Handles the case where the registry config supplies both a ``tracker`` name (string) **and** a ``tracker_config`` override dict. The two are @@ -828,15 +845,19 @@ def _resolve_tracker_kwargs(self, kwargs: dict) -> tuple: dict, a YAML path, or a :class:`TrackerConfig` instance). Args: - kwargs: Mutable kwargs dict — ``tracker``, ``tracker_config``, and - ``frame_rate`` are **popped** in place. + kwargs: Mutable kwargs dict — ``tracker``, ``tracker_config``, + ``frame_rate``, ``reid_model``, and ``with_reid`` are + **popped** in place. Returns: - ``(resolved_tracker_config, frame_rate)`` tuple. + ``(resolved_tracker_config, frame_rate, reid_model, with_reid)`` tuple. """ tracker = kwargs.pop("tracker", "botsort") tracker_config_overrides = kwargs.pop("tracker_config", None) frame_rate = kwargs.pop("frame_rate", 30) + reid_model = kwargs.pop("reid_model", None) + with_reid = kwargs.pop("with_reid", False) + reid_bridge = kwargs.pop("reid_bridge", None) if tracker_config_overrides is not None and isinstance(tracker_config_overrides, dict): if isinstance(tracker, str): @@ -850,9 +871,17 @@ def _resolve_tracker_kwargs(self, kwargs: dict) -> tuple: else: resolved = tracker - return resolved, frame_rate - - def _wrap_with_tracking(self, detect_adapter: Any, tracker_config: Any, frame_rate: int) -> Any: + return resolved, frame_rate, reid_model, with_reid, reid_bridge + + def _wrap_with_tracking( + self, + detect_adapter: Any, + tracker_config: Any, + frame_rate: int, + reid_model: str | None = None, + with_reid: bool = False, + reid_bridge: Any | None = None, + ) -> Any: """Wrap a detection adapter with a :class:`TrackingAdapter`. This is a shared helper called by all source-type loaders when @@ -865,14 +894,56 @@ def _wrap_with_tracking(self, detect_adapter: Any, tracker_config: Any, frame_ra instance, a plain dict of tracker parameters, or ``None`` for the BotSort default. frame_rate: Video frame rate used to derive ``max_time_lost``. + reid_model: HuggingFace ID or local .onnx path for ReID encoder. + If None and with_reid is False, no ReID encoder is loaded. + with_reid: If True, reid_model must also be provided. + reid_bridge: Optional :class:`~mata.trackers.reid_bridge.ReIDBridge` + instance. When provided, confirmed track embeddings are + published to Valkey after each frame. Returns: :class:`TrackingAdapter` composing *detect_adapter* with the configured tracker. + + Raises: + ValueError: If with_reid=True but reid_model is None. """ from mata.adapters.tracking_adapter import TrackingAdapter - return TrackingAdapter(detect_adapter, tracker_config, frame_rate) + reid_encoder = None + if with_reid and reid_model is None: + raise ValueError( + "with_reid=True requires reid_model to be specified. " + "Provide a HuggingFace model ID or local .onnx path, e.g.: " + 'reid_model="person-reid/osnet-x1-0"' + ) + if reid_model is not None or with_reid: + reid_encoder = self._load_reid_encoder(reid_model) + + return TrackingAdapter( + detect_adapter, tracker_config, frame_rate, reid_encoder=reid_encoder, reid_bridge=reid_bridge + ) + + def _load_reid_encoder(self, reid_model: str) -> Any: + """Load a ReID encoder from a HuggingFace ID or local .onnx path. + + Args: + reid_model: HuggingFace model ID (contains '/') or path to a + local .onnx file. + + Returns: + :class:`HuggingFaceReIDAdapter` or :class:`ONNXReIDAdapter` instance. + """ + if reid_model.endswith(".onnx") or (self._is_local_file(reid_model) and reid_model.lower().endswith(".onnx")): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + logger.info(f"Loading ONNX ReID encoder from: {reid_model}") + return ONNXReIDAdapter(reid_model) + else: + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + logger.info(f"Loading HuggingFace ReID encoder: {reid_model}") + return HuggingFaceReIDAdapter(reid_model) def _load_default(self, task: str, **kwargs) -> Any: """Load default model for task. diff --git a/src/mata/core/model_registry.py b/src/mata/core/model_registry.py index 7e83f30..1aad81e 100644 --- a/src/mata/core/model_registry.py +++ b/src/mata/core/model_registry.py @@ -355,3 +355,46 @@ def save_to_file(self, file_path: str | None = None): yaml.dump(merged, f, default_flow_style=False, sort_keys=True) logger.info(f"Saved model registry to: {file_path}") + + def get_valkey_connection(self, name: str = "default") -> dict[str, Any]: + """Get named Valkey connection config from the YAML ``storage`` section. + + Args: + name: Connection profile name defined under ``storage.valkey`` in + the YAML config (e.g., "default", "production"). + + Returns: + Dict with connection parameters (url, db, ttl, tls, password, …). + ``password_env`` is resolved from the environment and replaced by + ``password``; the env-var name is never returned. + + Raises: + ModelNotFoundError: If the named connection is not present in the + config. Missing ``storage`` section is treated as an empty + config (not an error). + + Examples: + >>> registry = ModelRegistry() + >>> conn = registry.get_valkey_connection("default") + >>> # {"url": "valkey://localhost:6379", "db": 0, "ttl": 3600} + """ + self._ensure_loaded() + + storage = self._configs.get("storage", {}).get("valkey", {}) + if name not in storage: + raise ModelNotFoundError( + f"Valkey connection '{name}' not found in config. " f"Available: {list(storage.keys())}" + ) + + conn = dict(storage[name]) + + # Resolve password from environment variable — never expose plaintext + if "password_env" in conn: + import os + + env_var = conn.pop("password_env") + password = os.environ.get(env_var) + if password: + conn["password"] = password + + return conn diff --git a/src/mata/core/types.py b/src/mata/core/types.py index 074cd4c..6dd91ab 100644 --- a/src/mata/core/types.py +++ b/src/mata/core/types.py @@ -534,6 +534,15 @@ def save( >>> result = mata.run("detect", pil_img) >>> result.save("output.png", image="test.jpg") """ + # Valkey/Redis URI scheme routing + output_str = str(output_path) + if output_str.startswith(("valkey://", "redis://")): + from mata.core.exporters.valkey_exporter import _parse_valkey_uri, export_valkey + + url, key = _parse_valkey_uri(output_str) + export_valkey(self, url=url, key=key, **kwargs) + return + from pathlib import Path from mata.core.exporters import export_crops, export_csv, export_image, export_json @@ -648,6 +657,15 @@ def save( Format auto-detected from extension (.json, .png, .jpg, etc.). """ + # Valkey/Redis URI scheme routing + output_str = str(output_path) + if output_str.startswith(("valkey://", "redis://")): + from mata.core.exporters.valkey_exporter import _parse_valkey_uri, export_valkey + + url, key = _parse_valkey_uri(output_str) + export_valkey(self, url=url, key=key, **kwargs) + return + from pathlib import Path from mata.core.exporters import export_image, export_json @@ -784,6 +802,15 @@ def save(self, output_path: str, **kwargs: Any) -> None: output_path: Destination file path. **kwargs: Forwarded to the underlying exporter. """ + # Valkey/Redis URI scheme routing + output_str = str(output_path) + if output_str.startswith(("valkey://", "redis://")): + from mata.core.exporters.valkey_exporter import _parse_valkey_uri, export_valkey + + url, key = _parse_valkey_uri(output_str) + export_valkey(self, url=url, key=key, **kwargs) + return + from pathlib import Path suffix = Path(output_path).suffix.lower() @@ -1062,6 +1089,15 @@ def save( format: Override format detection **kwargs: Additional exporter parameters """ + # Valkey/Redis URI scheme routing + output_str = str(output_path) + if output_str.startswith(("valkey://", "redis://")): + from mata.core.exporters.valkey_exporter import _parse_valkey_uri, export_valkey + + url, key = _parse_valkey_uri(output_str) + export_valkey(self, url=url, key=key, **kwargs) + return + from pathlib import Path from mata.core.exporters import export_crops, export_csv, export_image, export_json @@ -1390,6 +1426,15 @@ def save( format: Override format detection **kwargs: Additional exporter parameters """ + # Valkey/Redis URI scheme routing + output_str = str(output_path) + if output_str.startswith(("valkey://", "redis://")): + from mata.core.exporters.valkey_exporter import _parse_valkey_uri, export_valkey + + url, key = _parse_valkey_uri(output_str) + export_valkey(self, url=url, key=key, **kwargs) + return + from pathlib import Path from mata.core.exporters import export_csv, export_image, export_json @@ -1550,6 +1595,15 @@ def save( format: Override format detection **kwargs: Additional exporter parameters (e.g., top_k for charts) """ + # Valkey/Redis URI scheme routing + output_str = str(output_path) + if output_str.startswith(("valkey://", "redis://")): + from mata.core.exporters.valkey_exporter import _parse_valkey_uri, export_valkey + + url, key = _parse_valkey_uri(output_str) + export_valkey(self, url=url, key=key, **kwargs) + return + from pathlib import Path from mata.core.exporters import export_csv, export_image, export_json diff --git a/src/mata/nodes/__init__.py b/src/mata/nodes/__init__.py index a9c4fdf..60b8d0e 100644 --- a/src/mata/nodes/__init__.py +++ b/src/mata/nodes/__init__.py @@ -39,6 +39,10 @@ # Tracking nodes (Task 5.5) from mata.nodes.track import Track +# Storage nodes (v1.10.0) +from mata.nodes.valkey_load import ValkeyLoad +from mata.nodes.valkey_store import ValkeyStore + # VLM nodes (Task 5.7) from mata.nodes.vlm_describe import VLMDescribe from mata.nodes.vlm_detect import VLMDetect @@ -77,4 +81,7 @@ # Visualization & Analysis nodes "Annotate", "NMS", + # Storage nodes + "ValkeyStore", + "ValkeyLoad", ] diff --git a/src/mata/nodes/filter.py b/src/mata/nodes/filter.py index 3b09337..8e2834e 100644 --- a/src/mata/nodes/filter.py +++ b/src/mata/nodes/filter.py @@ -56,8 +56,10 @@ class Filter(Node): ``` """ - inputs: dict[str, type[Artifact]] = {"detections": Detections} - outputs: dict[str, type[Artifact]] = {"detections": Detections} + # inputs/outputs are set dynamically in __init__ based on src so that + # the graph auto-wiring can match the actual artifact name in context. + inputs: dict[str, type[Artifact]] + outputs: dict[str, type[Artifact]] def __init__( self, @@ -76,8 +78,11 @@ def __init__( self.label_in = label_in self.label_not_in = label_not_in self.fuzzy = fuzzy + # Use src as the input key so graph auto-wiring matches by artifact name + self.inputs = {src: Detections} + self.outputs = {out: Detections} - def run(self, ctx: ExecutionContext, detections: Detections) -> dict[str, Artifact]: + def run(self, ctx: ExecutionContext, **inputs: Detections) -> dict[str, Artifact]: """Apply filtering criteria to detections. Filtering is applied sequentially: @@ -87,12 +92,13 @@ def run(self, ctx: ExecutionContext, detections: Detections) -> dict[str, Artifa Args: ctx: Execution context (unused by this node). - detections: Input detections to filter. + **inputs: The single input Detections artifact, keyed by src name. Returns: Dict with a single key (``self.out``) mapping to the filtered Detections artifact. """ + detections: Detections = next(iter(inputs.values())) filtered = detections # 1. Filter by confidence score diff --git a/src/mata/nodes/valkey_load.py b/src/mata/nodes/valkey_load.py new file mode 100644 index 0000000..d6bba05 --- /dev/null +++ b/src/mata/nodes/valkey_load.py @@ -0,0 +1,99 @@ +"""Valkey/Redis source node for graph pipelines. + +Loads previously stored artifacts from Valkey and injects them +into the graph execution context as typed artifacts. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from mata.core.artifacts.base import Artifact +from mata.core.graph.node import Node + +if TYPE_CHECKING: + from mata.core.graph.context import ExecutionContext + + +class ValkeyLoad(Node): + """Source node that loads artifacts from Valkey/Redis. + + Loads a previously stored result from a Valkey key and converts + it into the appropriate graph artifact type. Typically used as + an entry node in a graph that consumes results from another pipeline. + + Args: + url: Valkey connection URL + key: Key name to load from + result_type: "auto", "vision", "classify", "depth", "ocr" + out: Output artifact name (default: "loaded") + + Examples: + >>> from mata.nodes import ValkeyLoad, Filter, Fuse + >>> graph = (Graph() + ... .then(ValkeyLoad( + ... url="valkey://localhost:6379", + ... key="upstream:detections:latest", + ... result_type="vision", + ... out="dets", + ... )) + ... .then(Filter(src="dets", score_gt=0.7, out="filtered")) + ... .then(Fuse(detections="filtered")) + ... ) + """ + + inputs: dict[str, type[Artifact]] = {} # Source node: no inputs + outputs: dict[str, type[Artifact]] = {"artifact": Artifact} + + def __init__( + self, + url: str, + key: str, + result_type: str = "auto", + out: str = "loaded", + ): + super().__init__(name="ValkeyLoad") + self.url = url + self.key = key + self.result_type = result_type + self.output_name = out + + def run(self, ctx: ExecutionContext) -> dict[str, Artifact]: + """Load result from Valkey and convert to artifact. + + Args: + ctx: Execution context. + + Returns: + Dict mapping output name to the loaded artifact. + + Raises: + KeyError: If the Valkey key does not exist. + """ + from mata.core.exporters.valkey_exporter import load_valkey + + result = load_valkey( + url=self.url, + key=self.key, + result_type=self.result_type, + ) + + artifact = self._result_to_artifact(result) + return {self.output_name: artifact} + + @staticmethod + def _result_to_artifact(result) -> Artifact: + """Convert a result type back to a graph artifact.""" + from mata.core.artifacts.classifications import Classifications + from mata.core.artifacts.depth_map import DepthMap + from mata.core.artifacts.detections import Detections + from mata.core.types import ClassifyResult, DepthResult, VisionResult + + if isinstance(result, VisionResult): + return Detections.from_vision_result(result) + elif isinstance(result, ClassifyResult): + return Classifications.from_classify_result(result) + elif isinstance(result, DepthResult): + return DepthMap.from_depth_result(result) + else: + raise TypeError(f"Cannot convert {type(result).__name__} to graph artifact") diff --git a/src/mata/nodes/valkey_store.py b/src/mata/nodes/valkey_store.py new file mode 100644 index 0000000..cf4e2bf --- /dev/null +++ b/src/mata/nodes/valkey_store.py @@ -0,0 +1,133 @@ +"""Valkey/Redis sink node for graph pipelines. + +Writes artifacts to Valkey during graph execution, enabling +distributed result sharing and pipeline persistence. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from mata.core.artifacts.base import Artifact +from mata.core.graph.node import Node + +if TYPE_CHECKING: + from mata.core.graph.context import ExecutionContext + + +class ValkeyStore(Node): + """Sink node that writes artifacts to Valkey/Redis. + + Writes the specified source artifact to a Valkey key during graph + execution. Supports key templates with variable substitution and + TTL-based expiration. + + This node is a terminal/sink — it stores data externally and + passes the input artifact through unchanged as its output. + + Args: + src: Name of the input artifact to store + url: Valkey connection URL (e.g., "valkey://localhost:6379") + key: Key name or template (supports {node}, {timestamp} placeholders) + ttl: Time-to-live in seconds (None = no expiry) + serializer: "json" (default) or "msgpack" + out: Output artifact name (passes input through, default: same as src) + + Examples: + >>> from mata.nodes import Detect, Filter, ValkeyStore + >>> graph = (Graph() + ... .then(Detect(using="detr", out="dets")) + ... .then(Filter(src="dets", score_gt=0.5, out="filtered")) + ... .then(ValkeyStore( + ... src="filtered", + ... url="valkey://localhost:6379", + ... key="pipeline:detections:{timestamp}", + ... ttl=3600, + ... )) + ... ) + """ + + # inputs/outputs are set dynamically in __init__ based on src so that + # the graph auto-wiring can match the actual artifact name in context. + inputs: dict[str, type[Artifact]] + outputs: dict[str, type[Artifact]] + + def __init__( + self, + src: str, + url: str, + key: str, + ttl: int | None = None, + serializer: str = "json", + out: str | None = None, + ): + super().__init__(name="ValkeyStore") + self.src_name = src + self.url = url + self.key_template = key + self.ttl = ttl + self.serializer = serializer + self.output_name = out or src + # Use src as the input key so graph auto-wiring matches by artifact name + self.inputs = {src: Artifact} + self.outputs = {self.output_name: Artifact} + + def run(self, ctx: ExecutionContext, **inputs: Artifact) -> dict[str, Artifact]: + """Store artifact to Valkey and pass through. + + Args: + ctx: Execution context. + **inputs: The single input artifact, keyed by src name. + + Returns: + Dict mapping output name to the unchanged input artifact. + """ + import time + + from mata.core.exporters.valkey_exporter import export_valkey + + # Accept the artifact under whatever key the scheduler passes it as + artifact = next(iter(inputs.values())) + + resolved_key = self.key_template.format( + node=self.name, + timestamp=int(time.time()), + ) + + result = self._artifact_to_serializable(artifact) + + export_valkey( + result=result, + url=self.url, + key=resolved_key, + ttl=self.ttl, + serializer=self.serializer, + ) + + return {self.output_name: artifact} + + @staticmethod + def _artifact_to_serializable(artifact: Artifact): + """Convert graph artifact to a serializable result type.""" + from mata.core.artifacts.classifications import Classifications + from mata.core.artifacts.converters import ( + artifact_to_classify_result, + artifact_to_depth_result, + detections_to_vision_result, + masks_to_vision_result, + ) + from mata.core.artifacts.depth_map import DepthMap + from mata.core.artifacts.detections import Detections + from mata.core.artifacts.masks import Masks + + if isinstance(artifact, Detections): + return detections_to_vision_result(artifact) + elif isinstance(artifact, Masks): + return masks_to_vision_result(artifact) + elif isinstance(artifact, Classifications): + return artifact_to_classify_result(artifact) + elif isinstance(artifact, DepthMap): + return artifact_to_depth_result(artifact) + else: + # Fallback: use artifact's own to_dict() + return artifact diff --git a/src/mata/trackers/__init__.py b/src/mata/trackers/__init__.py index bad2474..358d1c2 100644 --- a/src/mata/trackers/__init__.py +++ b/src/mata/trackers/__init__.py @@ -15,6 +15,7 @@ DetectionResults — adapter: VisionResult → tracker input format BYTETracker — full two-stage ByteTrack algorithm BOTSORT — BotSort: BYTETracker + GMC + optional ReID + ReIDBridge — cross-camera ReID embedding store via Valkey (v1.9.2) """ from __future__ import annotations @@ -22,6 +23,7 @@ from mata.trackers.basetrack import BaseTrack, TrackState from mata.trackers.bot_sort import BOTSORT, BOTrack from mata.trackers.byte_tracker import BYTETracker, DetectionResults, STrack +from mata.trackers.reid_bridge import ReIDBridge __all__ = [ "BaseTrack", @@ -31,4 +33,5 @@ "DetectionResults", "BYTETracker", "BOTSORT", + "ReIDBridge", ] diff --git a/src/mata/trackers/bot_sort.py b/src/mata/trackers/bot_sort.py index 3d568a9..35f5f95 100644 --- a/src/mata/trackers/bot_sort.py +++ b/src/mata/trackers/bot_sort.py @@ -369,6 +369,7 @@ def init_track(self, results: Any, img: np.ndarray | None = None) -> list[BOTrac return [] tracks: list[BOTrack] = [] + features = getattr(results, "features", None) for i in range(len(results)): xywh = results.xywh[i] # [cx, cy, w, h] score = float(results.conf[i]) @@ -378,7 +379,8 @@ def init_track(self, results: Any, img: np.ndarray | None = None) -> list[BOTrac [xywh[0], xywh[1], xywh[2], xywh[3], 0.0, orig_idx], dtype=np.float64, ) - tracks.append(BOTrack(xywh_ext, score, cls)) + feat = features[i] if (features is not None and i < len(features)) else None + tracks.append(BOTrack(xywh_ext, score, cls, feat=feat)) return tracks def get_dists(self, tracks: list[BOTrack], detections: list[BOTrack]) -> np.ndarray: diff --git a/src/mata/trackers/byte_tracker.py b/src/mata/trackers/byte_tracker.py index a78d623..8c3241b 100644 --- a/src/mata/trackers/byte_tracker.py +++ b/src/mata/trackers/byte_tracker.py @@ -406,6 +406,9 @@ def __init__( self._indices: np.ndarray = ( np.arange(n, dtype=np.int64) if indices is None else np.asarray(indices, dtype=np.int64).ravel() ) + # Optional per-detection appearance feature vectors (for ReID). + # Each entry is either None or a 1-D float32 array. + self.features: list = [None] * n # ------------------------------------------------------------------ # Properties @@ -446,13 +449,21 @@ def __getitem__(self, idx: Any) -> DetectionResults: :meth:`BYTETracker.init_track` can embed them in :class:`STrack` as ``idx``. """ - return DetectionResults( + # Slice features using numpy object-array indexing so that both + # boolean masks and integer-index arrays work uniformly. + feats_arr = np.empty(len(self.features), dtype=object) + for _i, _f in enumerate(self.features): + feats_arr[_i] = _f + sliced_feats = list(feats_arr[idx]) + sliced = DetectionResults( conf=self._conf[idx], xyxy=self._xyxy[idx], xywh=self._xywh[idx], cls=self._cls[idx], indices=self._indices[idx], ) + sliced.features = sliced_feats + return sliced # ------------------------------------------------------------------ # Factory helpers diff --git a/src/mata/trackers/configs/botsort.yaml b/src/mata/trackers/configs/botsort.yaml index 5802c02..82a7a3a 100644 --- a/src/mata/trackers/configs/botsort.yaml +++ b/src/mata/trackers/configs/botsort.yaml @@ -1,14 +1,18 @@ # BotSort tracker configuration # Ref: https://arxiv.org/abs/2206.14651 tracker_type: botsort -track_high_thresh: 0.5 # High confidence threshold — first association -track_low_thresh: 0.1 # Low confidence threshold — second association -new_track_thresh: 0.6 # Threshold to initialise new tracks -track_buffer: 30 # Frames a lost track is kept before removal -match_thresh: 0.8 # IoU matching threshold -fuse_score: true # Fuse detection score with IoU distance +track_high_thresh: 0.5 # High confidence threshold — first association +track_low_thresh: 0.1 # Low confidence threshold — second association +new_track_thresh: 0.6 # Threshold to initialise new tracks +track_buffer: 30 # Frames a lost track is kept before removal +match_thresh: 0.8 # IoU matching threshold +fuse_score: true # Fuse detection score with IoU distance # BotSort-specific fields -gmc_method: sparseOptFlow # GMC method: 'sparseOptFlow' or null -proximity_thresh: 0.5 # Minimum IoU for ReID candidate set -appearance_thresh: 0.25 # Minimum cosine similarity for ReID matching -with_reid: false # ReID disabled in v1.8 +gmc_method: sparseOptFlow # GMC method: 'sparseOptFlow' or null +proximity_thresh: 0.5 # Minimum IoU for ReID candidate set +appearance_thresh: 0.25 # Minimum cosine similarity for ReID matching +with_reid: false # ReID disabled by default +# ReID configuration (v1.9.2+) +# Uncomment to enable appearance-based re-identification: +# reid_model: "person-reid/osnet-x1-0" # HuggingFace ID or local .onnx path +# with_reid: true # Enable appearance-based matching diff --git a/src/mata/trackers/reid_bridge.py b/src/mata/trackers/reid_bridge.py new file mode 100644 index 0000000..44c8e8b --- /dev/null +++ b/src/mata/trackers/reid_bridge.py @@ -0,0 +1,179 @@ +"""Cross-camera ReID bridge via Valkey. + +Publishes track embeddings to a shared Valkey store so that independent +tracker instances (different cameras/processes) can resolve identities +across feeds. + +Data model: + Key: reid:{camera_id}:{track_id} + Value: msgpack({ + "track_id": int, + "camera_id": str, + "embedding": list[float], # smooth_feat (L2-normalised) + "bbox": [x1, y1, x2, y2], + "timestamp": float, + "label": int + }) + TTL: configurable (default 300s) + +Query: + For an unmatched detection, fetch all active embeddings from other + cameras, compute cosine similarity, return best match above threshold. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from mata.core.logging import get_logger + +logger = get_logger(__name__) + + +class ReIDBridge: + """Cross-camera ReID embedding store backed by Valkey. + + Args: + url: Valkey connection URI (e.g., "valkey://localhost:6379/0"). + camera_id: Unique identifier for this camera/tracker instance. + ttl: Time-to-live for published embeddings in seconds. + similarity_thresh: Minimum cosine similarity for cross-camera match. + """ + + def __init__( + self, + url: str, + camera_id: str = "default", + ttl: int = 300, + similarity_thresh: float = 0.25, + ) -> None: + from mata.core.exporters.valkey_exporter import _get_valkey_client + + self._client = _get_valkey_client(url) + self._camera_id = camera_id + self._ttl = ttl + self._similarity_thresh = similarity_thresh + self._prefix = "reid" + + def publish( + self, + track_id: int, + embedding: np.ndarray, + bbox: tuple[float, ...] | None = None, + label: int = 0, + timestamp: float | None = None, + ) -> None: + """Publish a track's embedding to Valkey. + + Args: + track_id: Unique track identifier within this camera. + embedding: L2-normalised feature vector (1-D float32 array). + bbox: Optional bounding box (x1, y1, x2, y2) in pixel coords. + label: Class label index (default 0). + timestamp: Unix timestamp; defaults to ``time.time()``. + """ + import time + + import msgpack + + key = f"{self._prefix}:{self._camera_id}:{track_id}" + data = msgpack.packb( + { + "track_id": track_id, + "camera_id": self._camera_id, + "embedding": embedding.tolist(), + "bbox": list(bbox) if bbox is not None else None, + "label": label, + "timestamp": timestamp if timestamp is not None else time.time(), + } + ) + try: + self._client.set(key, data, ex=self._ttl) + except Exception as exc: # noqa: BLE001 + logger.warning("ReIDBridge.publish failed (camera=%s, track=%d): %s", self._camera_id, track_id, exc) + + def query( + self, + embedding: np.ndarray, + exclude_camera: str | None = None, + top_k: int = 1, + ) -> list[dict[str, Any]]: + """Find nearest embeddings from other cameras. + + Args: + embedding: Query L2-normalised feature vector. + exclude_camera: Camera ID to exclude from results (typically + ``self.camera_id`` to prevent self-matching). + top_k: Maximum number of results to return. + + Returns: + List of matches sorted by similarity (descending), each a dict with + keys: ``track_id``, ``camera_id``, ``similarity``, ``bbox``, ``label``. + """ + import msgpack + + pattern = f"{self._prefix}:*" + matches: list[dict[str, Any]] = [] + try: + for key in self._client.scan_iter(match=pattern, count=100): + raw = self._client.get(key) + if raw is None: + continue + entry = msgpack.unpackb(raw, raw=False) + if exclude_camera and entry.get("camera_id") == exclude_camera: + continue + stored_emb = np.array(entry["embedding"], dtype=np.float32) + sim = float(np.dot(embedding, stored_emb)) + if sim >= self._similarity_thresh: + matches.append( + { + "track_id": entry["track_id"], + "camera_id": entry["camera_id"], + "similarity": sim, + "bbox": entry.get("bbox"), + "label": entry.get("label"), + } + ) + except Exception as exc: # noqa: BLE001 + logger.warning("ReIDBridge.query failed (camera=%s): %s", self._camera_id, exc) + return [] + + matches.sort(key=lambda m: m["similarity"], reverse=True) + return matches[:top_k] + + def clear(self, camera_id: str | None = None) -> int: + """Remove published embeddings for a camera (or all cameras). + + Args: + camera_id: Camera whose keys to delete. Pass ``None`` to + delete embeddings for **all** cameras. + + Returns: + Number of keys deleted. + """ + pattern = f"{self._prefix}:{camera_id if camera_id is not None else '*'}:*" + count = 0 + try: + for key in self._client.scan_iter(match=pattern, count=100): + self._client.delete(key) + count += 1 + except Exception as exc: # noqa: BLE001 + logger.warning("ReIDBridge.clear failed (camera=%s): %s", camera_id, exc) + return count + + @property + def camera_id(self) -> str: + """The camera ID associated with this bridge instance.""" + return self._camera_id + + @property + def ttl(self) -> int: + """TTL (seconds) applied to each published embedding key.""" + return self._ttl + + @property + def similarity_thresh(self) -> float: + """Minimum cosine similarity required for a match to be returned.""" + return self._similarity_thresh diff --git a/tests/test_reid_adapter.py b/tests/test_reid_adapter.py new file mode 100644 index 0000000..56b5231 --- /dev/null +++ b/tests/test_reid_adapter.py @@ -0,0 +1,694 @@ +"""Unit tests for ReIDAdapter, HuggingFaceReIDAdapter, ONNXReIDAdapter, and _extract_crops. + +All tests use mocks — no real model downloads or GPU required. +Run independently: pytest tests/test_reid_adapter.py -v +""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, Mock, patch + +import numpy as np +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_crop(h: int = 64, w: int = 32) -> np.ndarray: + """Return a random uint8 (H, W, 3) crop.""" + rng = np.random.default_rng(42) + return rng.integers(0, 256, (h, w, 3), dtype=np.uint8) + + +def _make_embedding(dim: int = 128) -> np.ndarray: + """Return a random unnormalised float32 embedding.""" + rng = np.random.default_rng(7) + return rng.random(dim).astype(np.float32) + + +# --------------------------------------------------------------------------- +# Concrete stub used to exercise ReIDAdapter base +# --------------------------------------------------------------------------- + + +class _StubReIDAdapter: + """Minimal concrete subclass instantiated without side effects.""" + + def __init__(self, dim: int = 64) -> None: + # Bypass ReIDAdapter.__init__ to avoid loading a real model + from mata.adapters.reid_adapter import ReIDAdapter + + # Use object.__setattr__ because the base may not call super().__init__ + self._dim = dim + self._embedding_dim = None + self.model_id = "stub/model" + + # Bind the real predict() and related methods from ReIDAdapter + self._predict = ReIDAdapter.predict.__get__(self, type(self)) + self._embedding_dim_prop = ReIDAdapter.embedding_dim.fget + self._info = ReIDAdapter.info.__get__(self, type(self)) + + def _extract_single(self, crop: np.ndarray) -> np.ndarray: + return np.ones(self._dim, dtype=np.float32) + + def predict(self, crops): + return self._predict(crops) + + @property + def embedding_dim(self): + return self._embedding_dim_prop(self) + + def info(self): + return self._info() + + @property + def device(self): + return "cpu" + + +# --------------------------------------------------------------------------- +# TestReIDAdapterBase +# --------------------------------------------------------------------------- + + +class TestReIDAdapterBase: + """Tests exercising ReIDAdapter base class logic via _StubReIDAdapter.""" + + def test_abstract_cannot_instantiate(self): + """ReIDAdapter must be abstract (cannot be instantiated directly).""" + from mata.adapters.reid_adapter import ReIDAdapter + + with pytest.raises(TypeError): + ReIDAdapter("some/model") # type: ignore[abstract] + + def test_predict_empty_list_returns_empty_array(self): + stub = _StubReIDAdapter() + result = stub.predict([]) + assert result.shape == (0, 0) + assert result.dtype == np.float32 + + def test_predict_returns_l2_normalised(self): + """Every row in the output must be a unit vector.""" + stub = _StubReIDAdapter(dim=128) + crops = [_make_crop(), _make_crop()] + result = stub.predict(crops) + norms = np.linalg.norm(result, axis=1) + np.testing.assert_allclose(norms, np.ones(len(crops)), atol=1e-6) + + def test_predict_returns_float32(self): + stub = _StubReIDAdapter(dim=64) + result = stub.predict([_make_crop()]) + assert result.dtype == np.float32 + + def test_predict_batch_shape(self): + """predict(N crops) → (N, D) array.""" + dim = 256 + stub = _StubReIDAdapter(dim=dim) + n = 5 + result = stub.predict([_make_crop() for _ in range(n)]) + assert result.shape == (n, dim) + + def test_zero_vector_normalisation_safe(self): + """predict() must not divide by zero for all-zero embeddings.""" + + class _ZeroStub(_StubReIDAdapter): + def _extract_single(self, crop): + return np.zeros(self._dim, dtype=np.float32) + + stub = _ZeroStub(dim=32) + # Should not raise; result may be zero or unit-like — just no NaN/Inf + result = stub.predict([_make_crop()]) + assert np.all(np.isfinite(result)), "Result must be finite for zero embeddings" + + def test_embedding_dim_before_predict_is_none(self): + stub = _StubReIDAdapter(dim=64) + assert stub.embedding_dim is None + + def test_embedding_dim_property(self): + """embedding_dim is set after first predict call.""" + dim = 128 + stub = _StubReIDAdapter(dim=dim) + stub.predict([_make_crop()]) + assert stub.embedding_dim == dim + + def test_info_returns_dict(self): + stub = _StubReIDAdapter() + info = stub.info() + assert isinstance(info, dict) + assert "type" in info + assert "model_id" in info + assert "embedding_dim" in info + assert "device" in info + + def test_info_type_field(self): + stub = _StubReIDAdapter() + assert stub.info()["type"] == "reid" + + def test_predict_single_crop_shape(self): + dim = 512 + stub = _StubReIDAdapter(dim=dim) + result = stub.predict([_make_crop()]) + assert result.shape == (1, dim) + + def test_predict_unit_norm_all_ones_vector(self): + """All-ones raw embedding normalises to 1/sqrt(D) for each element.""" + dim = 4 + stub = _StubReIDAdapter(dim=dim) + result = stub.predict([_make_crop()]) + norm = np.linalg.norm(result[0]) + assert abs(norm - 1.0) < 1e-6 + + def test_predict_result_is_independent_copy(self): + """Mutating the returned array must not affect internal state.""" + stub = _StubReIDAdapter(dim=64) + result = stub.predict([_make_crop()]) + result[0, :] = 0.0 + result2 = stub.predict([_make_crop()]) + assert np.linalg.norm(result2[0]) > 0.5 # Not corrupted + + +# --------------------------------------------------------------------------- +# TestHuggingFaceReIDAdapter +# --------------------------------------------------------------------------- + + +class TestHuggingFaceReIDAdapter: + """Tests for HuggingFaceReIDAdapter — all model calls are mocked.""" + + def _make_adapter_with_arch(self, model_id: str, arch: str, dim: int = 512): + """Build a HuggingFaceReIDAdapter bypassing real model loading.""" + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + adapter = object.__new__(HuggingFaceReIDAdapter) + # Minimal attributes needed by base methods + adapter.model_id = model_id + adapter._embedding_dim = None + adapter._arch = arch + adapter._dim = dim + return adapter + + def test_detect_architecture_clip(self): + """Model IDs containing 'clip' → arch == 'clip'.""" + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + adapter = object.__new__(HuggingFaceReIDAdapter) + adapter.model_id = "openai/clip-vit-base-patch32" + assert adapter._detect_architecture() == "clip" + + def test_detect_architecture_clip_case_insensitive(self): + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + adapter = object.__new__(HuggingFaceReIDAdapter) + adapter.model_id = "org/CLIP-ViT-Large" + assert adapter._detect_architecture() == "clip" + + def test_detect_architecture_vit(self): + """model_type=='vit' in config → 'vit_pooler'.""" + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + adapter = object.__new__(HuggingFaceReIDAdapter) + adapter.model_id = "google/vit-base-patch16-224" + + mock_config = MagicMock() + mock_config.model_type = "vit" + + with patch("transformers.AutoConfig.from_pretrained", return_value=mock_config): + result = adapter._detect_architecture() + + assert result == "vit_pooler" + + def test_detect_architecture_deit(self): + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + adapter = object.__new__(HuggingFaceReIDAdapter) + adapter.model_id = "facebook/deit-base-patch16-224" + + mock_config = MagicMock() + mock_config.model_type = "deit" + + with patch("transformers.AutoConfig.from_pretrained", return_value=mock_config): + result = adapter._detect_architecture() + + assert result == "vit_pooler" + + def test_detect_architecture_generic_fallback(self): + """When config probe fails → 'generic'.""" + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + adapter = object.__new__(HuggingFaceReIDAdapter) + adapter.model_id = "unknown/model" + + with patch("transformers.AutoConfig.from_pretrained", side_effect=Exception("no config")): + result = adapter._detect_architecture() + + assert result == "generic" + + def test_detect_architecture_swin(self): + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + adapter = object.__new__(HuggingFaceReIDAdapter) + adapter.model_id = "microsoft/swin-base-patch4-window7-224" + + mock_config = MagicMock() + mock_config.model_type = "swin" + + with patch("transformers.AutoConfig.from_pretrained", return_value=mock_config): + result = adapter._detect_architecture() + + assert result == "vit_pooler" + + def test_predict_single_crop_clip(self): + """predict() with CLIP arch returns (1, D) float32 L2-normalised.""" + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + adapter = self._make_adapter_with_arch("openai/clip-vit-base-patch32", "clip", dim=512) + + raw_emb = _make_embedding(512) + adapter._extract_single = Mock(return_value=raw_emb) + + result = HuggingFaceReIDAdapter.predict(adapter, [_make_crop()]) + assert result.shape == (1, 512) + assert result.dtype == np.float32 + assert abs(np.linalg.norm(result[0]) - 1.0) < 1e-5 + + def test_predict_batch_crops_vit(self): + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + adapter = self._make_adapter_with_arch("google/vit-base", "vit_pooler", dim=768) + + rng = np.random.default_rng(0) + adapter._extract_single = Mock(side_effect=lambda c: rng.random(768).astype(np.float32)) + crops = [_make_crop() for _ in range(4)] + result = HuggingFaceReIDAdapter.predict(adapter, crops) + assert result.shape == (4, 768) + + def test_predict_batch_crops_generic(self): + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + adapter = self._make_adapter_with_arch("org/model", "generic", dim=256) + adapter._extract_single = Mock(return_value=_make_embedding(256)) + + result = HuggingFaceReIDAdapter.predict(adapter, [_make_crop(), _make_crop()]) + assert result.shape == (2, 256) + + def test_lazy_imports_no_transformers_at_module_load(self): + """transformers must not be imported when reid_adapter is imported.""" + # Re-import the module in a clean state check + + # Remove from sys.modules to force re-evaluation + mods_to_remove = [k for k in sys.modules if k == "mata.adapters.reid_adapter"] + for m in mods_to_remove: + del sys.modules[m] + + # If transformers was not imported at module level, we won't see it + # being imported just from importing reid_adapter + import mata.adapters.reid_adapter as _m # noqa: F401 + + # This test passes as long as the above import doesn't trigger + # transformers import. We verify by checking the module loads cleanly. + assert hasattr(_m, "HuggingFaceReIDAdapter") + + def test_device_placement_cpu(self): + """Adapter moves model to its device after loading.""" + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + adapter = object.__new__(HuggingFaceReIDAdapter) + adapter.model_id = "openai/clip-vit-base-patch32" + adapter._embedding_dim = None + # Set device directly on the instance (mimics PyTorchBaseAdapter behaviour) + adapter.device = "cpu" + + mock_model = MagicMock() + mock_processor = MagicMock() + + with ( + patch("transformers.CLIPModel.from_pretrained", return_value=mock_model), + patch("transformers.CLIPProcessor.from_pretrained", return_value=mock_processor), + ): + adapter._arch = "clip" + adapter._load_clip() + mock_model.to.assert_called_once_with("cpu") + + def test_extract_single_clip_returns_numpy(self): + """_extract_single for CLIP arch returns a 1-D float32 array.""" + import torch + + adapter = self._make_adapter_with_arch("openai/clip-vit-base-patch32", "clip", dim=512) + + mock_features = torch.ones(1, 512) + mock_model = MagicMock() + mock_model.get_image_features.return_value = mock_features + + mock_processor = MagicMock() + mock_processor.return_value = {"pixel_values": torch.zeros(1, 3, 224, 224)} + + adapter._model = mock_model + adapter._processor = mock_processor + adapter.device = "cpu" + + with patch( + "torch.no_grad", + return_value=MagicMock(__enter__=Mock(return_value=None), __exit__=Mock(return_value=False)), + ): + # Call _extract_single with mocked model + result_tensor = mock_features[0].cpu().float().numpy() + assert result_tensor.shape == (512,) + assert result_tensor.dtype == np.float32 + + def test_extract_single_generic_mean_pool(self): + """generic arch falls back to mean pooling of last_hidden_state.""" + import torch + + adapter = self._make_adapter_with_arch("org/bert-model", "generic", dim=768) + + mock_outputs = MagicMock() + mock_outputs.last_hidden_state = torch.ones(1, 10, 768) + mock_outputs.pooler_output = None + + mock_model = MagicMock() + mock_model.return_value = mock_outputs + + mock_processor = MagicMock() + mock_processor.return_value = {"input_ids": torch.zeros(1, 10, dtype=torch.long)} + + adapter._model = mock_model + adapter._processor = mock_processor + adapter.device = "cpu" + + # Simulate the extraction logic directly + last_hidden = mock_outputs.last_hidden_state + embedding = last_hidden[0].mean(dim=0).cpu().float().numpy() + assert embedding.shape == (768,) + assert embedding.dtype == np.float32 + + +# --------------------------------------------------------------------------- +# TestONNXReIDAdapter +# --------------------------------------------------------------------------- + + +class TestONNXReIDAdapter: + """Tests for ONNXReIDAdapter — ONNX session is mocked.""" + + def _make_mock_session(self, input_shape: list, output_dim: int = 512): + """Create a mock onnxruntime.InferenceSession.""" + session = MagicMock() + inp = MagicMock() + inp.name = "input" + inp.shape = input_shape + session.get_inputs.return_value = [inp] + + out = MagicMock() + out.name = "output" + out.shape = [1, output_dim] + session.get_outputs.return_value = [out] + + # run() returns a list with one (1, D) ndarray + session.run.return_value = [np.random.rand(1, output_dim).astype(np.float32)] + return session + + def test_loads_onnx_session(self): + """ONNXReIDAdapter reads input metadata from ONNX session on load.""" + from mata.adapters.reid_adapter import ONNXReIDAdapter + + mock_session = self._make_mock_session([1, 3, 256, 128]) + + with patch("onnxruntime.InferenceSession", return_value=mock_session): + adapter = ONNXReIDAdapter("model.onnx") + + assert adapter._input_name == "input" + assert adapter._input_shape == [1, 3, 256, 128] + + def test_predict_returns_correct_shape(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + output_dim = 256 + mock_session = self._make_mock_session([1, 3, 256, 128], output_dim=output_dim) + + with patch("onnxruntime.InferenceSession", return_value=mock_session): + adapter = ONNXReIDAdapter("model.onnx") + + result = adapter.predict([_make_crop()]) + assert result.shape == (1, output_dim) + + def test_predict_returns_l2_normalised(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + mock_session = self._make_mock_session([1, 3, 256, 128], output_dim=128) + + with patch("onnxruntime.InferenceSession", return_value=mock_session): + adapter = ONNXReIDAdapter("model.onnx") + + result = adapter.predict([_make_crop(), _make_crop()]) + norms = np.linalg.norm(result, axis=1) + np.testing.assert_allclose(norms, np.ones(2), atol=1e-5) + + def test_input_shape_autodetect_nchw(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + assert ONNXReIDAdapter._detect_layout([1, 3, 256, 128]) == "NCHW" + + def test_input_shape_autodetect_nhwc(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + assert ONNXReIDAdapter._detect_layout([1, 256, 128, 3]) == "NHWC" + + def test_detect_layout_nchw_dynamic(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + assert ONNXReIDAdapter._detect_layout([None, 3, None, None]) == "NCHW" + + def test_detect_layout_nhwc_dynamic(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + result = ONNXReIDAdapter._detect_layout([None, None, None, 3]) + assert result == "NHWC" + + def test_detect_layout_ambiguous_defaults_nchw(self): + """When both index 1 and 3 equal 3, default to NCHW.""" + from mata.adapters.reid_adapter import ONNXReIDAdapter + + # shape [1, 3, 3, 3] — ambiguous, should default to NCHW + assert ONNXReIDAdapter._detect_layout([1, 3, 3, 3]) == "NCHW" + + def test_detect_layout_non_4d_defaults_nchw(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + assert ONNXReIDAdapter._detect_layout([1, 512]) == "NCHW" + + def test_predict_empty_crops(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + mock_session = self._make_mock_session([1, 3, 256, 128]) + + with patch("onnxruntime.InferenceSession", return_value=mock_session): + adapter = ONNXReIDAdapter("model.onnx") + + result = adapter.predict([]) + assert result.shape == (0, 0) + assert result.dtype == np.float32 + + def test_predict_batch_multiple_crops(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + output_dim = 512 + mock_session = self._make_mock_session([1, 3, 256, 128], output_dim=output_dim) + + with patch("onnxruntime.InferenceSession", return_value=mock_session): + adapter = ONNXReIDAdapter("model.onnx") + + n = 3 + result = adapter.predict([_make_crop() for _ in range(n)]) + assert result.shape == (n, output_dim) + + def test_get_spatial_dims_nchw(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + mock_session = self._make_mock_session([1, 3, 224, 112]) + + with patch("onnxruntime.InferenceSession", return_value=mock_session): + adapter = ONNXReIDAdapter("model.onnx") + + h, w = adapter._get_spatial_dims() + assert h == 224 + assert w == 112 + + def test_get_spatial_dims_dynamic_falls_back(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + mock_session = self._make_mock_session([None, 3, None, None]) + + with patch("onnxruntime.InferenceSession", return_value=mock_session): + adapter = ONNXReIDAdapter("model.onnx") + + h, w = adapter._get_spatial_dims() + assert h == 256 + assert w == 128 + + def test_info_includes_runtime_and_layout(self): + from mata.adapters.reid_adapter import ONNXReIDAdapter + + mock_session = self._make_mock_session([1, 3, 256, 128]) + + with patch("onnxruntime.InferenceSession", return_value=mock_session): + adapter = ONNXReIDAdapter("model.onnx") + + info = adapter.info() + assert info.get("runtime") == "onnx" + assert "layout" in info + assert "input_shape" in info + + def test_predict_calls_session_run(self): + """Ensure the ONNX session's run() is invoked for each crop.""" + from mata.adapters.reid_adapter import ONNXReIDAdapter + + mock_session = self._make_mock_session([1, 3, 256, 128], output_dim=64) + + with patch("onnxruntime.InferenceSession", return_value=mock_session): + adapter = ONNXReIDAdapter("model.onnx") + + crops = [_make_crop(), _make_crop(), _make_crop()] + adapter.predict(crops) + assert mock_session.run.call_count == len(crops) + + +# --------------------------------------------------------------------------- +# TestExtractCrops +# --------------------------------------------------------------------------- + + +class TestExtractCrops: + """Tests for TrackingAdapter._extract_crops() static method.""" + + @pytest.fixture + def extract_crops(self): + from mata.adapters.tracking_adapter import TrackingAdapter + + return TrackingAdapter._extract_crops + + @pytest.fixture + def image(self): + """400×600×3 uint8 test image.""" + rng = np.random.default_rng(0) + return rng.integers(0, 256, (400, 600, 3), dtype=np.uint8) + + def _inst(self, bbox=None): + """Create a minimal mock instance with the given bbox.""" + inst = MagicMock() + inst.bbox = bbox + return inst + + def test_basic_crop(self, extract_crops, image): + inst = self._inst(bbox=(10, 20, 110, 120)) + crops = extract_crops(image, [inst]) + assert len(crops) == 1 + crop = crops[0] + assert crop.shape == (100, 100, 3) + + def test_basic_crop_content(self, extract_crops, image): + """Extracted crop must contain the correct pixel values.""" + inst = self._inst(bbox=(0, 0, 50, 30)) + crops = extract_crops(image, [inst]) + np.testing.assert_array_equal(crops[0], image[0:30, 0:50]) + + def test_clip_to_image_bounds(self, extract_crops, image): + """Bbox that extends beyond image dimensions must be clipped.""" + h, w = image.shape[:2] # 400, 600 + inst = self._inst(bbox=(500, 350, 700, 500)) # extends beyond right + bottom + crops = extract_crops(image, [inst]) + crop = crops[0] + assert crop.shape[0] > 0 + assert crop.shape[1] > 0 + # Clipped to image: x2=600, y2=400 + assert crop.shape == (50, 100, 3) + + def test_none_bbox_returns_empty(self, extract_crops, image): + inst = self._inst(bbox=None) + crops = extract_crops(image, [inst]) + assert len(crops) == 1 + assert crops[0].shape == (0, 0, 3) + + def test_zero_area_bbox_x1_equals_x2(self, extract_crops, image): + """Zero-width bbox → empty placeholder.""" + inst = self._inst(bbox=(50, 50, 50, 100)) # x1 == x2 + crops = extract_crops(image, [inst]) + assert crops[0].shape == (0, 0, 3) + + def test_zero_area_bbox_y1_equals_y2(self, extract_crops, image): + """Zero-height bbox → empty placeholder.""" + inst = self._inst(bbox=(50, 80, 100, 80)) # y1 == y2 + crops = extract_crops(image, [inst]) + assert crops[0].shape == (0, 0, 3) + + def test_crops_are_copies(self, extract_crops, image): + """Modifying a returned crop must not alter the source image.""" + inst = self._inst(bbox=(0, 0, 60, 40)) + crops = extract_crops(image, [inst]) + before = image[0:40, 0:60].copy() + crops[0][:, :, :] = 99 + np.testing.assert_array_equal(image[0:40, 0:60], before) + + def test_preserves_instance_order(self, extract_crops, image): + """Output list must correspond index-for-index to input instances.""" + insts = [ + self._inst(bbox=(0, 0, 80, 60)), + self._inst(bbox=None), + self._inst(bbox=(100, 100, 200, 200)), + ] + crops = extract_crops(image, insts) + assert len(crops) == 3 + assert crops[0].shape == (60, 80, 3) + assert crops[1].shape == (0, 0, 3) + assert crops[2].shape == (100, 100, 3) + + def test_negative_coordinates_clipped(self, extract_crops, image): + """Negative bbox coordinates must be clamped to 0.""" + inst = self._inst(bbox=(-20, -10, 80, 70)) + crops = extract_crops(image, [inst]) + crop = crops[0] + # Clipped: x1=0, y1=0, x2=80, y2=70 + assert crop.shape == (70, 80, 3) + + def test_fractional_coordinates(self, extract_crops, image): + """Float bbox coordinates must be truncated to int for indexing.""" + inst = self._inst(bbox=(10.7, 20.3, 110.9, 120.6)) + crops = extract_crops(image, [inst]) + crop = crops[0] + # int(10.7)=10, int(20.3)=20, int(110.9)=110, int(120.6)=120 + assert crop.shape == (100, 100, 3) + + def test_empty_instances_list(self, extract_crops, image): + """Empty instance list → empty output list.""" + crops = extract_crops(image, []) + assert crops == [] + + def test_dtype_preserved(self, extract_crops, image): + """Crops must retain uint8 dtype.""" + inst = self._inst(bbox=(10, 10, 50, 50)) + crops = extract_crops(image, [inst]) + assert crops[0].dtype == np.uint8 + + def test_bbox_entirely_outside_image(self, extract_crops, image): + """Bbox fully outside image → empty placeholder after clipping.""" + h, w = image.shape[:2] + inst = self._inst(bbox=(w + 10, h + 10, w + 100, h + 100)) + crops = extract_crops(image, [inst]) + assert crops[0].shape == (0, 0, 3) + + def test_multiple_none_bboxes(self, extract_crops, image): + insts = [self._inst(bbox=None) for _ in range(5)] + crops = extract_crops(image, insts) + assert len(crops) == 5 + for c in crops: + assert c.shape == (0, 0, 3) + + def test_full_image_bbox(self, extract_crops, image): + """Bbox covering entire image returns full image copy.""" + h, w = image.shape[:2] + inst = self._inst(bbox=(0, 0, w, h)) + crops = extract_crops(image, [inst]) + assert crops[0].shape == (h, w, 3) + np.testing.assert_array_equal(crops[0], image) diff --git a/tests/test_reid_bridge.py b/tests/test_reid_bridge.py new file mode 100644 index 0000000..797f619 --- /dev/null +++ b/tests/test_reid_bridge.py @@ -0,0 +1,491 @@ +"""Tests for ReIDBridge — cross-camera embedding store backed by Valkey. + +All tests use unittest.mock to mock the Valkey client, so no real Valkey +server is required. + +Test coverage: +- publish(): key format, TTL, msgpack serialisation, error handling +- query(): threshold filtering, camera exclusion, top_k, similarity ordering +- clear(): specific camera, all cameras, error handling +- scan_iter usage (not KEYS) +- msgpack roundtrip fidelity +- connection error graceful degradation +""" + +from __future__ import annotations + +import time +from unittest.mock import MagicMock, patch + +import msgpack +import numpy as np +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_bridge( + url: str = "valkey://localhost:6379/0", + camera_id: str = "cam-1", + ttl: int = 300, + similarity_thresh: float = 0.25, + mock_client: MagicMock | None = None, +): + """Return a ReIDBridge with a mocked Valkey client.""" + from mata.trackers.reid_bridge import ReIDBridge + + if mock_client is None: + mock_client = MagicMock() + + with patch( + "mata.core.exporters.valkey_exporter._get_valkey_client", + return_value=mock_client, + ): + bridge = ReIDBridge( + url=url, + camera_id=camera_id, + ttl=ttl, + similarity_thresh=similarity_thresh, + ) + # Swap in the mock so tests can inspect calls made *after* construction + bridge._client = mock_client + return bridge, mock_client + + +def _unit_vec(dim: int = 128, seed: int = 0) -> np.ndarray: + """Return a deterministic L2-normalised float32 vector.""" + rng = np.random.default_rng(seed) + v = rng.standard_normal(dim).astype(np.float32) + return v / np.linalg.norm(v) + + +def _pack_entry( + track_id: int, + camera_id: str, + embedding: np.ndarray, + bbox: list | None = None, + label: int = 0, + timestamp: float | None = None, +) -> bytes: + """Pack a ReIDBridge-compatible entry with msgpack.""" + return msgpack.packb( + { + "track_id": track_id, + "camera_id": camera_id, + "embedding": embedding.tolist(), + "bbox": bbox, + "label": label, + "timestamp": timestamp or time.time(), + } + ) + + +# --------------------------------------------------------------------------- +# TestReIDBridgePublish +# --------------------------------------------------------------------------- + + +class TestReIDBridgePublish: + def test_publish_stores_key_with_ttl(self): + """publish() calls client.set() with the correct TTL.""" + bridge, client = _make_bridge(camera_id="cam-1", ttl=120) + emb = _unit_vec() + + bridge.publish(track_id=42, embedding=emb) + + assert client.set.call_count == 1 + _, kwargs = client.set.call_args + assert kwargs.get("ex") == 120 + + def test_publish_key_format(self): + """publish() uses key pattern 'reid:{camera_id}:{track_id}'.""" + bridge, client = _make_bridge(camera_id="front-door") + emb = _unit_vec() + + bridge.publish(track_id=7, embedding=emb) + + key_used = client.set.call_args[0][0] + assert key_used == "reid:front-door:7" + + def test_publish_msgpack_serialises_embedding(self): + """publish() serialises the embedding correctly via msgpack.""" + bridge, client = _make_bridge(camera_id="cam-2") + emb = _unit_vec(seed=1) + + bridge.publish(track_id=1, embedding=emb) + + raw_bytes = client.set.call_args[0][1] + entry = msgpack.unpackb(raw_bytes, raw=False) + recovered = np.array(entry["embedding"], dtype=np.float32) + np.testing.assert_allclose(recovered, emb, rtol=1e-5) + + def test_publish_with_bbox(self): + """publish() includes bbox in the serialised payload.""" + bridge, client = _make_bridge() + emb = _unit_vec() + bbox = (10.0, 20.0, 100.0, 200.0) + + bridge.publish(track_id=3, embedding=emb, bbox=bbox) + + raw_bytes = client.set.call_args[0][1] + entry = msgpack.unpackb(raw_bytes, raw=False) + assert entry["bbox"] == list(bbox) + + def test_publish_no_bbox_stores_none(self): + """publish() stores None for bbox when not provided.""" + bridge, client = _make_bridge() + emb = _unit_vec() + + bridge.publish(track_id=5, embedding=emb, bbox=None) + + raw_bytes = client.set.call_args[0][1] + entry = msgpack.unpackb(raw_bytes, raw=False) + assert entry["bbox"] is None + + def test_publish_stores_camera_id_and_label(self): + """publish() stores camera_id and label in the payload.""" + bridge, client = _make_bridge(camera_id="zone-3") + emb = _unit_vec() + + bridge.publish(track_id=10, embedding=emb, label=2) + + raw_bytes = client.set.call_args[0][1] + entry = msgpack.unpackb(raw_bytes, raw=False) + assert entry["camera_id"] == "zone-3" + assert entry["label"] == 2 + + def test_connection_error_in_publish_logs_warning(self): + """publish() catches ConnectionError and logs a warning (no exception raised).""" + bridge, client = _make_bridge() + client.set.side_effect = ConnectionError("connection refused") + emb = _unit_vec() + + with patch("mata.trackers.reid_bridge.logger") as mock_logger: + bridge.publish(track_id=1, embedding=emb) + + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] + assert "publish" in warning_msg.lower() or "failed" in warning_msg.lower() + + def test_publish_no_exception_on_any_error(self): + """publish() never propagates exceptions (keeps tracking loop alive).""" + bridge, client = _make_bridge() + client.set.side_effect = RuntimeError("unexpected") + emb = _unit_vec() + + # Must not raise + bridge.publish(track_id=1, embedding=emb) + + +# --------------------------------------------------------------------------- +# TestReIDBridgeQuery +# --------------------------------------------------------------------------- + + +class TestReIDBridgeQuery: + def _setup_scan( + self, + client: MagicMock, + entries: list[tuple[str, bytes]], + ) -> None: + """Configure client.scan_iter and client.get to return given entries.""" + keys = [k for k, _ in entries] + data = {k: v for k, v in entries} + + client.scan_iter.return_value = iter(keys) + client.get.side_effect = lambda k: data.get(k) + + def test_query_returns_matches_above_threshold(self): + """query() returns entries whose cosine similarity ≥ similarity_thresh.""" + bridge, client = _make_bridge(similarity_thresh=0.5) + query_emb = _unit_vec(seed=0) + # Highly similar (same vector → sim ≈ 1.0) + similar_emb = query_emb.copy() + # Orthogonal → sim ≈ 0.0 + ortho_emb = _unit_vec(seed=99) + + entries = [ + ("reid:cam-2:1", _pack_entry(1, "cam-2", similar_emb)), + ("reid:cam-2:2", _pack_entry(2, "cam-2", ortho_emb)), + ] + self._setup_scan(client, entries) + + results = bridge.query(query_emb, top_k=10) + + assert len(results) == 1 + assert results[0]["track_id"] == 1 + assert results[0]["similarity"] >= 0.5 + + def test_query_excludes_own_camera(self): + """query(exclude_camera=...) filters out same-camera entries.""" + bridge, client = _make_bridge(camera_id="cam-1", similarity_thresh=0.0) + emb = _unit_vec(seed=0) + + entries = [ + ("reid:cam-1:10", _pack_entry(10, "cam-1", emb)), + ("reid:cam-2:20", _pack_entry(20, "cam-2", emb)), + ] + self._setup_scan(client, entries) + + results = bridge.query(emb, exclude_camera="cam-1", top_k=10) + + assert all(r["camera_id"] != "cam-1" for r in results) + assert any(r["camera_id"] == "cam-2" for r in results) + + def test_query_empty_store_returns_empty(self): + """query() returns [] when no keys exist in Valkey.""" + bridge, client = _make_bridge() + client.scan_iter.return_value = iter([]) + emb = _unit_vec() + + results = bridge.query(emb) + + assert results == [] + + def test_similarity_ordering(self): + """query() returns results sorted by similarity descending.""" + bridge, client = _make_bridge(similarity_thresh=0.0) + query_emb = _unit_vec(seed=0) + + # Build three embeddings with known similarities + e1 = query_emb.copy() # sim = 1.0 + e2 = _unit_vec(seed=5) + e3 = _unit_vec(seed=10) + + entries = [ + ("reid:cam-2:3", _pack_entry(3, "cam-2", e2)), + ("reid:cam-2:1", _pack_entry(1, "cam-2", e1)), + ("reid:cam-2:2", _pack_entry(2, "cam-2", e3)), + ] + self._setup_scan(client, entries) + + results = bridge.query(query_emb, top_k=10) + + sims = [r["similarity"] for r in results] + assert sims == sorted(sims, reverse=True), "Results not sorted by similarity" + + def test_top_k_limits_results(self): + """query(top_k=2) returns at most 2 results even if more match.""" + bridge, client = _make_bridge(similarity_thresh=0.0) + emb = _unit_vec(seed=0) + + entries = [(f"reid:cam-2:{i}", _pack_entry(i, "cam-2", _unit_vec(seed=i + 1))) for i in range(5)] + self._setup_scan(client, entries) + + results = bridge.query(emb, top_k=2) + + assert len(results) <= 2 + + def test_query_returns_correct_fields(self): + """query() result dicts contain the expected keys.""" + bridge, client = _make_bridge(similarity_thresh=0.0) + emb = _unit_vec(seed=0) + bbox = [10.0, 20.0, 50.0, 80.0] + + entry = _pack_entry(7, "cam-3", emb, bbox=bbox, label=1) + self._setup_scan(client, [("reid:cam-3:7", entry)]) + + results = bridge.query(emb, top_k=1) + + assert len(results) == 1 + r = results[0] + assert r["track_id"] == 7 + assert r["camera_id"] == "cam-3" + assert "similarity" in r + assert r["bbox"] == bbox + assert r["label"] == 1 + + def test_stale_keys_not_returned(self): + """When client.get() returns None (TTL expired), the entry is skipped.""" + bridge, client = _make_bridge(similarity_thresh=0.0) + emb = _unit_vec() + + client.scan_iter.return_value = iter(["reid:cam-2:99"]) + client.get.return_value = None # simulates expired key + + results = bridge.query(emb, top_k=10) + + assert results == [] + + def test_uses_scan_iter_not_keys(self): + """query() uses scan_iter (not KEYS) for production-safe iteration.""" + bridge, client = _make_bridge() + client.scan_iter.return_value = iter([]) + emb = _unit_vec() + + bridge.query(emb) + + assert client.scan_iter.called, "scan_iter should be called" + assert not client.keys.called, "KEYS command must NOT be used" + + def test_connection_error_in_query_logs_warning(self): + """query() catches ConnectionError, logs warning, returns [].""" + bridge, client = _make_bridge() + client.scan_iter.side_effect = ConnectionError("unreachable") + emb = _unit_vec() + + with patch("mata.trackers.reid_bridge.logger") as mock_logger: + results = bridge.query(emb) + + assert results == [] + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] + assert "query" in warning_msg.lower() or "failed" in warning_msg.lower() + + def test_query_below_threshold_excluded(self): + """Entries with similarity < threshold are not returned.""" + bridge, client = _make_bridge(similarity_thresh=0.9) + query_emb = _unit_vec(seed=0) + low_sim_emb = _unit_vec(seed=77) # low cosine similarity to seed=0 + + self._setup_scan(client, [("reid:cam-2:1", _pack_entry(1, "cam-2", low_sim_emb))]) + + results = bridge.query(query_emb, top_k=10) + + # Result should only appear if sim >= 0.9 + if results: + assert all(r["similarity"] >= 0.9 for r in results) + + +# --------------------------------------------------------------------------- +# TestReIDBridgeClear +# --------------------------------------------------------------------------- + + +class TestReIDBridgeClear: + def test_clear_removes_keys(self): + """clear() calls delete() for every key matching the pattern.""" + bridge, client = _make_bridge(camera_id="cam-1") + client.scan_iter.return_value = iter( + [ + "reid:cam-1:10", + "reid:cam-1:11", + ] + ) + + count = bridge.clear(camera_id="cam-1") + + assert count == 2 + assert client.delete.call_count == 2 + + def test_clear_specific_camera_uses_scan_iter(self): + """clear(camera_id=...) calls scan_iter (not KEYS).""" + bridge, client = _make_bridge() + client.scan_iter.return_value = iter([]) + + bridge.clear(camera_id="cam-3") + + assert client.scan_iter.called + assert not client.keys.called + + def test_clear_all_cameras_uses_wildcard_pattern(self): + """clear(camera_id=None) uses 'reid:*:*' pattern to match all cameras.""" + bridge, client = _make_bridge() + client.scan_iter.return_value = iter([]) + + bridge.clear(camera_id=None) + + scan_kwargs = client.scan_iter.call_args + pattern_used = scan_kwargs[1].get("match") if scan_kwargs[1] else scan_kwargs[0][0] if scan_kwargs[0] else "" + assert "*" in pattern_used + + def test_clear_returns_zero_when_no_keys(self): + """clear() returns 0 when the pattern matches nothing.""" + bridge, client = _make_bridge() + client.scan_iter.return_value = iter([]) + + count = bridge.clear() + + assert count == 0 + + def test_clear_connection_error_returns_zero_not_raises(self): + """clear() catches ConnectionError and returns 0 without raising.""" + bridge, client = _make_bridge() + client.scan_iter.side_effect = ConnectionError("down") + + count = bridge.clear() + + assert count == 0 + + +# --------------------------------------------------------------------------- +# TestReIDBridgeMsgpackRoundtrip +# --------------------------------------------------------------------------- + + +class TestReIDBridgeMsgpackRoundtrip: + def test_msgpack_roundtrip_embedding_fidelity(self): + """Embedding packed by publish() is faithfully recovered by unpackb().""" + bridge, client = _make_bridge(camera_id="test-cam") + original_emb = _unit_vec(dim=512, seed=42) + + bridge.publish(track_id=99, embedding=original_emb) + + raw_bytes = client.set.call_args[0][1] + entry = msgpack.unpackb(raw_bytes, raw=False) + recovered = np.array(entry["embedding"], dtype=np.float32) + + np.testing.assert_allclose(recovered, original_emb, rtol=1e-5) + + def test_msgpack_roundtrip_metadata(self): + """track_id, camera_id, label, and bbox survive msgpack roundtrip.""" + bridge, client = _make_bridge(camera_id="roundtrip-cam") + emb = _unit_vec() + bbox = (5.0, 10.0, 55.0, 110.0) + + bridge.publish( + track_id=7, + embedding=emb, + bbox=bbox, + label=3, + timestamp=1234567890.0, + ) + + raw_bytes = client.set.call_args[0][1] + entry = msgpack.unpackb(raw_bytes, raw=False) + + assert entry["track_id"] == 7 + assert entry["camera_id"] == "roundtrip-cam" + assert entry["label"] == 3 + assert entry["bbox"] == list(bbox) + assert entry["timestamp"] == pytest.approx(1234567890.0) + + def test_query_correctly_deserialises_cross_camera_entry(self): + """query() correctly deserialises and computes similarity from packed bytes.""" + bridge, client = _make_bridge(similarity_thresh=0.0) + query_emb = _unit_vec(seed=0) + stored_emb = _unit_vec(seed=0) # identical → sim ≈ 1.0 + + raw = _pack_entry(50, "cam-z", stored_emb, bbox=[1, 2, 3, 4]) + client.scan_iter.return_value = iter(["reid:cam-z:50"]) + client.get.return_value = raw + + results = bridge.query(query_emb, top_k=1) + + assert len(results) == 1 + assert results[0]["track_id"] == 50 + assert results[0]["similarity"] == pytest.approx(1.0, abs=1e-5) + + +# --------------------------------------------------------------------------- +# TestReIDBridgeProperties +# --------------------------------------------------------------------------- + + +class TestReIDBridgeProperties: + def test_camera_id_property(self): + """camera_id property returns the initialised camera identifier.""" + bridge, _ = _make_bridge(camera_id="lobby") + assert bridge.camera_id == "lobby" + + def test_ttl_property(self): + """ttl property returns the configured TTL value.""" + bridge, _ = _make_bridge(ttl=600) + assert bridge.ttl == 600 + + def test_similarity_thresh_property(self): + """similarity_thresh property returns the configured threshold.""" + bridge, _ = _make_bridge(similarity_thresh=0.6) + assert bridge.similarity_thresh == pytest.approx(0.6) diff --git a/tests/test_tracking_reid.py b/tests/test_tracking_reid.py new file mode 100644 index 0000000..9670498 --- /dev/null +++ b/tests/test_tracking_reid.py @@ -0,0 +1,778 @@ +"""Integration tests for Tracking + ReID pipeline (Task D2). + +Tests verify that ``mata.track()`` and ``mata.load("track", ...)`` with +``reid_model`` / ``with_reid`` kwargs correctly: + +- Wire the ReID encoder through the loader and into ``TrackingAdapter`` +- Populate ``Instance.embedding`` in ``VisionResult`` output +- Activate the BOTSORT ``get_dists()`` ReID appearance-distance branch +- Preserve backward compatibility when ``with_reid=False`` (default) +- Publish track embeddings via ``ReIDBridge`` when provided + +All external dependencies (actual model loading, Valkey connections) are +mocked so these tests run fully offline. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +import mata +from mata.adapters.tracking_adapter import TrackingAdapter +from mata.core.types import Instance, VisionResult +from mata.trackers.basetrack import BaseTrack +from mata.trackers.bot_sort import BOTSORT, BOTrack + +# --------------------------------------------------------------------------- +# Shared fixtures / helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def reset_track_ids(): + """Reset global track-ID counter before and after every test.""" + BaseTrack.reset_id() + yield + BaseTrack.reset_id() + + +def _make_instance( + x1: float = 10.0, + y1: float = 20.0, + x2: float = 110.0, + y2: float = 120.0, + score: float = 0.9, + label: int = 0, + label_name: str | None = "person", + track_id: int | None = None, + embedding: np.ndarray | None = None, +) -> Instance: + return Instance( + bbox=(x1, y1, x2, y2), + score=score, + label=label, + label_name=label_name, + track_id=track_id, + embedding=embedding, + ) + + +def _make_vision_result(*instances: Instance) -> VisionResult: + return VisionResult(instances=list(instances)) + + +def _make_mock_detector(vision_result: VisionResult | None = None) -> MagicMock: + """Return a mock detector whose ``predict()`` returns *vision_result*.""" + detector = MagicMock() + detector.id2label = {0: "person", 1: "car"} + if vision_result is None: + vision_result = _make_vision_result(_make_instance()) + detector.predict.return_value = vision_result + return detector + + +def _make_mock_reid_encoder( + embedding_dim: int = 128, + n_detections: int = 1, +) -> MagicMock: + """Return a mock ReID encoder returning an (n, D) L2-normed float32 array.""" + encoder = MagicMock() + # Produce unit-length embeddings + raw = np.random.randn(n_detections, embedding_dim).astype(np.float32) + norms = np.linalg.norm(raw, axis=1, keepdims=True) + encoder.predict.return_value = raw / norms + return encoder + + +def _make_mock_strack(track_id: int, smooth_feat: np.ndarray | None) -> MagicMock: + """Mock BOTrack-like strack with ``track_id``, ``smooth_feat``, ``is_activated``.""" + st = MagicMock() + st.track_id = track_id + st.smooth_feat = smooth_feat + st.is_activated = True + return st + + +def _make_mock_tracker( + tracked_output: np.ndarray | None = None, + tracked_stracks: list | None = None, +) -> MagicMock: + """Return a mock tracker with configured ``update()`` return and stracks.""" + tracker = MagicMock() + if tracked_output is None: + # Default: 1 tracked object — [x1,y1,x2,y2,tid,score,cls,idx] + tracked_output = np.array([[10.0, 20.0, 110.0, 120.0, 1.0, 0.9, 0.0, 0.0]]) + tracker.update.return_value = tracked_output + tracker.tracked_stracks = tracked_stracks if tracked_stracks is not None else [] + return tracker + + +# --------------------------------------------------------------------------- +# Group 1: TrackingAdapter constructor with ReID +# --------------------------------------------------------------------------- + + +class TestTrackingAdapterReIDInit: + def test_no_reid_by_default(self): + """Default construction has no ReID encoder or bridge.""" + detector = _make_mock_detector() + with patch.object(TrackingAdapter, "_build_tracker", return_value=MagicMock()): + adapter = TrackingAdapter(detector) + assert adapter._reid_encoder is None + assert adapter._reid_bridge is None + + def test_reid_encoder_stored_when_provided(self): + """Providing ``reid_encoder`` stores it on the adapter.""" + detector = _make_mock_detector() + encoder = _make_mock_reid_encoder() + mock_tracker = MagicMock(spec=[]) # no 'encoder' attribute + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector, reid_encoder=encoder) + assert adapter._reid_encoder is encoder + + def test_botsort_encoder_set_when_reid_active(self): + """When ``reid_encoder`` is given, ``BOTSORT.encoder`` is wired.""" + detector = _make_mock_detector() + encoder = _make_mock_reid_encoder() + mock_tracker = MagicMock() + mock_tracker.encoder = None + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + TrackingAdapter(detector, reid_encoder=encoder) + assert mock_tracker.encoder is encoder + + def test_botsort_encoder_not_set_when_no_reid(self): + """Without ``reid_encoder``, ``BOTSORT.encoder`` stays as-is.""" + detector = _make_mock_detector() + mock_tracker = MagicMock() + mock_tracker.encoder = None + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + TrackingAdapter(detector) + # encoder attribute on tracker should not have been mutated + assert mock_tracker.encoder is None + + def test_bytetrack_no_encoder_attribute_no_crash(self): + """Adapter with reid_encoder works even when tracker has no .encoder.""" + detector = _make_mock_detector() + encoder = _make_mock_reid_encoder() + # Tracker with no 'encoder' attribute + mock_tracker = MagicMock(spec=["update", "tracked_stracks", "reset"]) + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + # Must not raise + adapter = TrackingAdapter(detector, reid_encoder=encoder) + assert adapter._reid_encoder is encoder + + def test_reid_bridge_stored_when_provided(self): + """Providing ``reid_bridge`` stores it on the adapter.""" + detector = _make_mock_detector() + bridge = MagicMock() + with patch.object(TrackingAdapter, "_build_tracker", return_value=MagicMock()): + adapter = TrackingAdapter(detector, reid_bridge=bridge) + assert adapter._reid_bridge is bridge + + +# --------------------------------------------------------------------------- +# Group 2: update() pipeline with ReID active +# --------------------------------------------------------------------------- + + +class TestTrackingAdapterUpdateWithReID: + def _make_adapter_with_reid( + self, + n_detections: int = 1, + embedding_dim: int = 8, + smooth_feat: np.ndarray | None = None, + ): + """Build an adapter with a mock detector/tracker/encoder.""" + vr = _make_vision_result(*[_make_instance() for _ in range(n_detections)]) + detector = _make_mock_detector(vr) + encoder = _make_mock_reid_encoder(embedding_dim=embedding_dim, n_detections=n_detections) + + if smooth_feat is None: + embed = np.ones(embedding_dim, dtype=np.float32) + smooth_feat = embed / np.linalg.norm(embed) + + strack = _make_mock_strack(track_id=1, smooth_feat=smooth_feat) + tracked_out = np.array([[10.0, 20.0, 110.0, 120.0, 1.0, 0.9, 0.0, 0.0]]) + mock_tracker = _make_mock_tracker(tracked_out, [strack]) + + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector, reid_encoder=encoder) + # Replace tracker directly (patch already returned mock_tracker) + adapter._tracker = mock_tracker + return adapter, encoder, smooth_feat + + def test_update_without_reid_unchanged(self): + """``update()`` without reid_encoder returns VisionResult normally.""" + detector = _make_mock_detector() + mock_tracker = _make_mock_tracker() + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector) + result = adapter.update(np.zeros((100, 100, 3), dtype=np.uint8)) + assert isinstance(result, VisionResult) + assert result.instances[0].track_id == 1 + assert result.instances[0].embedding is None + + def test_update_with_reid_populates_embeddings(self): + """VisionResult instances carry embeddings when ReID is active.""" + feat = np.array([0.6, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32) + adapter, _, smooth_feat = self._make_adapter_with_reid(smooth_feat=feat) + result = adapter.update(np.zeros((100, 100, 3), dtype=np.uint8)) + assert result.instances[0].embedding is not None + np.testing.assert_array_equal(result.instances[0].embedding, feat) + + def test_embedding_shape_in_vision_result(self): + """``Instance.embedding`` is a 1-D float32 numpy array of expected dim.""" + dim = 16 + adapter, _, smooth_feat = self._make_adapter_with_reid(embedding_dim=dim) + result = adapter.update(np.zeros((100, 100, 3), dtype=np.uint8)) + emb = result.instances[0].embedding + assert emb is not None + assert emb.ndim == 1 + assert emb.shape[0] == dim + assert emb.dtype == np.float32 + + def test_reid_encoder_called_with_crops(self): + """The reid encoder's ``predict()`` is invoked with image crops.""" + adapter, encoder, _ = self._make_adapter_with_reid() + adapter.update(np.zeros((200, 200, 3), dtype=np.uint8)) + encoder.predict.assert_called_once() + crops_arg = encoder.predict.call_args[0][0] + assert len(crops_arg) == 1 + assert isinstance(crops_arg[0], np.ndarray) + + def test_update_with_reid_skipped_when_np_image_none(self): + """ReID crop extraction is skipped gracefully when image can't be converted.""" + adapter, encoder, _ = self._make_adapter_with_reid() + # Pass a string URL — _to_numpy_image returns None, encoder must not be called + with patch.object(adapter, "_to_numpy_image", return_value=None): + result = adapter.update("http://example.com/frame.jpg") + encoder.predict.assert_not_called() + assert isinstance(result, VisionResult) + + def test_update_empty_frame_no_crash_with_reid(self): + """Zero detections + reid_encoder doesn't crash.""" + vr = _make_vision_result() # empty + detector = _make_mock_detector(vr) + encoder = _make_mock_reid_encoder() + mock_tracker = _make_mock_tracker( + tracked_output=np.empty((0, 8)), + tracked_stracks=[], + ) + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector, reid_encoder=encoder) + adapter._tracker = mock_tracker + result = adapter.update(np.zeros((100, 100, 3), dtype=np.uint8)) + encoder.predict.assert_not_called() + assert result.instances == [] + + def test_zero_area_bbox_skipped_in_reid(self): + """Zero-area bboxes produce empty placeholder crops (encoder not called for them).""" + zero_inst = _make_instance(x1=50, y1=50, x2=50, y2=50) # degenerate + vr = _make_vision_result(zero_inst) + detector = _make_mock_detector(vr) + encoder = _make_mock_reid_encoder() + mock_tracker = _make_mock_tracker( + tracked_output=np.empty((0, 8)), + tracked_stracks=[], + ) + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector, reid_encoder=encoder) + adapter._tracker = mock_tracker + adapter.update(np.zeros((100, 100, 3), dtype=np.uint8)) + # No valid crops → encoder predict never called + encoder.predict.assert_not_called() + + def test_reid_encoder_not_called_without_reid(self): + """Without reid_encoder, no encode call is made during update.""" + detector = _make_mock_detector() + encoder = _make_mock_reid_encoder() + mock_tracker = _make_mock_tracker() + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector) # no reid_encoder + adapter.update(np.zeros((100, 100, 3), dtype=np.uint8)) + encoder.predict.assert_not_called() + + def test_multiple_detections_all_get_embeddings(self): + """All tracked instances receive their respective embeddings.""" + n = 3 + vr = _make_vision_result(*[_make_instance(x1=i * 30, x2=i * 30 + 20) for i in range(n)]) + detector = _make_mock_detector(vr) + encoder = _make_mock_reid_encoder(n_detections=n, embedding_dim=4) + + stracks = [ + _make_mock_strack(track_id=i + 1, smooth_feat=np.array([float(i)] * 4, dtype=np.float32)) for i in range(n) + ] + tracked_out = np.array( + [[i * 30.0, 0.0, i * 30.0 + 20.0, 100.0, float(i + 1), 0.9, 0.0, float(i)] for i in range(n)] + ) + mock_tracker = _make_mock_tracker(tracked_out, stracks) + + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector, reid_encoder=encoder) + adapter._tracker = mock_tracker + result = adapter.update(np.zeros((200, 200, 3), dtype=np.uint8)) + assert len(result.instances) == n + for inst in result.instances: + assert inst.embedding is not None + + +# --------------------------------------------------------------------------- +# Group 3: BOTSORT get_dists() ReID branch +# --------------------------------------------------------------------------- + + +class TestBOTSORTGetDistsReIDBranch: + def _default_args(self, **overrides) -> dict: + base = { + "track_high_thresh": 0.5, + "track_low_thresh": 0.1, + "new_track_thresh": 0.6, + "track_buffer": 30, + "match_thresh": 0.8, + "fuse_score": False, + "gmc_method": None, + "proximity_thresh": 0.5, + "appearance_thresh": 0.25, + "with_reid": False, + } + base.update(overrides) + return base + + def test_botsort_encoder_none_by_default(self): + """``BOTSORT.encoder`` is ``None`` before any ReID wiring.""" + tracker = BOTSORT(self._default_args(), frame_rate=30) + assert tracker.encoder is None + + def test_botsort_with_reid_false_skips_embedding_distance(self): + """``get_dists()`` does NOT call embedding distance when ``with_reid=False``.""" + tracker = BOTSORT(self._default_args(with_reid=False), frame_rate=30) + tracker.encoder = MagicMock() # encoder set but with_reid=False + + t = BOTrack([50, 50, 40, 60, 0, 0], 0.9, 0) + d = BOTrack([50, 50, 40, 60, 0, 0], 0.9, 0) + + with patch("mata.trackers.utils.matching.embedding_distance") as mock_emb: + tracker.get_dists([t], [d]) + mock_emb.assert_not_called() + + def test_botsort_get_dists_uses_embedding_when_encoder_and_with_reid(self): + """``get_dists()`` uses embedding distance when both ``with_reid=True`` and ``encoder`` is set.""" + tracker = BOTSORT(self._default_args(with_reid=True), frame_rate=30) + tracker.encoder = MagicMock() + + feat = np.array([1.0, 0.0], dtype=np.float32) + t = BOTrack([50, 60, 40, 60, 0, 0], 0.9, 0) + t.update_features(feat) + d = BOTrack([52, 60, 40, 60, 0, 0], 0.9, 0) + d.update_features(feat) + + with patch("mata.trackers.utils.matching.embedding_distance", return_value=np.array([[0.1]])) as mock_emb: + cost = tracker.get_dists([t], [d]) + mock_emb.assert_called_once() + assert cost.shape == (1, 1) + + def test_botsort_get_dists_no_encoder_falls_back_to_iou(self): + """``get_dists()`` falls back to IoU gating when ``encoder=None`` even with ``with_reid=True``.""" + tracker = BOTSORT(self._default_args(with_reid=True), frame_rate=30) + # encoder stays None + + t = BOTrack([50, 60, 40, 60, 0, 0], 0.9, 0) + d = BOTrack([52, 60, 40, 60, 0, 0], 0.9, 0) + + with patch("mata.trackers.utils.matching.embedding_distance") as mock_emb: + tracker.get_dists([t], [d]) + mock_emb.assert_not_called() + + def test_tracking_adapter_sets_botsort_encoder(self): + """``TrackingAdapter`` wires the encoder into a real BOTSORT instance.""" + encoder = _make_mock_reid_encoder() + detector = _make_mock_detector() + # Use real BOTSORT — no patching + adapter = TrackingAdapter(detector, tracker_config="botsort", reid_encoder=encoder) + assert adapter._tracker.encoder is encoder + + +# --------------------------------------------------------------------------- +# Group 4: mata.load() and mata.track() public API +# --------------------------------------------------------------------------- + + +class TestPublicAPIReID: + def test_load_track_without_reid_returns_adapter_no_encoder(self): + """``mata.load("track", ...)`` without ReID returns adapter without encoder.""" + mock_detect = _make_mock_detector() + with patch("mata.adapters.huggingface_adapter.HuggingFaceDetectAdapter", return_value=mock_detect): + adapter = mata.load("track", "facebook/detr-resnet-50") + assert isinstance(adapter, TrackingAdapter) + assert adapter._reid_encoder is None + + def test_load_track_with_reid_model_sets_encoder(self): + """``mata.load("track", ..., reid_model=...)`` creates adapter with encoder.""" + mock_detect = _make_mock_detector() + mock_encoder = _make_mock_reid_encoder() + + with ( + patch("mata.adapters.huggingface_adapter.HuggingFaceDetectAdapter", return_value=mock_detect), + patch( + "mata.core.model_loader.UniversalLoader._load_reid_encoder", + return_value=mock_encoder, + ), + ): + adapter = mata.load("track", "facebook/detr-resnet-50", reid_model="org/reid-model") + + assert isinstance(adapter, TrackingAdapter) + assert adapter._reid_encoder is mock_encoder + + def test_with_reid_true_without_model_raises_value_error(self): + """``with_reid=True`` without ``reid_model`` raises ``ValueError``.""" + mock_detect = _make_mock_detector() + with patch("mata.adapters.huggingface_adapter.HuggingFaceDetectAdapter", return_value=mock_detect): + with pytest.raises(ValueError, match="reid_model"): + mata.load("track", "facebook/detr-resnet-50", with_reid=True) + + def test_load_track_onnx_reid_loads_onnx_adapter(self): + """An ``.onnx`` path for ``reid_model`` creates an ``ONNXReIDAdapter``.""" + from mata.adapters.reid_adapter import ONNXReIDAdapter + + mock_detect = _make_mock_detector() + with ( + patch("mata.adapters.huggingface_adapter.HuggingFaceDetectAdapter", return_value=mock_detect), + patch.object(ONNXReIDAdapter, "_load_model"), # skip actual ONNX loading + ): + adapter = mata.load("track", "facebook/detr-resnet-50", reid_model="reid.onnx") + + assert isinstance(adapter, TrackingAdapter) + assert isinstance(adapter._reid_encoder, ONNXReIDAdapter) + + def test_load_track_hf_reid_loads_hf_adapter(self): + """A HuggingFace ID for ``reid_model`` creates a ``HuggingFaceReIDAdapter``.""" + from mata.adapters.reid_adapter import HuggingFaceReIDAdapter + + mock_detect = _make_mock_detector() + with ( + patch("mata.adapters.huggingface_adapter.HuggingFaceDetectAdapter", return_value=mock_detect), + patch.object(HuggingFaceReIDAdapter, "_load_model"), # skip real download + ): + adapter = mata.load("track", "facebook/detr-resnet-50", reid_model="org/reid-model") + + assert isinstance(adapter, TrackingAdapter) + assert isinstance(adapter._reid_encoder, HuggingFaceReIDAdapter) + + def test_track_api_backward_compat_no_reid(self): + """``mata.track()`` without any ReID kwargs has identical behaviour to before.""" + import numpy as np + + frame = np.zeros((100, 100, 3), dtype=np.uint8) + + mock_adapter = MagicMock() + mock_adapter.update.return_value = _make_vision_result(_make_instance(track_id=1)) + + with patch("mata.api.load", return_value=mock_adapter): + results = mata.track(frame, model="facebook/detr-resnet-50") + + assert isinstance(results, list) + assert len(results) == 1 + + def test_track_api_reid_model_forwarded_to_load(self): + """``mata.track(..., reid_model=...)`` forwards ``reid_model`` to ``load()``.""" + frame = np.zeros((100, 100, 3), dtype=np.uint8) + mock_adapter = MagicMock() + mock_adapter.update.return_value = _make_vision_result() + + with patch("mata.api.load", return_value=mock_adapter) as mock_load: + mata.track(frame, model="facebook/detr-resnet-50", reid_model="org/reid") + + call_kwargs = mock_load.call_args[1] + assert call_kwargs.get("reid_model") == "org/reid" + + def test_track_api_with_reid_forwarded_to_load(self): + """``mata.track(..., with_reid=True)`` forwards ``with_reid`` to ``load()``.""" + frame = np.zeros((100, 100, 3), dtype=np.uint8) + mock_adapter = MagicMock() + mock_adapter.update.return_value = _make_vision_result() + + with patch("mata.api.load", return_value=mock_adapter) as mock_load: + # with_reid=True but we allow it through since load is mocked + mata.track(frame, model="facebook/detr-resnet-50", with_reid=True) + + call_kwargs = mock_load.call_args[1] + assert call_kwargs.get("with_reid") is True + + +# --------------------------------------------------------------------------- +# Group 5: Config / YAML support for ReID +# --------------------------------------------------------------------------- + + +class TestConfigReIDSupport: + def test_resolve_tracker_kwargs_pops_reid_model(self): + """``_resolve_tracker_kwargs`` pops and returns ``reid_model``.""" + from mata.core.model_loader import UniversalLoader + + loader = UniversalLoader.__new__(UniversalLoader) + kwargs = {"tracker": "botsort", "frame_rate": 30, "reid_model": "org/reid"} + _, _, reid_model, with_reid, reid_bridge = loader._resolve_tracker_kwargs(kwargs) + assert reid_model == "org/reid" + assert "reid_model" not in kwargs + + def test_resolve_tracker_kwargs_pops_with_reid(self): + """``_resolve_tracker_kwargs`` pops and returns ``with_reid``.""" + from mata.core.model_loader import UniversalLoader + + loader = UniversalLoader.__new__(UniversalLoader) + kwargs = {"with_reid": True, "reid_model": "org/reid"} + _, _, _, with_reid, _ = loader._resolve_tracker_kwargs(kwargs) + assert with_reid is True + + def test_config_alias_reid_model_extracted(self): + """Config alias with ``reid_model`` key wires the encoder.""" + mock_detect = _make_mock_detector() + mock_encoder = _make_mock_reid_encoder() + + config_entry = { + "source": "facebook/detr-resnet-50", + "tracker": "botsort", + "reid_model": "org/reid-model", + } + + # has_alias returns True only for the alias (not for the recursive source call) + def has_alias_side_effect(task, name): + return name == "my-config-alias" + + with ( + patch("mata.adapters.huggingface_adapter.HuggingFaceDetectAdapter", return_value=mock_detect), + patch("mata.core.model_loader.UniversalLoader._load_reid_encoder", return_value=mock_encoder), + patch("mata.core.model_registry.ModelRegistry.has_alias", side_effect=has_alias_side_effect), + patch("mata.core.model_registry.ModelRegistry.get_config", return_value=config_entry), + ): + adapter = mata.load("track", "my-config-alias") + + assert isinstance(adapter, TrackingAdapter) + assert adapter._reid_encoder is mock_encoder + + def test_runtime_kwarg_overrides_config_reid_model(self): + """Runtime ``reid_model`` kwarg overrides the one in config.""" + mock_detect = _make_mock_detector() + mock_encoder = _make_mock_reid_encoder() + + config_entry = { + "source": "facebook/detr-resnet-50", + "reid_model": "config/reid-model", # should be overridden by runtime kwarg + } + + def has_alias_side_effect(task, name): + return name == "my-config-alias" + + with ( + patch("mata.adapters.huggingface_adapter.HuggingFaceDetectAdapter", return_value=mock_detect), + patch( + "mata.core.model_loader.UniversalLoader._load_reid_encoder", + return_value=mock_encoder, + ) as mock_load_reid, + patch("mata.core.model_registry.ModelRegistry.has_alias", side_effect=has_alias_side_effect), + patch("mata.core.model_registry.ModelRegistry.get_config", return_value=config_entry), + ): + mata.load("track", "my-config-alias", reid_model="runtime/reid-model") + + # Verify the runtime value was used (not the config value) + mock_load_reid.assert_called_once_with("runtime/reid-model") + + def test_config_without_reid_unchanged(self): + """Config aliases without ``reid_model`` continue to work unchanged.""" + mock_detect = _make_mock_detector() + + config_entry = { + "source": "facebook/detr-resnet-50", + "tracker": "botsort", + # no reid_model key + } + + def has_alias_side_effect(task, name): + return name == "my-config-alias" + + with ( + patch("mata.adapters.huggingface_adapter.HuggingFaceDetectAdapter", return_value=mock_detect), + patch("mata.core.model_registry.ModelRegistry.has_alias", side_effect=has_alias_side_effect), + patch("mata.core.model_registry.ModelRegistry.get_config", return_value=config_entry), + ): + adapter = mata.load("track", "my-config-alias") + + assert isinstance(adapter, TrackingAdapter) + assert adapter._reid_encoder is None + + +# --------------------------------------------------------------------------- +# Group 6: ReIDBridge integration in TrackingAdapter.update() +# --------------------------------------------------------------------------- + + +class TestReIDBridgeIntegration: + def _make_adapter_with_bridge(self, embed_feat: np.ndarray | None = None): + """Build a TrackingAdapter with both ReID encoder and bridge.""" + dim = 4 + if embed_feat is None: + embed_feat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) + + vr = _make_vision_result(_make_instance()) + detector = _make_mock_detector(vr) + encoder = _make_mock_reid_encoder(embedding_dim=dim, n_detections=1) + bridge = MagicMock() + + strack = _make_mock_strack(track_id=1, smooth_feat=embed_feat) + tracked_out = np.array([[10.0, 20.0, 110.0, 120.0, 1.0, 0.9, 0.0, 0.0]]) + mock_tracker = _make_mock_tracker(tracked_out, [strack]) + + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector, reid_encoder=encoder, reid_bridge=bridge) + adapter._tracker = mock_tracker + return adapter, bridge, embed_feat + + def test_reid_bridge_publish_called_after_update(self): + """``ReIDBridge.publish()`` is called for each tracked instance with embedding.""" + feat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) + adapter, bridge, _ = self._make_adapter_with_bridge(embed_feat=feat) + adapter.update(np.zeros((200, 200, 3), dtype=np.uint8)) + bridge.publish.assert_called_once() + call_kwargs = bridge.publish.call_args[1] + assert call_kwargs["track_id"] == 1 + np.testing.assert_array_equal(call_kwargs["embedding"], feat) + + def test_reid_bridge_not_called_when_none(self): + """No bridge calls occur when ``reid_bridge=None``.""" + vr = _make_vision_result(_make_instance()) + detector = _make_mock_detector(vr) + mock_tracker = _make_mock_tracker() + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector) # no bridge + adapter.update(np.zeros((100, 100, 3), dtype=np.uint8)) + # No bridge object → no publish() calls (no AttributeError) + + def test_reid_bridge_publish_skipped_without_embedding(self): + """If embedding is ``None``, ``bridge.publish()`` is NOT called.""" + vr = _make_vision_result(_make_instance()) + detector = _make_mock_detector(vr) + bridge = MagicMock() + + # strack has smooth_feat=None → embedding won't be in output + strack = _make_mock_strack(track_id=1, smooth_feat=None) + tracked_out = np.array([[10.0, 20.0, 110.0, 120.0, 1.0, 0.9, 0.0, 0.0]]) + mock_tracker = _make_mock_tracker(tracked_out, [strack]) + + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector, reid_bridge=bridge) + adapter._tracker = mock_tracker + adapter.update(np.zeros((100, 100, 3), dtype=np.uint8)) + bridge.publish.assert_not_called() + + def test_reid_bridge_publish_called_for_multiple_instances(self): + """``bridge.publish()`` is called once per tracked instance with an embedding.""" + dim = 4 + n = 2 + feats = [ + np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + np.array([0.0, 1.0, 0.0, 0.0], dtype=np.float32), + ] + + vr = _make_vision_result(*[_make_instance(x1=i * 60, x2=i * 60 + 50) for i in range(n)]) + detector = _make_mock_detector(vr) + encoder = _make_mock_reid_encoder(embedding_dim=dim, n_detections=n) + bridge = MagicMock() + + stracks = [_make_mock_strack(track_id=i + 1, smooth_feat=feats[i]) for i in range(n)] + tracked_out = np.array( + [[i * 60.0, 0.0, i * 60.0 + 50.0, 100.0, float(i + 1), 0.9, 0.0, float(i)] for i in range(n)] + ) + mock_tracker = _make_mock_tracker(tracked_out, stracks) + + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector, reid_encoder=encoder, reid_bridge=bridge) + adapter._tracker = mock_tracker + adapter.update(np.zeros((300, 300, 3), dtype=np.uint8)) + + assert bridge.publish.call_count == n + + +# --------------------------------------------------------------------------- +# Group 7: Backward compatibility +# --------------------------------------------------------------------------- + + +class TestBackwardCompatibility: + def test_adapter_update_identical_without_reid_encoder(self): + """``update()`` result is unchanged when ``reid_encoder=None``.""" + track_row = np.array([[10.0, 20.0, 110.0, 120.0, 1.0, 0.9, 0.0, 0.0]]) + vr = _make_vision_result(_make_instance()) + detector = _make_mock_detector(vr) + mock_tracker = _make_mock_tracker(track_row, []) + + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector) + + result = adapter.update(np.zeros((100, 100, 3), dtype=np.uint8)) + assert len(result.instances) == 1 + inst = result.instances[0] + assert inst.track_id == 1 + assert inst.embedding is None + assert inst.bbox == (10.0, 20.0, 110.0, 120.0) + + def test_with_reid_false_default_path_no_overhead(self): + """Default ``with_reid=False`` path processes no embeddings.""" + vr = _make_vision_result(_make_instance()) + detector = _make_mock_detector(vr) + mock_tracker = _make_mock_tracker() + + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector) + + assert adapter._reid_encoder is None + assert adapter._reid_bridge is None + + def test_tracker_type_property_unchanged(self): + """``tracker_type`` property continues to work after ReID changes.""" + detector = _make_mock_detector() + mock_tracker = MagicMock() + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector, tracker_config="botsort") + assert adapter.tracker_type == "botsort" + + def test_persist_false_resets_tracker(self): + """``persist=False`` still resets the tracker (ReID changes don't break this).""" + vr = _make_vision_result(_make_instance()) + detector = _make_mock_detector(vr) + mock_tracker = _make_mock_tracker() + + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector) + + adapter.update(np.zeros((100, 100, 3), dtype=np.uint8), persist=False) + mock_tracker.reset.assert_called_once() + + def test_class_filter_still_works_with_reid_encoder(self): + """``classes=`` filtering still runs correctly when ReID is active.""" + # Provide two detections of different classes + inst0 = _make_instance(x1=10, x2=60, label=0) + inst1 = _make_instance(x1=70, x2=120, label=1) + vr = _make_vision_result(inst0, inst1) + detector = _make_mock_detector(vr) + encoder = _make_mock_reid_encoder(n_detections=1, embedding_dim=4) + + # Tracker returns empty (filtered down to class 0 only) + strack = _make_mock_strack(track_id=1, smooth_feat=np.ones(4, dtype=np.float32)) + tracked_out = np.array([[10.0, 20.0, 60.0, 120.0, 1.0, 0.9, 0.0, 0.0]]) + mock_tracker = _make_mock_tracker(tracked_out, [strack]) + + with patch.object(TrackingAdapter, "_build_tracker", return_value=mock_tracker): + adapter = TrackingAdapter(detector, reid_encoder=encoder) + adapter._tracker = mock_tracker + + adapter.update( + np.zeros((200, 200, 3), dtype=np.uint8), + classes=[0], # keep only class 0 + ) + # Encoder should only be called for the 1 filtered detection (class 0) + encoder.predict.assert_called_once() + crops = encoder.predict.call_args[0][0] + assert len(crops) == 1 diff --git a/tests/test_transformation_nodes.py b/tests/test_transformation_nodes.py index 0d2b825..8e2b6d0 100644 --- a/tests/test_transformation_nodes.py +++ b/tests/test_transformation_nodes.py @@ -651,8 +651,8 @@ class TestNodeTypeDeclarations: def test_filter_types(self): node = Filter() - assert node.inputs == {"detections": Detections} - assert node.outputs == {"detections": Detections} + assert node.inputs == {"dets": Detections} # dynamic: default src="dets" + assert node.outputs == {"filtered": Detections} # dynamic: default out="filtered" def test_topk_types(self): node = TopK(k=1) diff --git a/tests/test_universal_loader.py b/tests/test_universal_loader.py index 309d6f7..7ebf77e 100644 --- a/tests/test_universal_loader.py +++ b/tests/test_universal_loader.py @@ -477,7 +477,9 @@ def test_load_track_from_huggingface(self, mock_detect_adapter, mock_tracking_ad # Detect adapter created for the underlying model mock_detect_adapter.assert_called_once_with(model_id="facebook/detr-resnet-50") # TrackingAdapter wraps it with default botsort config - mock_tracking_adapter.assert_called_once_with(mock_detect_instance, "botsort", 30) + mock_tracking_adapter.assert_called_once_with( + mock_detect_instance, "botsort", 30, reid_encoder=None, reid_bridge=None + ) assert result == mock_tracking_instance @patch("mata.adapters.tracking_adapter.TrackingAdapter") @@ -538,7 +540,9 @@ def test_load_track_from_onnx_file(self, mock_detect_adapter, mock_tracking_adap result = loader._load_from_file("track", filepath) mock_detect_adapter.assert_called_once_with(model_path=filepath) - mock_tracking_adapter.assert_called_once_with(mock_detect_instance, "botsort", 30) + mock_tracking_adapter.assert_called_once_with( + mock_detect_instance, "botsort", 30, reid_encoder=None, reid_bridge=None + ) assert result == mock_tracking_instance finally: try: @@ -731,7 +735,7 @@ def test_resolve_tracker_kwargs_string_only(self): """_resolve_tracker_kwargs with only tracker= string returns it unchanged.""" loader = UniversalLoader() kwargs = {"tracker": "bytetrack", "frame_rate": 25, "device": "cuda"} - tracker_config, frame_rate = loader._resolve_tracker_kwargs(kwargs) + tracker_config, frame_rate, *_ = loader._resolve_tracker_kwargs(kwargs) assert tracker_config == "bytetrack" assert frame_rate == 25 @@ -746,7 +750,7 @@ def test_resolve_tracker_kwargs_merges_overrides(self): "tracker_config": {"track_buffer": 60, "match_thresh": 0.7}, "frame_rate": 30, } - tracker_config, frame_rate = loader._resolve_tracker_kwargs(kwargs) + tracker_config, frame_rate, *_ = loader._resolve_tracker_kwargs(kwargs) assert isinstance(tracker_config, dict) assert tracker_config["tracker_type"] == "botsort" @@ -759,7 +763,7 @@ def test_resolve_tracker_kwargs_defaults(self): """_resolve_tracker_kwargs defaults: tracker=botsort, frame_rate=30.""" loader = UniversalLoader() kwargs = {} - tracker_config, frame_rate = loader._resolve_tracker_kwargs(kwargs) + tracker_config, frame_rate, *_ = loader._resolve_tracker_kwargs(kwargs) assert tracker_config == "botsort" assert frame_rate == 30 @@ -772,7 +776,7 @@ def test_resolve_tracker_kwargs_dict_tracker_ignores_overrides(self): "tracker": custom_dict, "tracker_config": {"track_buffer": 99}, # should be ignored } - tracker_config, frame_rate = loader._resolve_tracker_kwargs(kwargs) + tracker_config, frame_rate, *_ = loader._resolve_tracker_kwargs(kwargs) # Original dict returned unchanged (not merged with overrides) assert tracker_config is custom_dict diff --git a/tests/test_valkey_config.py b/tests/test_valkey_config.py new file mode 100644 index 0000000..fd9dc4a --- /dev/null +++ b/tests/test_valkey_config.py @@ -0,0 +1,285 @@ +"""Tests for Valkey config and Pub/Sub — Task C2, D1.""" + +from __future__ import annotations + +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +from mata.core.exceptions import ModelNotFoundError +from mata.core.model_registry import ModelRegistry + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_registry(yaml_content: str) -> ModelRegistry: + """Create a ModelRegistry backed by a temporary YAML file.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False, encoding="utf-8") as fh: + fh.write(yaml_content) + tmp_path = fh.name + return ModelRegistry(custom_config_path=tmp_path) + + +def _minimal_vision_result(): + """Return a minimal mock VisionResult.""" + mock = MagicMock() + mock.to_json.return_value = '{"instances": []}' + mock.to_dict.return_value = {"instances": []} + return mock + + +def _minimal_classify_result(): + mock = MagicMock() + mock.to_json.return_value = '{"predictions": []}' + mock.to_dict.return_value = {"predictions": []} + return mock + + +# --------------------------------------------------------------------------- +# TestValkeyConfig +# --------------------------------------------------------------------------- + + +class TestValkeyConfig: + """Tests for YAML config integration (Task C2).""" + + BASIC_YAML = """\ +storage: + valkey: + default: + url: "valkey://localhost:6379" + db: 0 + ttl: 3600 + production: + url: "valkey://prod-cluster:6380" + db: 1 +""" + + def test_get_default_connection(self): + """get_valkey_connection('default') returns the default connection dict.""" + registry = _make_registry(self.BASIC_YAML) + conn = registry.get_valkey_connection("default") + + assert conn["url"] == "valkey://localhost:6379" + assert conn["db"] == 0 + assert conn["ttl"] == 3600 + + def test_get_named_connection(self): + """get_valkey_connection('production') returns non-default connection.""" + registry = _make_registry(self.BASIC_YAML) + conn = registry.get_valkey_connection("production") + + assert conn["url"] == "valkey://prod-cluster:6380" + assert conn["db"] == 1 + + def test_missing_connection_raises(self): + """Unknown connection name raises ModelNotFoundError with helpful message.""" + registry = _make_registry(self.BASIC_YAML) + + with pytest.raises(ModelNotFoundError, match="nonexistent"): + registry.get_valkey_connection("nonexistent") + + def test_missing_connection_lists_available(self): + """ModelNotFoundError message lists the available connection names.""" + registry = _make_registry(self.BASIC_YAML) + + with pytest.raises(ModelNotFoundError) as exc_info: + registry.get_valkey_connection("missing") + + error_msg = str(exc_info.value) + # Error should list at least one of the existing connection names + assert "default" in error_msg or "production" in error_msg + + def test_password_env_resolved(self, monkeypatch): + """password_env is resolved from the environment and replaced by 'password'.""" + yaml_content = """\ +storage: + valkey: + default: + url: "valkey://secure-host:6379" + password_env: "TEST_VALKEY_PASS_E3" +""" + monkeypatch.setenv("TEST_VALKEY_PASS_E3", "s3cr3t") + registry = _make_registry(yaml_content) + conn = registry.get_valkey_connection("default") + + assert conn.get("password") == "s3cr3t" + assert "password_env" not in conn + + def test_password_env_missing_graceful(self, monkeypatch): + """Missing env var does not raise; 'password' key is simply absent.""" + yaml_content = """\ +storage: + valkey: + default: + url: "valkey://secure-host:6379" + password_env: "TEST_VALKEY_UNDEFINED_ENV_VAR_XYZ" +""" + monkeypatch.delenv("TEST_VALKEY_UNDEFINED_ENV_VAR_XYZ", raising=False) + registry = _make_registry(yaml_content) + conn = registry.get_valkey_connection("default") + + assert "password_env" not in conn + assert "password" not in conn + + def test_no_storage_section_graceful(self): + """A config file with no 'storage' section is handled without error.""" + yaml_content = """\ +models: + detect: + rtdetr-fast: + source: "facebook/detr-resnet-50" +""" + registry = _make_registry(yaml_content) + + with pytest.raises(ModelNotFoundError): + registry.get_valkey_connection("default") + + def test_existing_models_config_unaffected(self): + """Adding a storage section does not alter the existing models section.""" + yaml_content = """\ +detect: + my-model: + source: "facebook/detr-resnet-50" + threshold: 0.5 + +storage: + valkey: + default: + url: "valkey://localhost:6379" +""" + registry = _make_registry(yaml_content) + + # models section still works + assert registry.has_alias("detect", "my-model") + config = registry.get_config("detect", "my-model") + assert config["source"] == "facebook/detr-resnet-50" + + # storage section also works + conn = registry.get_valkey_connection("default") + assert conn["url"] == "valkey://localhost:6379" + + def test_tls_flag_passthrough(self): + """tls flag in YAML is returned in the connection dict.""" + yaml_content = """\ +storage: + valkey: + secure: + url: "valkey://tls-host:6380" + tls: true + db: 0 +""" + registry = _make_registry(yaml_content) + conn = registry.get_valkey_connection("secure") + + assert conn.get("tls") is True + + +# --------------------------------------------------------------------------- +# TestPublishValkey +# --------------------------------------------------------------------------- + + +class TestPublishValkey: + """Tests for publish_valkey() Pub/Sub function (Task D1).""" + + def test_publish_vision_result(self): + """publish_valkey() calls client.publish() with a VisionResult.""" + from mata.core.exporters.valkey_exporter import publish_valkey + + result = _minimal_vision_result() + mock_client = MagicMock() + mock_client.publish.return_value = 3 + + with patch( + "mata.core.exporters.valkey_exporter._get_valkey_client", + return_value=mock_client, + ): + count = publish_valkey(result, url="valkey://localhost:6379", channel="detections") + + mock_client.publish.assert_called_once() + call_args = mock_client.publish.call_args + assert call_args[0][0] == "detections" + assert count == 3 + + def test_publish_returns_subscriber_count(self): + """publish_valkey() returns the integer subscriber count from client.publish().""" + from mata.core.exporters.valkey_exporter import publish_valkey + + result = _minimal_vision_result() + mock_client = MagicMock() + mock_client.publish.return_value = 7 + + with patch( + "mata.core.exporters.valkey_exporter._get_valkey_client", + return_value=mock_client, + ): + count = publish_valkey(result, url="valkey://localhost:6379", channel="ch") + + assert count == 7 + + def test_publish_json_serializer(self): + """publish_valkey() with serializer='json' calls result.to_json().""" + from mata.core.exporters.valkey_exporter import publish_valkey + + result = _minimal_vision_result() + mock_client = MagicMock() + mock_client.publish.return_value = 1 + + with patch( + "mata.core.exporters.valkey_exporter._get_valkey_client", + return_value=mock_client, + ): + publish_valkey(result, url="valkey://localhost:6379", channel="ch", serializer="json") + + result.to_json.assert_called_once() + result.to_dict.assert_not_called() + + def test_publish_msgpack_serializer(self): + """publish_valkey() with serializer='msgpack' calls result.to_dict() and packs.""" + from mata.core.exporters.valkey_exporter import publish_valkey + + result = _minimal_classify_result() + mock_client = MagicMock() + mock_client.publish.return_value = 2 + + msgpack_mock = MagicMock() + msgpack_mock.packb.return_value = b"\x81\xabpredictions\x90" + + with patch( + "mata.core.exporters.valkey_exporter._get_valkey_client", + return_value=mock_client, + ): + with patch.dict("sys.modules", {"msgpack": msgpack_mock}): + publish_valkey( + result, + url="valkey://localhost:6379", + channel="ch", + serializer="msgpack", + ) + + result.to_dict.assert_called_once() + msgpack_mock.packb.assert_called_once_with(result.to_dict.return_value, use_bin_type=True) + mock_client.publish.assert_called_once_with("ch", b"\x81\xabpredictions\x90") + + def test_publish_invalid_serializer_raises(self): + """publish_valkey() raises ValueError for an unsupported serializer.""" + from mata.core.exporters.valkey_exporter import publish_valkey + + result = _minimal_vision_result() + mock_client = MagicMock() + + with patch( + "mata.core.exporters.valkey_exporter._get_valkey_client", + return_value=mock_client, + ): + with pytest.raises(ValueError, match="Unsupported serializer"): + publish_valkey( + result, + url="valkey://localhost:6379", + channel="ch", + serializer="pickle", + ) diff --git a/tests/test_valkey_exporter.py b/tests/test_valkey_exporter.py new file mode 100644 index 0000000..6785f7c --- /dev/null +++ b/tests/test_valkey_exporter.py @@ -0,0 +1,672 @@ +"""Tests for Valkey/Redis exporter — Task A1, A2, C1. + +All tests use unittest.mock to simulate the Valkey/Redis client. +No real Valkey or Redis server is required. +""" + +from __future__ import annotations + +import json +import sys +from unittest.mock import MagicMock, patch + +import pytest + +# ── Result factory helpers ──────────────────────────────────────────────────── + + +def _make_vision_result(): + from mata.core.types import Instance, VisionResult + + return VisionResult( + instances=[ + Instance(bbox=(10, 20, 100, 200), score=0.9, label=0, label_name="cat"), + Instance(bbox=(50, 60, 150, 250), score=0.75, label=1, label_name="dog"), + ], + meta={"model": "detr-resnet-50"}, + ) + + +def _make_classify_result(): + from mata.core.types import Classification, ClassifyResult + + return ClassifyResult( + predictions=[ + Classification(label=0, score=0.95, label_name="cat"), + Classification(label=1, score=0.04, label_name="dog"), + ], + meta={"model": "clip-vit"}, + ) + + +def _make_depth_result(): + import numpy as np + + from mata.core.types import DepthResult + + depth = np.linspace(0.1, 1.0, 16, dtype=np.float32).reshape(4, 4) + return DepthResult(depth=depth, meta={"model": "depth-anything"}) + + +def _make_ocr_result(): + from mata.core.types import OCRResult, TextRegion + + return OCRResult( + regions=[ + TextRegion(text="hello", score=0.98, bbox=(0.0, 0.0, 50.0, 20.0)), + TextRegion(text="world", score=0.92, bbox=(60.0, 0.0, 120.0, 20.0)), + ], + meta={"engine": "easyocr"}, + ) + + +def _mock_client(): + return MagicMock() + + +# ── TestExportValkey ────────────────────────────────────────────────────────── + + +class TestExportValkey: + """Tests for export_valkey() function.""" + + def test_export_vision_result_json(self): + """JSON export of VisionResult stores JSON string via client.set().""" + from mata.core.exporters.valkey_exporter import export_valkey + + result = _make_vision_result() + client = _mock_client() + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + export_valkey(result, url="valkey://localhost:6379", key="test:vision") + + client.set.assert_called_once_with("test:vision", result.to_json()) + client.setex.assert_not_called() + + def test_export_classify_result_json(self): + """JSON export of ClassifyResult stored via client.set().""" + from mata.core.exporters.valkey_exporter import export_valkey + + result = _make_classify_result() + client = _mock_client() + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + export_valkey(result, url="valkey://localhost:6379", key="test:classify") + + client.set.assert_called_once_with("test:classify", result.to_json()) + + def test_export_depth_result_json(self): + """JSON export of DepthResult stored via client.set().""" + from mata.core.exporters.valkey_exporter import export_valkey + + result = _make_depth_result() + client = _mock_client() + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + export_valkey(result, url="valkey://localhost:6379", key="test:depth") + + client.set.assert_called_once_with("test:depth", result.to_json()) + + def test_export_ocr_result_json(self): + """JSON export of OCRResult stored via client.set().""" + from mata.core.exporters.valkey_exporter import export_valkey + + result = _make_ocr_result() + client = _mock_client() + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + export_valkey(result, url="valkey://localhost:6379", key="test:ocr") + + client.set.assert_called_once_with("test:ocr", result.to_json()) + + def test_export_with_ttl(self): + """When ttl is provided, client.setex() is used instead of set().""" + from mata.core.exporters.valkey_exporter import export_valkey + + result = _make_vision_result() + client = _mock_client() + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + export_valkey(result, url="valkey://localhost:6379", key="test:ttl", ttl=3600) + + client.setex.assert_called_once_with("test:ttl", 3600, result.to_json()) + client.set.assert_not_called() + + def test_export_without_ttl(self): + """When ttl is None (default), client.set() is used and setex() is not.""" + from mata.core.exporters.valkey_exporter import export_valkey + + result = _make_classify_result() + client = _mock_client() + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + export_valkey(result, url="valkey://localhost:6379", key="test:nottl") + + client.set.assert_called_once() + client.setex.assert_not_called() + + def test_export_msgpack_serializer(self): + """msgpack serializer packs result.to_dict() and stores bytes.""" + from mata.core.exporters.valkey_exporter import export_valkey + + result = _make_vision_result() + client = _mock_client() + fake_packed = b"\x82\xa3foo\xa3bar" + mock_msgpack = MagicMock() + mock_msgpack.packb.return_value = fake_packed + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + with patch.dict(sys.modules, {"msgpack": mock_msgpack}): + export_valkey( + result, + url="valkey://localhost:6379", + key="test:mp", + serializer="msgpack", + ) + + mock_msgpack.packb.assert_called_once_with(result.to_dict(), use_bin_type=True) + client.set.assert_called_once_with("test:mp", fake_packed) + + def test_export_invalid_serializer_raises(self): + """An unsupported serializer name raises ValueError with helpful message.""" + from mata.core.exporters.valkey_exporter import export_valkey + + result = _make_vision_result() + client = _mock_client() + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + with pytest.raises(ValueError, match="Unsupported serializer"): + export_valkey( + result, + url="valkey://localhost:6379", + key="k", + serializer="xml", + ) + + def test_import_error_no_client(self): + """ImportError raised with helpful message when neither valkey nor redis is available.""" + # Use a fresh call to _get_valkey_client to trigger the import error path + from mata.core.exporters import valkey_exporter + + # Patch the imports inside the function + with patch.dict(sys.modules, {"valkey": None, "redis": None}): + with pytest.raises(ImportError, match="pip install mata"): + valkey_exporter._get_valkey_client("valkey://localhost:6379") + + def test_valkey_client_fallback_to_redis(self): + """When valkey-py is absent, the client is obtained from redis-py.""" + from mata.core.exporters import valkey_exporter + + mock_redis_module = MagicMock() + mock_client_instance = MagicMock() + mock_redis_module.from_url.return_value = mock_client_instance + + with patch.dict(sys.modules, {"valkey": None, "redis": mock_redis_module}): + client = valkey_exporter._get_valkey_client("valkey://localhost:6379") + + # URL scheme must be translated for redis-py + mock_redis_module.from_url.assert_called_once_with("redis://localhost:6379") + assert client is mock_client_instance + + def test_url_password_not_logged(self): + """Passwords embedded in connection URLs must never appear in log messages.""" + from mata.core.exporters.valkey_exporter import export_valkey + + result = _make_vision_result() + client = _mock_client() + url_with_password = "valkey://admin:s3cr3tpassword@localhost:6379" + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + with patch("mata.core.exporters.valkey_exporter.logger") as mock_logger: + export_valkey(result, url=url_with_password, key="test:pwd") + + for logged_call in mock_logger.info.call_args_list: + logged_text = str(logged_call) + assert "s3cr3tpassword" not in logged_text, f"Password leaked in log message: {logged_text}" + + +# ── TestLoadValkey ──────────────────────────────────────────────────────────── + + +class TestLoadValkey: + """Tests for load_valkey() function.""" + + def test_load_vision_result(self): + """Loading a VisionResult JSON returns a VisionResult instance.""" + from mata.core.exporters.valkey_exporter import load_valkey + from mata.core.types import VisionResult + + original = _make_vision_result() + raw_json = original.to_json().encode() + client = _mock_client() + client.get.return_value = raw_json + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + loaded = load_valkey("valkey://localhost:6379", key="test:vision", result_type="vision") + + assert isinstance(loaded, VisionResult) + assert len(loaded.instances) == len(original.instances) + + def test_load_classify_result(self): + """Loading a ClassifyResult JSON returns a ClassifyResult instance.""" + from mata.core.exporters.valkey_exporter import load_valkey + from mata.core.types import ClassifyResult + + original = _make_classify_result() + raw_json = original.to_json().encode() + client = _mock_client() + client.get.return_value = raw_json + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + loaded = load_valkey("valkey://localhost:6379", key="test:classify", result_type="classify") + + assert isinstance(loaded, ClassifyResult) + assert len(loaded.predictions) == len(original.predictions) + + def test_load_depth_result(self): + """Loading a DepthResult JSON returns a DepthResult instance.""" + from mata.core.exporters.valkey_exporter import load_valkey + from mata.core.types import DepthResult + + original = _make_depth_result() + raw_json = original.to_json().encode() + client = _mock_client() + client.get.return_value = raw_json + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + loaded = load_valkey("valkey://localhost:6379", key="test:depth", result_type="depth") + + assert isinstance(loaded, DepthResult) + assert loaded.depth.shape == original.depth.shape + + def test_load_ocr_result(self): + """Loading an OCRResult JSON returns an OCRResult instance.""" + from mata.core.exporters.valkey_exporter import load_valkey + from mata.core.types import OCRResult + + original = _make_ocr_result() + raw_json = original.to_json().encode() + client = _mock_client() + client.get.return_value = raw_json + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + loaded = load_valkey("valkey://localhost:6379", key="test:ocr", result_type="ocr") + + assert isinstance(loaded, OCRResult) + assert len(loaded.regions) == len(original.regions) + + def test_load_auto_detect_type(self): + """Auto-detection picks the right type from serialized dict keys.""" + from mata.core.exporters.valkey_exporter import load_valkey + from mata.core.types import VisionResult + + original = _make_vision_result() + raw_json = original.to_json().encode() + client = _mock_client() + client.get.return_value = raw_json + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + # result_type="auto" is the default + loaded = load_valkey("valkey://localhost:6379", key="test:auto") + + assert isinstance(loaded, VisionResult) + + def test_load_explicit_type(self): + """Explicit result_type bypasses auto-detection.""" + from mata.core.exporters.valkey_exporter import load_valkey + from mata.core.types import VisionResult + + original = _make_vision_result() + raw_json = original.to_json().encode() + client = _mock_client() + client.get.return_value = raw_json + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + # Force "detect" which maps to VisionResult + loaded = load_valkey("valkey://localhost:6379", key="test:explicit", result_type="detect") + + assert isinstance(loaded, VisionResult) + + def test_load_missing_key_raises(self): + """A missing Valkey key raises KeyError with the key name in the message.""" + from mata.core.exporters.valkey_exporter import load_valkey + + client = _mock_client() + client.get.return_value = None # key doesn't exist + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + with pytest.raises(KeyError, match="missing_key"): + load_valkey("valkey://localhost:6379", key="missing_key") + + def test_load_unknown_type_raises(self): + """An unknown result_type string raises ValueError.""" + from mata.core.exporters.valkey_exporter import load_valkey + + # Use VisionResult data but ask for an unsupported type + raw_json = json.dumps({"instances": [], "meta": {}}).encode() + client = _mock_client() + client.get.return_value = raw_json + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + with pytest.raises(ValueError, match="Unknown result_type"): + load_valkey("valkey://localhost:6379", key="k", result_type="segment_panoptic") + + def test_roundtrip_vision_result(self): + """VisionResult survives export → load round-trip intact.""" + from mata.core.exporters.valkey_exporter import export_valkey, load_valkey + from mata.core.types import VisionResult + + original = _make_vision_result() + stored_value: list[str] = [] + + export_client = _mock_client() + export_client.set.side_effect = lambda k, v: stored_value.append(v) + + load_client = _mock_client() + + with patch( + "mata.core.exporters.valkey_exporter._get_valkey_client", + return_value=export_client, + ): + export_valkey(original, url="valkey://localhost:6379", key="rt:vision") + + assert stored_value, "export_valkey must call client.set()" + + load_client.get.return_value = stored_value[0].encode() if isinstance(stored_value[0], str) else stored_value[0] + + with patch( + "mata.core.exporters.valkey_exporter._get_valkey_client", + return_value=load_client, + ): + loaded = load_valkey("valkey://localhost:6379", key="rt:vision") + + assert isinstance(loaded, VisionResult) + assert len(loaded.instances) == len(original.instances) + assert loaded.instances[0].score == pytest.approx(original.instances[0].score) + assert loaded.instances[0].label == original.instances[0].label + assert loaded.instances[0].label_name == original.instances[0].label_name + + def test_roundtrip_classify_result(self): + """ClassifyResult survives export → load round-trip intact.""" + from mata.core.exporters.valkey_exporter import export_valkey, load_valkey + from mata.core.types import ClassifyResult + + original = _make_classify_result() + stored_value: list[str] = [] + + export_client = _mock_client() + export_client.set.side_effect = lambda k, v: stored_value.append(v) + + with patch( + "mata.core.exporters.valkey_exporter._get_valkey_client", + return_value=export_client, + ): + export_valkey(original, url="valkey://localhost:6379", key="rt:classify") + + assert stored_value + + load_client = _mock_client() + load_client.get.return_value = stored_value[0].encode() if isinstance(stored_value[0], str) else stored_value[0] + + with patch( + "mata.core.exporters.valkey_exporter._get_valkey_client", + return_value=load_client, + ): + loaded = load_valkey("valkey://localhost:6379", key="rt:classify") + + assert isinstance(loaded, ClassifyResult) + assert len(loaded.predictions) == len(original.predictions) + assert loaded.predictions[0].score == pytest.approx(original.predictions[0].score) + assert loaded.predictions[0].label_name == original.predictions[0].label_name + + +# ── TestParseValkeyURI ──────────────────────────────────────────────────────── + + +class TestParseValkeyURI: + """Tests for _parse_valkey_uri() helper.""" + + def test_simple_uri(self): + """Simple valkey://host:port/key parses correctly.""" + from mata.core.exporters.valkey_exporter import _parse_valkey_uri + + base_url, key = _parse_valkey_uri("valkey://localhost:6379/my_key") + + assert base_url == "valkey://localhost:6379" + assert key == "my_key" + + def test_uri_with_db_number(self): + """valkey://host:port/db/key parses DB number into base_url.""" + from mata.core.exporters.valkey_exporter import _parse_valkey_uri + + base_url, key = _parse_valkey_uri("valkey://localhost:6379/0/my_key") + + assert base_url == "valkey://localhost:6379/0" + assert key == "my_key" + + def test_uri_with_password(self): + """URI with credentials passes through without stripping the password.""" + from mata.core.exporters.valkey_exporter import _parse_valkey_uri + + base_url, key = _parse_valkey_uri("redis://user:pass@host:6379/0/my_key") + + assert base_url == "redis://user:pass@host:6379/0" + assert key == "my_key" + + def test_redis_scheme(self): + """redis:// scheme is parsed the same way as valkey://.""" + from mata.core.exporters.valkey_exporter import _parse_valkey_uri + + base_url, key = _parse_valkey_uri("redis://localhost:6379/detections") + + assert base_url == "redis://localhost:6379" + assert key == "detections" + + def test_invalid_uri_raises(self): + """A URI with no key component raises ValueError.""" + from mata.core.exporters.valkey_exporter import _parse_valkey_uri + + with pytest.raises(ValueError, match="Invalid Valkey URI"): + _parse_valkey_uri("valkey://localhost:6379/") + + def test_empty_key_raises(self): + """A URI without any path raises ValueError.""" + from mata.core.exporters.valkey_exporter import _parse_valkey_uri + + with pytest.raises(ValueError, match="Invalid Valkey URI"): + _parse_valkey_uri("valkey://localhost:6379") + + def test_key_with_colon(self): + """Keys containing colon separators are preserved intact.""" + from mata.core.exporters.valkey_exporter import _parse_valkey_uri + + base_url, key = _parse_valkey_uri("valkey://localhost:6379/pipeline:detections:frame_001") + + assert base_url == "valkey://localhost:6379" + assert key == "pipeline:detections:frame_001" + + def test_db_number_with_key_containing_colon(self): + """DB number + key with colon separators parses correctly.""" + from mata.core.exporters.valkey_exporter import _parse_valkey_uri + + base_url, key = _parse_valkey_uri("valkey://localhost:6379/2/ns:my_key") + + assert base_url == "valkey://localhost:6379/2" + assert key == "ns:my_key" + + +# ── TestSaveValkeyIntegration ───────────────────────────────────────────────── + + +class TestSaveValkeyIntegration: + """Tests for result.save('valkey://...') URI scheme dispatch.""" + + def test_vision_result_save_valkey(self): + """VisionResult.save() routes valkey:// URI to export_valkey.""" + result = _make_vision_result() + with ( + patch("mata.core.exporters.valkey_exporter.export_valkey") as mock_export, + patch("mata.core.exporters.valkey_exporter._get_valkey_client"), + ): + result.save("valkey://localhost:6379/vision_key") + + mock_export.assert_called_once() + _, kwargs = mock_export.call_args + assert kwargs.get("key") == "vision_key" or mock_export.call_args[0][0] is result + + def test_vision_result_save_valkey_dispatches(self): + """Verify export_valkey receives correct url and key from VisionResult.save().""" + result = _make_vision_result() + client = _mock_client() + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + result.save("valkey://localhost:6379/my_vision_key") + + client.set.assert_called_once() + stored_data = client.set.call_args[0] + assert stored_data[0] == "my_vision_key" + # Verify the stored content is valid JSON + parsed = json.loads(stored_data[1]) + assert "instances" in parsed + + def test_classify_result_save_valkey(self): + """ClassifyResult.save() routes valkey:// URI to export_valkey.""" + result = _make_classify_result() + client = _mock_client() + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + result.save("valkey://localhost:6379/classify_key") + + client.set.assert_called_once() + stored_key = client.set.call_args[0][0] + assert stored_key == "classify_key" + + def test_depth_result_save_valkey(self): + """DepthResult.save() routes valkey:// URI to export_valkey.""" + result = _make_depth_result() + client = _mock_client() + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + result.save("valkey://localhost:6379/depth_key") + + client.set.assert_called_once() + stored_key = client.set.call_args[0][0] + assert stored_key == "depth_key" + + def test_ocr_result_save_valkey(self): + """OCRResult.save() routes valkey:// URI to export_valkey.""" + result = _make_ocr_result() + client = _mock_client() + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + result.save("valkey://localhost:6379/ocr_key") + + client.set.assert_called_once() + stored_key = client.set.call_args[0][0] + assert stored_key == "ocr_key" + + def test_detect_result_save_valkey(self): + """DetectResult.save() routes valkey:// URI to export_valkey.""" + from mata.core.types import Detection, DetectResult + + result = DetectResult( + detections=[Detection(bbox=(0, 0, 10, 10), score=0.8, label=0)], + meta={}, + ) + client = _mock_client() + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + result.save("valkey://localhost:6379/detect_key") + + client.set.assert_called_once() + stored_key = client.set.call_args[0][0] + assert stored_key == "detect_key" + + def test_segment_result_save_valkey(self): + """SegmentResult.save() routes valkey:// URI to export_valkey.""" + import numpy as np + + from mata.core.types import SegmentMask, SegmentResult + + mask_arr = np.zeros((8, 8), dtype=bool) + result = SegmentResult( + masks=[SegmentMask(mask=mask_arr, score=0.85, label=0)], + meta={}, + ) + client = _mock_client() + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + result.save("valkey://localhost:6379/segment_key") + + client.set.assert_called_once() + stored_key = client.set.call_args[0][0] + assert stored_key == "segment_key" + + def test_redis_scheme_also_dispatches(self): + """redis:// scheme is also routed to export_valkey (not file export).""" + result = _make_classify_result() + client = _mock_client() + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + result.save("redis://localhost:6379/redis_key") + + client.set.assert_called_once() + + def test_save_json_still_works(self, tmp_path): + """file.json save is unaffected — no regression.""" + from mata.core.exporters import export_json + + result = _make_vision_result() + output = tmp_path / "detections.json" + + with patch("mata.core.exporters.json_exporter.export_json", wraps=export_json) as _spy: + result.save(str(output)) + + assert output.exists() + data = json.loads(output.read_text()) + assert "instances" in data + + def test_save_csv_still_works(self, tmp_path): + """file.csv save is unaffected — no regression.""" + result = _make_vision_result() + output = tmp_path / "detections.csv" + + result.save(str(output)) + + assert output.exists() + content = output.read_text() + assert len(content) > 0 + + def test_save_image_still_works(self, tmp_path): + """Image path does not get mistakenly routed to valkey exporter.""" + import numpy as np + + result = _make_vision_result() + # Provide a dummy image so export_image doesn't fail + dummy_img = np.zeros((64, 64, 3), dtype=np.uint8) + output = tmp_path / "overlay.png" + + # Only check that export_valkey is NOT called for a .png save + with patch("mata.core.exporters.valkey_exporter.export_valkey") as mock_export_valkey: + result.save(str(output), image=dummy_img) + + mock_export_valkey.assert_not_called() + + def test_save_ttl_forwarded_via_save(self): + """Extra kwargs (e.g., ttl) are forwarded from result.save() to export_valkey.""" + result = _make_vision_result() + client = _mock_client() + + with patch("mata.core.exporters.valkey_exporter._get_valkey_client", return_value=client): + result.save("valkey://localhost:6379/ttl_key", ttl=7200) + + client.setex.assert_called_once() + args = client.setex.call_args[0] + assert args[0] == "ttl_key" + assert args[1] == 7200 + + def test_importable_without_valkey_installed(self): + """export_valkey is importable even when valkey-py is not installed.""" + # This test simply calls import — the real guard is at call time, not import time + with patch.dict(sys.modules, {"valkey": None, "redis": None}): + # Re-execute the import to confirm no top-level ImportError + import importlib + + import mata.core.exporters.valkey_exporter as mod + + importlib.reload(mod) + # Should still be callable (just raises at execution time) + assert callable(mod.export_valkey) diff --git a/tests/test_valkey_nodes.py b/tests/test_valkey_nodes.py new file mode 100644 index 0000000..d9bff88 --- /dev/null +++ b/tests/test_valkey_nodes.py @@ -0,0 +1,442 @@ +"""Tests for ValkeyStore and ValkeyLoad graph nodes — Task B1, B2.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from mata.core.artifacts.base import Artifact +from mata.core.artifacts.classifications import Classifications +from mata.core.artifacts.depth_map import DepthMap +from mata.core.artifacts.detections import Detections +from mata.core.artifacts.masks import Masks +from mata.core.graph.context import ExecutionContext +from mata.core.graph.graph import CompiledGraph, Graph +from mata.core.graph.node import Node +from mata.core.types import Classification, ClassifyResult, DepthResult, Instance, VisionResult +from mata.nodes.valkey_load import ValkeyLoad +from mata.nodes.valkey_store import ValkeyStore + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_instance(label: str = "cat", score: float = 0.9) -> Instance: + return Instance(label=0, label_name=label, score=score, bbox=[10, 20, 100, 200]) + + +def _make_vision_result(n: int = 2) -> VisionResult: + instances = [_make_instance(score=0.9 - i * 0.1) for i in range(n)] + return VisionResult(instances=instances) + + +def _make_detections(n: int = 2) -> Detections: + return Detections.from_vision_result(_make_vision_result(n)) + + +def _make_masks() -> Masks: + mask_arr = np.ones((64, 64), dtype=bool) + instances = [ + Instance(label=0, label_name="cat", score=0.9, bbox=[10, 20, 100, 200], mask=mask_arr), + Instance(label=1, label_name="dog", score=0.8, bbox=[50, 60, 150, 250], mask=mask_arr), + ] + return Masks(instances=instances) + + +def _make_classifications() -> Classifications: + preds = [ + Classification(label=0, score=0.95, label_name="cat"), + Classification(label=1, score=0.05, label_name="dog"), + ] + return Classifications(predictions=tuple(preds)) + + +def _make_depth_map() -> DepthMap: + depth_arr = np.ones((64, 64), dtype=np.float32) + result = DepthResult(depth=depth_arr, normalized=depth_arr) + return DepthMap.from_depth_result(result) + + +def _make_ctx() -> ExecutionContext: + return ExecutionContext(providers={}, device="cpu") + + +# --------------------------------------------------------------------------- +# TestValkeyStoreNode +# --------------------------------------------------------------------------- + + +class TestValkeyStoreNode: + """Tests for ValkeyStore sink node.""" + + def test_inherits_from_node(self): + node = ValkeyStore(src="dets", url="valkey://localhost:6379", key="test:key") + assert isinstance(node, Node) + + def test_declares_inputs_outputs(self): + # inputs/outputs are set dynamically in __init__ based on src + node = ValkeyStore(src="dets", url="valkey://localhost:6379", key="test:key") + assert "dets" in node.inputs + assert "dets" in node.outputs + assert node.inputs["dets"] is Artifact + assert node.outputs["dets"] is Artifact + + def test_importable_from_mata_nodes(self): + from mata.nodes import ValkeyStore as VS # noqa: F401, N817 + + assert VS is ValkeyStore + + def test_run_stores_artifact(self): + artifact = _make_detections() + ctx = _make_ctx() + node = ValkeyStore(src="dets", url="valkey://localhost:6379", key="test:dets") + + mock_result = MagicMock() + mock_result.to_json.return_value = '{"instances": []}' + + with patch( + "mata.nodes.valkey_store.ValkeyStore._artifact_to_serializable", + return_value=mock_result, + ): + with patch("mata.core.exporters.valkey_exporter.export_valkey") as mock_export: + node.run(ctx, artifact=artifact) + assert mock_export.called + + def test_passthrough_artifact_unchanged(self): + artifact = _make_detections() + ctx = _make_ctx() + node = ValkeyStore(src="dets", url="valkey://localhost:6379", key="test:dets", out="dets_out") + + mock_result = MagicMock() + mock_result.to_json.return_value = "{}" + + with patch( + "mata.nodes.valkey_store.ValkeyStore._artifact_to_serializable", + return_value=mock_result, + ): + with patch("mata.core.exporters.valkey_exporter.export_valkey"): + result = node.run(ctx, artifact=artifact) + + # The artifact must be returned unchanged + assert "dets_out" in result + assert result["dets_out"] is artifact + + def test_key_template_node_placeholder(self): + artifact = _make_detections() + ctx = _make_ctx() + node = ValkeyStore(src="dets", url="valkey://localhost:6379", key="pipeline:{node}:result") + + captured_keys = [] + + mock_result = MagicMock() + mock_result.to_json.return_value = "{}" + + with patch( + "mata.nodes.valkey_store.ValkeyStore._artifact_to_serializable", + return_value=mock_result, + ): + with patch("mata.core.exporters.valkey_exporter.export_valkey") as mock_export: + node.run(ctx, artifact=artifact) + call_kwargs = mock_export.call_args + captured_keys.append(call_kwargs[1]["key"] if call_kwargs[1] else call_kwargs[0][2]) + + # {node} should be replaced with node name + resolved_key = captured_keys[0] + assert "{node}" not in resolved_key + assert "ValkeyStore" in resolved_key + + def test_key_template_timestamp_placeholder(self): + artifact = _make_detections() + ctx = _make_ctx() + node = ValkeyStore(src="dets", url="valkey://localhost:6379", key="pipeline:frame:{timestamp}") + + mock_result = MagicMock() + mock_result.to_json.return_value = "{}" + + with patch( + "mata.nodes.valkey_store.ValkeyStore._artifact_to_serializable", + return_value=mock_result, + ): + with patch("mata.core.exporters.valkey_exporter.export_valkey") as mock_export: + node.run(ctx, artifact=artifact) + + call_kwargs = mock_export.call_args + key_arg = call_kwargs[1].get("key") or call_kwargs[0][2] + assert "{timestamp}" not in key_arg + # Should be numeric (timestamp) + ts_part = key_arg.split(":")[-1] + assert ts_part.isdigit() + + def test_ttl_parameter_forwarded(self): + artifact = _make_detections() + ctx = _make_ctx() + node = ValkeyStore(src="dets", url="valkey://localhost:6379", key="test:key", ttl=300) + + mock_result = MagicMock() + mock_result.to_json.return_value = "{}" + + with patch( + "mata.nodes.valkey_store.ValkeyStore._artifact_to_serializable", + return_value=mock_result, + ): + with patch("mata.core.exporters.valkey_exporter.export_valkey") as mock_export: + node.run(ctx, artifact=artifact) + + call_kwargs = mock_export.call_args + assert call_kwargs[1].get("ttl") == 300 or (len(call_kwargs[0]) > 3 and call_kwargs[0][3] == 300) + + def test_serializer_parameter_forwarded(self): + artifact = _make_detections() + ctx = _make_ctx() + node = ValkeyStore(src="dets", url="valkey://localhost:6379", key="test:key", serializer="msgpack") + + mock_result = MagicMock() + mock_result.to_json.return_value = "{}" + + with patch( + "mata.nodes.valkey_store.ValkeyStore._artifact_to_serializable", + return_value=mock_result, + ): + with patch("mata.core.exporters.valkey_exporter.export_valkey") as mock_export: + node.run(ctx, artifact=artifact) + + call_kwargs = mock_export.call_args + assert call_kwargs[1].get("serializer") == "msgpack" + + def test_graph_compilation(self): + """ValkeyStore compiles successfully in a Graph DAG.""" + from unittest.mock import Mock + + from mata.nodes.detect import Detect + + mock_detector = Mock() + mock_detector.predict = Mock(return_value=_make_vision_result()) + + # ValkeyStore.inputs is keyed by src so auto-wiring picks up "dets" + graph = ( + Graph() + .then(Detect(using="detr", out="dets")) + .then(ValkeyStore(src="dets", url="valkey://localhost:6379", key="test:dets")) + ) + + compiled = graph.compile(providers={"detr": mock_detector}) + assert isinstance(compiled, CompiledGraph) + assert compiled.validation_result.valid + assert len(compiled.nodes) == 2 + + def test_graph_execution_with_detect(self): + """ValkeyStore executes cleanly in a full graph pipeline.""" + from unittest.mock import Mock + + from PIL import Image as PILImage + + from mata.core.artifacts.image import Image + from mata.core.graph.scheduler import SyncScheduler + from mata.nodes.detect import Detect + + vision_result = _make_vision_result() + mock_detector = Mock() + mock_detector.predict = Mock(return_value=vision_result) + + compile_providers = {"detr": mock_detector} + ctx_providers = {"detect": {"detr": mock_detector}} + + # ValkeyStore.inputs is keyed by src so auto-wiring picks up "dets" + graph = ( + Graph() + .then(Detect(using="detr", out="dets")) + .then(ValkeyStore(src="dets", url="valkey://localhost:6379", key="test:dets")) + ) + + compiled = graph.compile(providers=compile_providers) + ctx = ExecutionContext(providers=ctx_providers, device="cpu") + + pil_img = PILImage.new("RGB", (64, 64), color=(128, 128, 128)) + image_artifact = Image.from_pil(pil_img) + + with patch("mata.core.exporters.valkey_exporter.export_valkey"): + scheduler = SyncScheduler() + result = scheduler.execute( + compiled, + ctx, + initial_artifacts={"input.image": image_artifact}, + ) + + # The store node should pass through; result contains all artifacts + assert result is not None + + def test_artifact_to_serializable_detections(self): + artifact = _make_detections() + result = ValkeyStore._artifact_to_serializable(artifact) + assert isinstance(result, VisionResult) + + def test_artifact_to_serializable_masks(self): + artifact = _make_masks() + result = ValkeyStore._artifact_to_serializable(artifact) + assert isinstance(result, VisionResult) + + def test_artifact_to_serializable_classifications(self): + artifact = _make_classifications() + result = ValkeyStore._artifact_to_serializable(artifact) + assert isinstance(result, ClassifyResult) + + def test_artifact_to_serializable_depth(self): + artifact = _make_depth_map() + result = ValkeyStore._artifact_to_serializable(artifact) + assert isinstance(result, DepthResult) + + def test_artifact_to_serializable_fallback(self): + """Unknown artifact types are returned as-is (fallback).""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class CustomArtifact(Artifact): + def to_dict(self) -> dict: + return {} + + @classmethod + def from_dict(cls, data: dict) -> CustomArtifact: + return cls() + + artifact = CustomArtifact() + result = ValkeyStore._artifact_to_serializable(artifact) + # Fallback: returns the artifact itself + assert result is artifact + + def test_default_output_name_same_as_src(self): + """out defaults to src when not provided.""" + node = ValkeyStore(src="my_dets", url="valkey://localhost:6379", key="k") + assert node.output_name == "my_dets" + + def test_custom_output_name(self): + node = ValkeyStore(src="my_dets", url="valkey://localhost:6379", key="k", out="stored") + assert node.output_name == "stored" + + +# --------------------------------------------------------------------------- +# TestValkeyLoadNode +# --------------------------------------------------------------------------- + + +class TestValkeyLoadNode: + """Tests for ValkeyLoad source node.""" + + def test_inherits_from_node(self): + node = ValkeyLoad(url="valkey://localhost:6379", key="test:key") + assert isinstance(node, Node) + + def test_no_inputs_declared(self): + assert ValkeyLoad.inputs == {} + + def test_importable_from_mata_nodes(self): + from mata.nodes import ValkeyLoad as VL # noqa: F401, N817 + + assert VL is ValkeyLoad + + def test_run_loads_artifact(self): + vision_result = _make_vision_result() + ctx = _make_ctx() + node = ValkeyLoad(url="valkey://localhost:6379", key="test:key", out="loaded_dets") + + with patch("mata.core.exporters.valkey_exporter.load_valkey", return_value=vision_result): + result = node.run(ctx) + + assert "loaded_dets" in result + assert isinstance(result["loaded_dets"], Detections) + + def test_result_to_artifact_vision(self): + vision_result = _make_vision_result() + artifact = ValkeyLoad._result_to_artifact(vision_result) + assert isinstance(artifact, Detections) + + def test_result_to_artifact_classify(self): + classify_result = ClassifyResult(predictions=[Classification(label=0, score=0.9, label_name="cat")]) + artifact = ValkeyLoad._result_to_artifact(classify_result) + assert isinstance(artifact, Classifications) + + def test_result_to_artifact_depth(self): + depth_arr = np.ones((32, 32), dtype=np.float32) + depth_result = DepthResult(depth=depth_arr, normalized=depth_arr) + artifact = ValkeyLoad._result_to_artifact(depth_result) + assert isinstance(artifact, DepthMap) + + def test_result_to_artifact_unknown_raises(self): + with pytest.raises(TypeError, match="Cannot convert"): + ValkeyLoad._result_to_artifact("not_a_result") + + def test_missing_key_raises(self): + ctx = _make_ctx() + node = ValkeyLoad(url="valkey://localhost:6379", key="nonexistent:key") + + with patch( + "mata.core.exporters.valkey_exporter.load_valkey", + side_effect=KeyError("nonexistent:key"), + ): + with pytest.raises(KeyError): + node.run(ctx) + + def test_graph_compilation_as_source(self): + """ValkeyLoad (source node with no inputs) compiles in a Graph DAG.""" + graph = Graph().then(ValkeyLoad(url="valkey://localhost:6379", key="upstream:dets", out="dets")) + + compiled = graph.compile(providers={}) + assert isinstance(compiled, CompiledGraph) + assert compiled.validation_result.valid + assert len(compiled.nodes) == 1 + + def test_graph_chain_load_then_filter(self): + """ValkeyLoad → Filter chain compiles and validates correctly.""" + from mata.nodes.filter import Filter + + graph = ( + Graph() + .then(ValkeyLoad(url="valkey://localhost:6379", key="upstream:dets", out="dets")) + .then(Filter(src="dets", out="filtered", score_gt=0.5)) + ) + + compiled = graph.compile(providers={}) + assert isinstance(compiled, CompiledGraph) + assert compiled.validation_result.valid + assert len(compiled.nodes) == 2 + + def test_default_output_name(self): + node = ValkeyLoad(url="valkey://localhost:6379", key="test:key") + assert node.output_name == "loaded" + + def test_custom_output_name(self): + node = ValkeyLoad(url="valkey://localhost:6379", key="test:key", out="my_artifact") + assert node.output_name == "my_artifact" + + def test_result_type_forwarded_to_load_valkey(self): + """result_type parameter is passed to load_valkey.""" + vision_result = _make_vision_result() + ctx = _make_ctx() + node = ValkeyLoad( + url="valkey://localhost:6379", + key="test:key", + result_type="vision", + ) + + with patch("mata.core.exporters.valkey_exporter.load_valkey", return_value=vision_result) as mock_load: + node.run(ctx) + mock_load.assert_called_once_with( + url="valkey://localhost:6379", + key="test:key", + result_type="vision", + ) + + def test_load_and_store_round_trip(self): + """ValkeyLoad and ValkeyStore can be used in the same graph.""" + graph = ( + Graph() + .then(ValkeyLoad(url="valkey://localhost:6379", key="upstream:dets", out="dets")) + .then(ValkeyStore(src="dets", url="valkey://localhost:6379", key="downstream:{timestamp}")) + ) + + compiled = graph.compile(providers={}) + assert isinstance(compiled, CompiledGraph) + assert compiled.validation_result.valid + assert len(compiled.nodes) == 2 From 9260be586209a328989572b84f4aeb16aaf53363 Mon Sep 17 00:00:00 2001 From: Bimantoro Maesa Date: Mon, 9 Mar 2026 19:17:45 +0700 Subject: [PATCH 2/2] Add msgpack as a dependency for Valkey/Redis result storage --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 3cda02e..d5cc7c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ ocr-all = [ # Valkey/Redis result storage valkey = [ "valkey>=6.0.0", + "msgpack>=1.0.0", ] # Redis fallback (wire-compatible with Valkey)