diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index df4c7a2..cdc6106 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -122,7 +122,7 @@ Built-in Tools (zoom, crop) + Provider Tools (detect, classify, etc.) - **Error modes**: retry (default), skip, fail - **Observability**: Full tracing and metrics integration -**Testing:** 336 comprehensive tests in `test_tool_schema.py`, `test_tool_registry.py`, `test_agent_loop.py`, `test_tool_prompts.py`, `test_tool_call_parser.py`, etc. All passing with zero regressions. +**Testing:** 342 comprehensive tests in `test_tool_schema.py`, `test_tool_registry.py`, `test_agent_loop.py`, `test_tool_prompts.py`, `test_tool_call_parser.py`, etc. All passing with zero regressions. **Documentation:** See `docs/VLM_TOOL_CALLING_SUMMARY.md` for complete architecture details, design decisions, limitations, and future roadmap. @@ -152,7 +152,7 @@ pytest tests/test_track_node.py -v # Track graph node (39 tests) # VLM tool-calling test suites (v1.7.0) pytest tests/test_tool_schema.py -v # Tool schema (33 tests) -pytest tests/test_tool_registry.py -v # Tool registry (44 tests) +pytest tests/test_tool_registry.py -v # Tool registry (49 tests) pytest tests/test_agent_loop.py -v # Agent loop (51 tests) pytest tests/test_tool_prompts.py -v # Tool prompts (18 tests) pytest tests/test_tool_call_parser.py -v # Tool call parser (51 tests) diff --git a/.gitignore b/.gitignore index ed76607..3416d29 100644 --- a/.gitignore +++ b/.gitignore @@ -167,6 +167,8 @@ runs/ # Dev Test Artifacts examples/inference/outputs/ bm_test/ +example_test.sh +example_test.ps1 # Auto Claude data directory .auto-claude/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ca966e..475ab5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,26 @@ Versions follow [Semantic Versioning](https://semver.org/). --- +## [1.9.1] - 2026-03-08 + +### Changed + +- Refactored graph flow notation from `β†’` to `>` in all examples, scripts, and documentation for consistency with the DSL operator syntax +- Updated expected output structure descriptions in examples and docs to match the new `>` notation + +### Added + +- `ToolRegistry` now requires `text_prompts` for zero-shot providers (GroundingDINO, OWL-ViT, CLIP) and raises `ValueError` when they are missing +- Improved tool schema generation: zero-shot providers automatically include a `text_prompts` parameter in their generated `ToolSchema` +- Tests for zero-shot provider detection and `text_prompts` schema requirement in `test_tool_registry.py` + +### Fixed + +- SAM adapter: minor issue where prompt-less calls could silently produce empty masks instead of raising a clear error +- Video tracking examples: corrected frame iteration and output path handling in `examples/track/` + +--- + ## [1.9.0] - 2026-03-02 ### Added diff --git a/README.md b/README.md index b4bc566..9780d3e 100644 --- a/README.md +++ b/README.md @@ -1171,72 +1171,7 @@ export MATA_CONFIG=/path/to/config.json ## πŸ›£οΈ Roadmap -### βœ… Completed (v1.9.0 - Current) - -#### **OCR / Text Extraction** β€” Five backends, graph nodes, evaluation pipeline - -- βœ… **Five OCR backends**: EasyOCR (80+ languages), PaddleOCR (multilingual), Tesseract (classic), GOT-OCR2 (HuggingFace end-to-end), TrOCR (HuggingFace line-level) -- βœ… **`mata.run("ocr", ...)` API**: Unified entry point β€” `model=` selects backend or HuggingFace ID -- βœ… **`mata.load("ocr", ...)` API**: Returns persistent adapter for repeated inference -- βœ… **`OCRResult` type**: `.full_text`, `.regions` (list of `TextRegion` with bbox + score + text) -- βœ… **Multi-format export**: `.save("out.txt")`, `.save("out.csv")`, `.save("out.json")`, `.save("overlay.png")` -- βœ… **`filter_by_score()`**: Confidence-threshold filtering on OCR results -- βœ… **`OCRText` graph artifact**: Strongly-typed artifact for graph pipelines -- βœ… **`OCR` graph node**: Accepts `Image` or `ROIs` input, aggregates per-crop results with `instance_id` correlation -- βœ… **`ExtractROIs` graph node**: Crops detection regions for downstream OCR -- βœ… **`OCRWrapper`**: Protocol-based capability wrapper enabling OCR as a graph provider -- βœ… **VLM tool integration**: `"ocr"` registered in `ToolRegistry` and `TASK_SCHEMA_DEFAULTS` for agent mode -- βœ… **UniversalLoader routing**: Bare engine names (`"easyocr"`, `"paddleocr"`, `"tesseract"`) routed via `_EXTERNAL_OCR_ENGINES`; HuggingFace OCR IDs routed through `_load_from_huggingface()` -- βœ… **Optional dependencies**: EasyOCR, PaddleOCR, Tesseract declared as optional extras in `pyproject.toml` -- βœ… **71 evaluation tests**: `test_eval_ocr.py` β€” all passing, zero regressions against 4307+ total -- ⏳ **`mata.val("ocr", ...)` evaluation**: `OCRMetrics` (word accuracy, character accuracy, precision, recall, F1) with COCO-Text JSON dataset loader | PENDING for v1.9.1 release due to dataset licensing review. - -### βœ… Completed (v1.8) - -#### **Object Tracking** β€” ByteTrack + BotSort - -- βœ… **Vendored ByteTrack**: Zero-dependency implementation in `src/mata/trackers/` (no yolox/ultralytics) -- βœ… **Vendored BotSort**: IoU + Global Motion Compensation (GMC via sparse optical flow) -- βœ… **`mata.track()` API**: One-liner video/stream/webcam/image-dir tracking -- βœ… **`mata.load("track", ...)` API**: Returns `TrackingAdapter` for persistent per-frame tracking -- βœ… **Multiple source types**: Video files, RTSP streams, webcams, image directories, single images -- βœ… **Track ID rendering**: `show_track_ids=True` with deterministic per-track colors -- βœ… **Trajectory trails**: `show_trails=True` β€” PIL-native polyline history rendering -- βœ… **CSV/JSON export**: MOT-compatible CSV export, multi-frame JSON with metadata -- βœ… **Graph node upgrade**: `Track` node uses vendored trackers, `BotSortWrapper` added -- βœ… **Graph presets**: BotSort variants added to surveillance/driving presets -- βœ… **YAML config**: Tracker settings in `~/.mata/models.yaml` under `track:` task -- βœ… **687 tests**: All passing, zero regressions against 4047+ total - -### βœ… Completed (v1.6) - -#### **Graph System Architecture** - Multi-task workflows with parallel execution - -- βœ… **Artifact Type System**: Strongly-typed vision primitives (Image, Detections, Masks, Keypoints, Tracks, ROIs) -- βœ… **Task Graph Builder**: Fluent API for composing multi-task pipelines (Detect β†’ Segment β†’ Pose) -- βœ… **Parallel Execution**: Automatic parallelization of independent tasks (1.5-3x speedup, 41x in benchmarks) -- βœ… **Conditional Branching**: Result-driven workflow control with If/else, HasLabel, CountAbove, ScoreAbove -- βœ… **Temporal Processing**: Video inference with BYTETrack/IoU tracking and frame policies -- βœ… **Capability Providers**: Protocol-based model registry with lazy loading -- βœ… **VLM Graph Nodes**: VLMDescribe, VLMDetect, VLMQuery, PromoteEntities for Entityβ†’Instance workflows -- βœ… **Visualization Nodes**: Native Annotate and NMS nodes reusing existing PIL/matplotlib backends -- βœ… **Pre-built Presets**: 8 graph presets (detection+segmentation, scene analysis, VLM workflows, tracking) -- βœ… **Observability**: Metrics collection, execution tracing, and provenance tracking -- βœ… **`mata.infer()` API**: New public API for graph execution with flat provider dicts -- βœ… **Backward Compatibility**: 100% compatible with existing `mata.load()`/`mata.run()` APIs -- βœ… **Comprehensive Testing**: 2185 tests, >80% coverage - -### βœ… Completed (v1.5.3) - -- βœ… **Multi-Task Support**: Detection, classification, segmentation, depth estimation, vision-language models (VLM) -- βœ… **Zero-Shot Capabilities**: CLIP (classify), GroundingDINO/OWL-ViT (detect), SAM/SAM3 (segment) -- βœ… **Vision-Language Models**: Image captioning, VQA, visual understanding with Qwen3-VL - February 2026 -- βœ… **Universal Loader**: llama.cpp-style loading with 5-strategy auto-detection -- βœ… **Multi-Format Runtime**: PyTorch, ONNX Runtime, TorchScript, Torchvision support -- βœ… **Torchvision CNN Detection**: Apache 2.0 licensed models (RetinaNet, Faster R-CNN, FCOS, SSD) - February 2026 -- βœ… **Export & Visualization**: JSON/CSV/image/crops with dual backends (PIL/matplotlib) -- βœ… **Plugin Removal**: Simplified architecture, -1,268 lines of legacy code -- βœ… **Comprehensive Testing**: 405 tests (exceeded 202+ target), 60-85% coverage +> **For a full history of completed features, see [CHANGELOG.md](CHANGELOG.md).** ### πŸ”„ In Progress @@ -1245,7 +1180,7 @@ export MATA_CONFIG=/path/to/config.json - ⏳ **ReID model integration**: Feature embeddings via HuggingFace ReID models - ⏳ **Cross-camera tracking**: Match track IDs across camera feeds - ⏳ **BotSort ReID mode**: Enable `with_reid=true` in botsort config -- **Status**: Planned for v1.9 +- **Status**: Planned for v1.9.x #### **2. KACA Integration** - MIT-licensed CNN detection with PyTorch and ONNX support @@ -1261,6 +1196,7 @@ export MATA_CONFIG=/path/to/config.json - πŸ”„ **Model Recommendations**: Suggest best models based on task and hardware constraints - πŸ”„ **Batch Model Download**: Pre-download common models for air-gapped environments - πŸ”„ **Enhanced Search**: Filter by task, license, performance metrics +- **Status**: Planned for v2.x ### ⏳ Planned (v2.0 - Q2 2026) diff --git a/docs/VLM_TOOL_CALLING_SUMMARY.md b/docs/VLM_TOOL_CALLING_SUMMARY.md index 2cb1726..3f1a7a1 100644 --- a/docs/VLM_TOOL_CALLING_SUMMARY.md +++ b/docs/VLM_TOOL_CALLING_SUMMARY.md @@ -1,9 +1,10 @@ # VLM Tool-Calling Agent System β€” Architecture Summary -**Version**: 1.7.0 +**Version**: 1.7.1 **Implementation Date**: February 16, 2026 +**Last Updated**: March 8, 2026 **Status**: βœ… Production Ready -**Test Coverage**: 336 comprehensive tests, all passing +**Test Coverage**: 342 comprehensive tests, all passing --- @@ -270,6 +271,31 @@ mata.infer( --- +### 8. **Provider-Aware Schema Generation for Zero-Shot Models** _(v1.7.1)_ + +**Decision**: `ToolRegistry` introspects the actual provider at registration time and upgrades `text_prompts` to `required=True` for zero-shot adapters. + +**Rationale**: + +- **VLM must know `text_prompts` is required** β€” The default `TASK_SCHEMA_DEFAULTS["detect"]` marks `text_prompts` as optional (correct for supervised detectors like RT-DETR, YOLO). But zero-shot models (GroundingDINO, OWL-ViT) **cannot run without class names**. If the schema shows the parameter as optional, the VLM's system prompt will say _"optional"_ and the agent will omit it, causing a `TypeError` or `InvalidInputError` at execution time. +- **Zero-shot contract is enforced at the adapter level** β€” `HuggingFaceZeroShotDetectAdapter.predict()` keeps `text_prompts` as a required positional argument. The fix is upstream: make the _schema_ match the adapter's actual contract. +- **Clean detection via class name** β€” All MATA zero-shot adapters have `"ZeroShot"` in their class name. `_is_zero_shot_provider()` unwraps one layer of wrapper (e.g., `DetectorWrapper.adapter`) and checks the underlying class name β€” no new class attributes or protocol changes needed. +- **`TASK_SCHEMA_DEFAULTS` stays generic** β€” The shared default schema is not modified; customization happens per-provider at `ToolRegistry` construction time. + +**Agentic chain this enables**: + +``` +VLM: "I see an unknown object. Let me classify it." + β†’ classifier(region=[80,120,220,300]) β†’ "cat (0.92)" +VLM: "It's a cat. Let me find all cats using the detector." + β†’ detector(text_prompts="cat") β†’ 2 cats detected +VLM: "Found 2 cats at [80,120,220,300] and [300,130,440,280]. Summary..." +``` + +**Implementation**: `_is_zero_shot_provider()` + upgraded `_schema_for_capability(capability, tool_name, provider)` in `src/mata/core/tool_registry.py` (v1.7.1). + +--- + ### 7. **Multi-Format Tool Call Parsing** **Decision**: Support fenced blocks (` ```tool_call `), XML (``), and raw JSON. @@ -560,6 +586,11 @@ result = AgentResult( - **Impact**: VLMs may output `"0.5"` instead of `0.5` for floats - **Solution**: Comprehensive type coercion in `validate_tool_call()` +#### ~~4. Zero-Shot Detector Omits `text_prompts`~~ _(Fixed β€” v1.7.1)_ + +- **Was**: `TASK_SCHEMA_DEFAULTS["detect"]` marked `text_prompts` as optional, causing the VLM to omit it. Zero-shot adapters require it, so the call failed with `TypeError`. +- **Fix**: `ToolRegistry._schema_for_capability()` now introspects the actual provider via `_is_zero_shot_provider()` and upgrades `text_prompts` to `required=True` for zero-shot adapters. The VLM's system prompt now correctly says the parameter is required, so the agent always populates it from its own reasoning. + --- ## Future Directions diff --git a/examples/classify/basic_classification.py b/examples/classify/basic_classification.py index 0a52beb..5137a63 100644 --- a/examples/classify/basic_classification.py +++ b/examples/classify/basic_classification.py @@ -1,6 +1,6 @@ """Basic Classification Examples β€” MATA Framework -Progressive examples: one-shot β†’ load/reuse β†’ model comparison β†’ filtering. +Progressive examples: one-shot > load/reuse > model comparison > filtering. Run: python examples/classify/basic_classification.py """ @@ -31,7 +31,7 @@ def load_and_reuse(): for _ in range(2): result = classifier.predict(get_image()) top1 = result.get_top1() - print(f" β†’ {top1.label_name}: {top1.score * 100:.2f}%") + print(f" to {top1.label_name}: {top1.score * 100:.2f}%") # === Section 3: Access Results (.get_top1, top-5 predictions) === diff --git a/examples/classify/clip_zeroshot.py b/examples/classify/clip_zeroshot.py index 0b0a4e8..3d99118 100644 --- a/examples/classify/clip_zeroshot.py +++ b/examples/classify/clip_zeroshot.py @@ -150,7 +150,7 @@ def example4_batch_classification(): for name, img in images: result = classifier.predict(img, text_prompts=text_prompts, top_k=2) top2 = [(p.label_name, f"{p.score:.4f}") for p in result.predictions] - print(f" {name:15s} β†’ {top2}") + print(f" {name:15s} to {top2}") def main(): diff --git a/examples/depth/basic_depth.py b/examples/depth/basic_depth.py index 02c5ae2..0a4c7a1 100644 --- a/examples/depth/basic_depth.py +++ b/examples/depth/basic_depth.py @@ -38,7 +38,7 @@ def example_depth_v1(): OUTPUT_DIR.mkdir(parents=True, exist_ok=True) result.save(OUTPUT_DIR / "depth_v1.png", colormap="magma") result.save(OUTPUT_DIR / "depth_v1.json") - print(f"Saved β†’ {OUTPUT_DIR}/depth_v1.png and depth_v1.json") + print(f"Saved to {OUTPUT_DIR}/depth_v1.png and depth_v1.json") # === Section 2: One-Shot Depth (Depth Anything V2) === @@ -61,7 +61,7 @@ def example_depth_v2(): OUTPUT_DIR.mkdir(parents=True, exist_ok=True) result.save(OUTPUT_DIR / "depth_v2.png", colormap="magma") result.save(OUTPUT_DIR / "depth_v2.json") - print(f"Saved β†’ {OUTPUT_DIR}/depth_v2.png and depth_v2.json") + print(f"Saved to {OUTPUT_DIR}/depth_v2.png and depth_v2.json") # === Section 3: Load Once, Predict Many === @@ -105,6 +105,8 @@ def main(): except Exception as exc: print(f" [error] load-once: {exc}") + print("\nDone.") + if __name__ == "__main__": main() diff --git a/examples/detect/basic_detection.py b/examples/detect/basic_detection.py index 1e932fe..d7d7248 100644 --- a/examples/detect/basic_detection.py +++ b/examples/detect/basic_detection.py @@ -87,12 +87,12 @@ def section_export(output_dir: Path): # Save to .json file json_path = output_dir / "detections.json" result.save(str(json_path)) - print(f"[export] Saved JSON β†’ {json_path}") + print(f"[export] Saved JSON to {json_path}") # Save annotated image (overlay bboxes on the source image) img_path = output_dir / "detections_overlay.jpg" result.save(str(img_path)) - print(f"[export] Saved image β†’ {img_path}") + print(f"[export] Saved image to {img_path}") # === Section 6: Config Aliases === @@ -104,7 +104,7 @@ def section_config_aliases(): mata.register_model("detect", "my-rtdetr", "PekingU/rtdetr_r50vd", threshold=0.6) detector = mata.load("detect", "my-detr") - print(f"[alias] Loaded 'my-detr' β†’ {detector.__class__.__name__}") + print(f"[alias] Loaded 'my-detr' to {detector.__class__.__name__}") # Config-file aliases work the same way β€” set them in .mata/models.yaml # and load by name without calling register_model() in code. diff --git a/examples/detect/zeroshot_detection.py b/examples/detect/zeroshot_detection.py index 33fa17e..cf05736 100644 --- a/examples/detect/zeroshot_detection.py +++ b/examples/detect/zeroshot_detection.py @@ -88,7 +88,7 @@ def example_grounding_dino(): output_image = draw_detections(image.copy(), result, text_prompts) output_path = "examples/images/output_grounding_dino.jpg" output_image.save(output_path) - print(f"\nβœ“ Saved visualization to: {output_path}") + print(f"\n Saved visualization to: {output_path}") return result @@ -119,7 +119,7 @@ def example_owlvit_v2(): output_image = draw_detections(image.copy(), result, text_prompts) output_path = "examples/images/output_owlvit_v2.jpg" output_image.save(output_path) - print(f"\nβœ“ Saved visualization to: {output_path}") + print(f"\n Saved visualization to: {output_path}") return result @@ -153,7 +153,7 @@ def example_batch_processing(): for instance in result.instances: print(f" - {instance.label_name}: {instance.score:.3f}") - print(f"\nβœ“ Processed {len(images)} images in batch") + print(f"\n Processed {len(images)} images in batch") return results @@ -214,9 +214,9 @@ def example_model_comparison(): print(f" OWL-ViT v2: {len(result_owlv2.instances)} objects") print("\n[Results] Model comparison:") - print(f" β”œβ”€ GroundingDINO: {len(result_gdino.instances)} detections") - print(f" β”œβ”€ OWL-ViT v1: {len(result_owlv1.instances)} detections") - print(f" └─ OWL-ViT v2: {len(result_owlv2.instances)} detections") + print(f" - GroundingDINO: {len(result_gdino.instances)} detections") + print(f" - OWL-ViT v1: {len(result_owlv1.instances)} detections") + print(f" - OWL-ViT v2: {len(result_owlv2.instances)} detections") return result_gdino, result_owlv1, result_owlv2 @@ -238,17 +238,18 @@ def main(): example_model_comparison() print("\n" + "=" * 70) - print("βœ“ All examples completed successfully!") + print(" All examples completed successfully!") print("=" * 70) print("\nNext steps:") print(" 1. Check the output images in examples/images/") print(" 2. Try with your own images") print(" 3. Experiment with different text prompts") - print(" 4. Explore the GroundingDINOβ†’SAM pipeline: examples/segment/grounding_sam_pipeline.py") + print(" 4. Explore the GroundingDINO then SAM pipeline: examples/segment/grounding_sam_pipeline.py") print() + print("Done") except Exception as e: - print(f"\nβœ— Error: {e}", file=sys.stderr) + print(f"\n Error: {e}", file=sys.stderr) import traceback traceback.print_exc() sys.exit(1) diff --git a/examples/graph/README.md b/examples/graph/README.md index 691a599..cf98aa0 100644 --- a/examples/graph/README.md +++ b/examples/graph/README.md @@ -18,7 +18,7 @@ These examples demonstrate the fundamental graph system capabilities: | Example | Description | Key Features | | ------------------------------------------- | -------------------------------------------- | ------------------------------------------------ | -| βœ… [simple_pipeline.py](simple_pipeline.py) | Detection β†’ Filter β†’ Segmentation β†’ Fuse | `mata.infer()`, `Graph.then()`, basic pipeline | +| βœ… [simple_pipeline.py](simple_pipeline.py) | Detection > Filter > Segmentation > Fuse | `mata.infer()`, `Graph.then()`, basic pipeline | | [parallel_tasks.py](parallel_tasks.py) | Parallel detection + classification + depth | `Graph.parallel()`, `ParallelScheduler`, speedup | | [video_tracking.py](video_tracking.py) | Video processing with object tracking | `VideoProcessor`, `Track`, frame policies | | [vlm_workflows.py](vlm_workflows.py) | VLM grounded detection & scene understanding | `VLMDetect`, `PromoteEntities`, VLM presets | diff --git a/examples/graph/presets_demo.py b/examples/graph/presets_demo.py index 91cc225..579539c 100644 --- a/examples/graph/presets_demo.py +++ b/examples/graph/presets_demo.py @@ -53,19 +53,21 @@ def create_mock_providers(): # Classifier mock_classifier = Mock() mock_classifier.predict = Mock(return_value=ClassifyResult( - classifications=[ - Classification(label="outdoor", score=0.82), - Classification(label="indoor", score=0.18), + predictions=[ + Classification(label=0, score=0.82, label_name="outdoor"), + Classification(label=1, score=0.18, label_name="indoor"), ], meta={"model": "mock-clip"}, )) + mock_classifier.classify = mock_classifier.predict # Depth estimator mock_depth = Mock() mock_depth.predict = Mock(return_value=DepthResult( - depth_map=np.random.rand(480, 640).astype(np.float32), + depth=np.random.rand(480, 640).astype(np.float32), meta={"model": "mock-depth"}, )) + mock_depth.estimate = mock_depth.predict # Tracker mock_tracker = Mock() @@ -83,7 +85,7 @@ def create_mock_providers(): "detector": mock_detector, "segmenter": mock_segmenter, "classifier": mock_classifier, - "depth_estimator": mock_depth, + "depth": mock_depth, "tracker": mock_tracker, "vlm": mock_vlm, } @@ -122,7 +124,7 @@ def main(): ) print(f" Graph: {g1.name}, Nodes: {len(g1._nodes)}") print(" Usage: detector + segmenter providers") - print(" Flow: Detect β†’ Filter β†’ [NMS] β†’ PromptBoxes β†’ RefineMask β†’ Fuse") + print(" Flow: Detect > Filter > [NMS] > PromptBoxes > RefineMask > Fuse") # ----------------------------------------------------------------------- # Preset 2: Segment and Refine (Segment Everything) @@ -131,7 +133,7 @@ def main(): g2 = segment_and_refine() print(f" Graph: {g2.name}, Nodes: {len(g2._nodes)}") print(" Usage: segmenter provider") - print(" Flow: SegmentEverything β†’ RefineMask β†’ Fuse") + print(" Flow: SegmentEverything > RefineMask > Fuse") # ----------------------------------------------------------------------- # Preset 3: Detection + Pose @@ -144,7 +146,7 @@ def main(): ) print(f" Graph: {g3.name}, Nodes: {len(g3._nodes)}") print(" Usage: detector provider") - print(" Flow: Detect β†’ Filter β†’ [NMS] β†’ [TopK] β†’ Fuse") + print(" Flow: Detect > Filter > [NMS] > [TopK] > Fuse") # ----------------------------------------------------------------------- # Preset 4: Full Scene Analysis (Parallel) @@ -155,7 +157,7 @@ def main(): ) print(f" Graph: {g4.name}, Nodes: {len(g4._nodes)}") print(" Usage: detector + classifier + depth_estimator providers") - print(" Flow: parallel(Detect, Classify, EstimateDepth) β†’ Filter β†’ Fuse") + print(" Flow: parallel(Detect, Classify, EstimateDepth) > Filter > Fuse") # ----------------------------------------------------------------------- # Preset 5: Detection + Tracking @@ -163,13 +165,13 @@ def main(): print("\n=== Preset 5: detect_and_track ===") g5 = detect_and_track( detection_threshold=0.5, - track_thresh=0.4, + track_threshold=0.4, track_buffer=30, - match_thresh=0.8, + match_threshold=0.8, ) print(f" Graph: {g5.name}, Nodes: {len(g5._nodes)}") print(" Usage: detector + tracker providers") - print(" Flow: Detect β†’ Filter β†’ Track β†’ Fuse") + print(" Flow: Detect > Filter > Track > Fuse") # ----------------------------------------------------------------------- # Preset 6: VLM Grounded Detection @@ -178,7 +180,7 @@ def main(): g6 = vlm_grounded_detection() print(f" Graph: {g6.name}, Nodes: {len(g6._nodes)}") print(" Usage: vlm + detector providers") - print(" Flow: parallel(VLMDetect, Detect) β†’ Filter β†’ PromoteEntities β†’ Fuse") + print(" Flow: parallel(VLMDetect, Detect) > Filter > PromoteEntities > Fuse") # ----------------------------------------------------------------------- # Preset 7: VLM Scene Understanding @@ -187,7 +189,7 @@ def main(): g7 = vlm_scene_understanding() print(f" Graph: {g7.name}, Nodes: {len(g7._nodes)}") print(" Usage: vlm + detector + depth_estimator providers") - print(" Flow: parallel(VLMDescribe, Detect, EstimateDepth) β†’ Fuse") + print(" Flow: parallel(VLMDescribe, Detect, EstimateDepth) > Fuse") # ----------------------------------------------------------------------- # Preset 8: VLM Multi-Image Comparison @@ -196,29 +198,43 @@ def main(): g8 = vlm_multi_image_comparison() print(f" Graph: {g8.name}, Nodes: {len(g8._nodes)}") print(" Usage: vlm provider") - print(" Flow: VLMQuery β†’ Fuse") + print(" Flow: VLMQuery > Fuse") # ----------------------------------------------------------------------- # Execute a preset with mata.infer() # ----------------------------------------------------------------------- print("\n=== Executing a Preset ===") - if "--real" not in sys.argv: - import mata + import mata + if "--real" in sys.argv: + print(" Loading real models (this may take a moment)...") + detector = mata.load("detect", "facebook/detr-resnet-50") + classifier = mata.load("classify", "openai/clip-vit-base-patch32") + depth = mata.load("depth", "depth-anything/Depth-Anything-V2-Small-hf") + real_providers = { + "detector": detector, + "classifier": classifier, + "depth": depth, + } + result = mata.infer( + image="examples/images/000000039769.jpg", + graph=g4, # full_scene_analysis preset + providers=real_providers, + ) + else: + print(" (Running with mock providers β€” use --real for actual models)") result = mata.infer( image="examples/images/000000039769.jpg", graph=g4, # full_scene_analysis preset providers=providers, ) - print("full_scene_analysis result:") - print(f" Channels: {list(result.channels.keys())}") - if result.has_channel("detections"): - dets = result.get_channel("detections") - print(f" Detections: {len(dets.instances)} objects") - else: - print(" (Use --real flag with actual model providers)") + print("full_scene_analysis result:") + print(f" Channels: {list(result.channels.keys())}") + if result.has_channel("dets"): + dets = result.get_channel("dets") + print(f" Detections: {len(dets.instances)} objects") # ----------------------------------------------------------------------- # Quick reference diff --git a/examples/graph/scenarios/README.md b/examples/graph/scenarios/README.md index 6b82429..567a411 100644 --- a/examples/graph/scenarios/README.md +++ b/examples/graph/scenarios/README.md @@ -27,53 +27,53 @@ python retail_shelf_analysis.py --real shelf_image.jpg | Script | Problem | Models | Graph Pattern | | ------------------------------------ | ----------------------------------------- | -------------------- | ------------------------- | -| `manufacturing_defect_classify.py` | Surface defect detection & classification | GroundingDINO + CLIP | Detect β†’ ROI β†’ Classify | -| `manufacturing_defect_segment.py` | Defect segmentation & measurement | GroundingDINO + SAM | Detect β†’ Segment β†’ Refine | -| `manufacturing_assembly_verify.py` | Assembly verification with VLM | Qwen3-VL + DETR | VLM β€– Detect β†’ Fuse | -| `manufacturing_component_inspect.py` | Per-component VLM inspection | DETR + Qwen3-VL | Detect β†’ ROI β†’ VLM | +| `manufacturing_defect_classify.py` | Surface defect detection & classification | GroundingDINO + CLIP | Detect > ROI > Classify | +| `manufacturing_defect_segment.py` | Defect segmentation & measurement | GroundingDINO + SAM | Detect > Segment > Refine | +| `manufacturing_assembly_verify.py` | Assembly verification with VLM | Qwen3-VL + DETR | VLM β€– Detect > Fuse | +| `manufacturing_component_inspect.py` | Per-component VLM inspection | DETR + Qwen3-VL | Detect > ROI > VLM | ### πŸ›’ Retail (3 scenarios) | Script | Problem | Models | Graph Pattern | | -------------------------- | ---------------------------------------- | ---------------------- | ----------------------------- | -| `retail_shelf_analysis.py` | Product detection + brand classification | Faster R-CNN + CLIP | Detect β†’ NMS β†’ ROI β†’ Classify | -| `retail_product_search.py` | Zero-shot product search & segmentation | GroundingDINO + SAM | Detect β†’ Segment | +| `retail_shelf_analysis.py` | Product detection + brand classification | Faster R-CNN + CLIP | Detect > NMS > ROI > Classify | +| `retail_product_search.py` | Zero-shot product search & segmentation | GroundingDINO + SAM | Detect > Segment | | `retail_stock_level.py` | Multi-modal stock assessment | Qwen3-VL + DETR + CLIP | VLM β€– Detect β€– Classify | ### πŸš— Autonomous Driving (4 scenarios) | Script | Problem | Models | Graph Pattern | | -------------------------------- | ------------------------------------- | ---------------------------------------- | ------------------------- | -| `driving_distance_estimation.py` | Vehicle distance estimation | DETR + Depth Anything | Detect β€– Depth β†’ Fuse | -| `driving_road_scene.py` | Complete road scene analysis | 4 models (detect/segment/depth/classify) | 4-way parallel β†’ Fuse | -| `driving_traffic_tracking.py` | Traffic object tracking | RT-DETR + BYTETrack | Detect β†’ Track β†’ Annotate | +| `driving_distance_estimation.py` | Vehicle distance estimation | DETR + Depth Anything | Detect β€– Depth > Fuse | +| `driving_road_scene.py` | Complete road scene analysis | 4 models (detect/segment/depth/classify) | 4-way parallel > Fuse | +| `driving_traffic_tracking.py` | Traffic object tracking | RT-DETR + BYTETrack | Detect > Track > Annotate | | `driving_obstacle_vlm.py` | Obstacle detection with VLM reasoning | Qwen3-VL + GroundingDINO + Depth | VLM β€– Detect β€– Depth | ### πŸ”’ Security/Surveillance (3 scenarios) -| Script | Problem | Models | Graph Pattern | -| -------------------------------------- | ------------------------------------------ | ------------------------------ | ------------------------ | -| `security_crowd_monitoring.py` | Person detection + tracking | DETR + BYTETrack | Detect β†’ Filter β†’ Track | -| `security_suspicious_object.py` | Suspicious object detection + VLM analysis | GroundingDINO + SAM + Qwen3-VL | Detect β†’ Segment β†’ VLM | -| `security_situational_awareness.py` | Situational awareness monitoring | Qwen3-VL + GroundingDINO | VLM β†’ PromoteEntities | +| Script | Problem | Models | Graph Pattern | +| ----------------------------------- | ------------------------------------------ | ------------------------------ | ----------------------- | +| `security_crowd_monitoring.py` | Person detection + tracking | DETR + BYTETrack | Detect > Filter > Track | +| `security_suspicious_object.py` | Suspicious object detection + VLM analysis | GroundingDINO + SAM + Qwen3-VL | Detect > Segment > VLM | +| `security_situational_awareness.py` | Situational awareness monitoring | Qwen3-VL + GroundingDINO | VLM > PromoteEntities | ### 🌾 Agriculture (3 scenarios) -| Script | Problem | Models | Graph Pattern | -| ---------------------------------- | --------------------------------- | ------------------------- | -------------------------- | -| `agriculture_disease_classify.py` | Crop disease classification | GroundingDINO + CLIP | Detect β†’ ROI β†’ Classify | -| `agriculture_aerial_crop.py` | Aerial crop segmentation | Mask2Former + Depth | Segment β€– Depth β†’ Fuse | -| `agriculture_pest_detection.py` | Pest detection & area mapping | GroundingDINO + SAM | Detect β†’ Segment | +| Script | Problem | Models | Graph Pattern | +| --------------------------------- | ----------------------------- | -------------------- | ----------------------- | +| `agriculture_disease_classify.py` | Crop disease classification | GroundingDINO + CLIP | Detect > ROI > Classify | +| `agriculture_aerial_crop.py` | Aerial crop segmentation | Mask2Former + Depth | Segment β€– Depth > Fuse | +| `agriculture_pest_detection.py` | Pest detection & area mapping | GroundingDINO + SAM | Detect > Segment | ### πŸ₯ Healthcare (3 scenarios) ⚠️ **DISCLAIMER**: Research and demonstration purposes only. NOT for clinical use. -| Script | Problem | Models | Graph Pattern | -| --------------------------------- | ------------------------------------ | ----------------------------- | ------------------------------- | -| `medical_roi_segmentation.py` | ROI segmentation & measurement | GroundingDINO + SAM | Detect β†’ Segment | -| `medical_report_generation.py` | Medical image report generation | Qwen3-VL | VLM β†’ Fuse | -| `medical_pathology_triage.py` | Pathology triage workflow | DETR + CLIP + Qwen3-VL | Detect β†’ ROI β†’ Classify β†’ VLM | +| Script | Problem | Models | Graph Pattern | +| ------------------------------ | ------------------------------- | ---------------------- | ----------------------------- | +| `medical_roi_segmentation.py` | ROI segmentation & measurement | GroundingDINO + SAM | Detect > Segment | +| `medical_report_generation.py` | Medical image report generation | Qwen3-VL | VLM > Fuse | +| `medical_pathology_triage.py` | Pathology triage workflow | DETR + CLIP + Qwen3-VL | Detect > ROI > Classify > VLM | ## Usage Patterns @@ -94,7 +94,7 @@ Real-World Problem: Retailers need automated shelf monitoring... Graph Flow: - Detect β†’ Filter β†’ NMS β†’ ExtractROIs β†’ Classify β†’ Fuse + Detect > Filter > NMS > ExtractROIs > Classify > Fuse βœ“ Graph 'shelf_product_analysis' constructed with 6 nodes ``` @@ -140,7 +140,7 @@ Examples combine multiple vision tasks for comprehensive insights: ```python # VLM + Detection + Classification in parallel -stock_level_analysis() # β†’ semantic + quantitative + categorical +stock_level_analysis() # > semantic + quantitative + categorical ``` ### 3. Preset Reuse @@ -186,7 +186,7 @@ scenarios/ **Beginners**: -1. Start with `retail_shelf_analysis.py` β€” simple detect β†’ classify pattern +1. Start with `retail_shelf_analysis.py` β€” simple detect > classify pattern 2. Try `retail_product_search.py` β€” learn zero-shot detection 3. Explore `manufacturing_defect_classify.py` β€” see domain transfer diff --git a/examples/graph/scenarios/agriculture_aerial_crop.py b/examples/graph/scenarios/agriculture_aerial_crop.py index 1da5b68..7faad68 100644 --- a/examples/graph/scenarios/agriculture_aerial_crop.py +++ b/examples/graph/scenarios/agriculture_aerial_crop.py @@ -18,7 +18,7 @@ - Depth: depth-anything/Depth-Anything-V2-Small-hf (monocular depth estimation) Graph Flow: - Parallel(SegmentImage, EstimateDepth) β†’ Fuse + Parallel(SegmentImage, EstimateDepth) > Fuse Usage: python agriculture_aerial_crop.py # Mock mode @@ -59,12 +59,12 @@ def main(): print(" Precision agriculture needs crop distribution + terrain topology") print(" from aerial imagery for optimized resource planning.") print() - print("Graph: Parallel(SegmentImage, EstimateDepth) β†’ Fuse") + print("Graph: Parallel(SegmentImage, EstimateDepth) > Fuse") print("Models: Mask2Former (segmenter) + Depth Anything (depth)") print() print("Expected output structure:") - print(" result['final'].instances β†’ segmented crop regions with masks") - print(" result['final'].depth β†’ terrain depth map (H, W) array") + print(" result['final'].instances > segmented crop regions with masks") + print(" result['final'].depth > terrain depth map (H, W) array") print(" Enables: crop coverage area, terrain-aware irrigation/spraying") print() diff --git a/examples/graph/scenarios/agriculture_disease_classify.py b/examples/graph/scenarios/agriculture_disease_classify.py index 3631c29..e01c1b6 100644 --- a/examples/graph/scenarios/agriculture_disease_classify.py +++ b/examples/graph/scenarios/agriculture_disease_classify.py @@ -19,7 +19,7 @@ - Classifier: openai/clip-vit-base-patch32 (zero-shot classification) Graph Flow: - Detect("diseased leaf . pest damage . healthy leaf") β†’ Filter β†’ ExtractROIs β†’ Classify β†’ Fuse + Detect("diseased leaf . pest damage . healthy leaf") > Filter > ExtractROIs > Classify > Fuse Usage: python agriculture_disease_classify.py # Mock mode @@ -60,13 +60,13 @@ def main(): print(" Farmers need automated disease detection to prevent crop loss.") print(" Manual inspection is slow and inconsistent across large fields.") print() - print("Graph: Detect β†’ Filter β†’ ExtractROIs β†’ Classify β†’ Fuse") + print("Graph: Detect > Filter > ExtractROIs > Classify > Fuse") print("Models: GroundingDINO (detector) + CLIP (classifier)") print() print("Expected output structure:") - print(" result['final'].instances β†’ list of leaf regions with disease classifications") - print(" result['final'].rois β†’ cropped images of each detected leaf") - print(" result['final'].classifications β†’ disease type per crop") + print(" result['final'].instances > list of leaf regions with disease classifications") + print(" result['final'].rois > cropped images of each detected leaf") + print(" result['final'].classifications > disease type per crop") print() # Verify preset construction diff --git a/examples/graph/scenarios/agriculture_pest_detection.py b/examples/graph/scenarios/agriculture_pest_detection.py index 990206a..800ae60 100644 --- a/examples/graph/scenarios/agriculture_pest_detection.py +++ b/examples/graph/scenarios/agriculture_pest_detection.py @@ -18,7 +18,7 @@ - Segmenter: facebook/sam-vit-base (prompt-based segmentation) Graph Flow: - Detect("insect . pest . caterpillar . aphid . beetle") β†’ Filter β†’ PromptBoxes(SAM) β†’ RefineMask β†’ Fuse + Detect("insect . pest . caterpillar . aphid . beetle") > Filter > PromptBoxes(SAM) > RefineMask > Fuse Usage: python agriculture_pest_detection.py # Mock mode @@ -58,11 +58,11 @@ def main(): print(" Early pest detection is critical for crop protection.") print(" Need to identify locations, measure affected areas, and prioritize treatment.") print() - print("Graph: Detect β†’ Filter β†’ PromptBoxes(SAM) β†’ RefineMask β†’ Fuse") + print("Graph: Detect > Filter > PromptBoxes(SAM) > RefineMask > Fuse") print("Models: GroundingDINO (detector) + SAM (segmenter)") print() print("Expected output structure:") - print(" result['final'].instances β†’ pest detections with precise segmentation masks") + print(" result['final'].instances > pest detections with precise segmentation masks") print(" Each instance has bbox + mask for area measurement") print(" Enables: targeted pesticide application, infestation monitoring") print() diff --git a/examples/graph/scenarios/driving_distance_estimation.py b/examples/graph/scenarios/driving_distance_estimation.py index 1c28469..2da2f5a 100644 --- a/examples/graph/scenarios/driving_distance_estimation.py +++ b/examples/graph/scenarios/driving_distance_estimation.py @@ -17,7 +17,7 @@ - Depth: depth-anything/Depth-Anything-V2-Small-hf (monocular depth) Graph Flow: - Parallel(Detect, EstimateDepth) β†’ Filter(vehicle classes) β†’ Fuse + Parallel(Detect, EstimateDepth) > Filter(vehicle classes) > Fuse Usage: python driving_distance_estimation.py # Mock mode @@ -61,12 +61,12 @@ def main(): else: print("=== Autonomous Driving: Vehicle Distance Estimation (Mock) ===") print() - print("Graph: Parallel(Detect, EstimateDepth) β†’ Filter β†’ Fuse") + print("Graph: Parallel(Detect, EstimateDepth) > Filter > Fuse") print("Models: DETR (detector) + Depth Anything (depth)") print() print("Expected output structure:") - print(" result['final'].instances β†’ detected vehicles/pedestrians with bboxes") - print(" result['final'].depth β†’ depth map (H, W) numpy array") + print(" result['final'].instances > detected vehicles/pedestrians with bboxes") + print(" result['final'].depth > depth map (H, W) numpy array") print(" Distance correlation: sample depth values at bbox centers") print() diff --git a/examples/graph/scenarios/driving_obstacle_vlm.py b/examples/graph/scenarios/driving_obstacle_vlm.py index 21b53bc..97f74a2 100644 --- a/examples/graph/scenarios/driving_obstacle_vlm.py +++ b/examples/graph/scenarios/driving_obstacle_vlm.py @@ -23,7 +23,7 @@ - Depth: depth-anything/Depth-Anything-V2-Small-hf Graph Flow: - Parallel(VLMDescribe, Detect, EstimateDepth) β†’ Fuse + Parallel(VLMDescribe, Detect, EstimateDepth) > Fuse (Reuses vlm_scene_understanding preset with custom prompt) Usage: @@ -80,7 +80,7 @@ def main(): else: print("=== Autonomous Driving: Obstacle Detection with VLM Reasoning (Mock) ===") print() - print("Graph: Parallel(VLMDescribe, Detect, EstimateDepth) β†’ Fuse") + print("Graph: Parallel(VLMDescribe, Detect, EstimateDepth) > Fuse") print("Models: Qwen3-VL (reasoning) + DETR (detection) + Depth Anything") print() print("Key Design Pattern: REUSING EXISTING PRESET") @@ -90,9 +90,9 @@ def main(): print(" only the prompt differs.") print() print("Expected output structure:") - print(" result['scene'].description β†’ VLM's road hazard analysis") - print(" result['scene'].instances β†’ detected objects") - print(" result['scene'].depth β†’ depth map for spatial awareness") + print(" result['scene'].description > VLM's road hazard analysis") + print(" result['scene'].instances > detected objects") + print(" result['scene'].depth > depth map for spatial awareness") print() print("VLM Capabilities for Driving:") print(" β€’ Identify unusual obstacles (debris, animals, construction)") diff --git a/examples/graph/scenarios/driving_road_scene.py b/examples/graph/scenarios/driving_road_scene.py index 633b64a..9b9804b 100644 --- a/examples/graph/scenarios/driving_road_scene.py +++ b/examples/graph/scenarios/driving_road_scene.py @@ -24,7 +24,7 @@ - Classifier: openai/clip-vit-base-patch32 (zero-shot) Graph Flow: - Parallel(Detect, SegmentImage, EstimateDepth, Classify) β†’ Filter β†’ Fuse + Parallel(Detect, SegmentImage, EstimateDepth, Classify) > Filter > Fuse Usage: python driving_road_scene.py # Mock mode @@ -77,14 +77,14 @@ def main(): else: print("=== Autonomous Driving: Comprehensive Road Scene Analysis (Mock) ===") print() - print("Graph: Parallel(Detect, SegmentImage, EstimateDepth, Classify) β†’ Filter β†’ Fuse") + print("Graph: Parallel(Detect, SegmentImage, EstimateDepth, Classify) > Filter > Fuse") print("Models: DETR + Mask2Former + Depth Anything + CLIP") print() print("Expected output structure:") - print(" result['final'].instances β†’ detected objects") - print(" result['final'].masks β†’ panoptic segmentation (road, sidewalk, sky)") - print(" result['final'].depth β†’ depth map for spatial context") - print(" result['final'].classifications β†’ scene type classification") + print(" result['final'].instances > detected objects") + print(" result['final'].masks > panoptic segmentation (road, sidewalk, sky)") + print(" result['final'].depth > depth map for spatial context") + print(" result['final'].classifications > scene type classification") print() print("This is the most comprehensive driving preset with 4 parallel tasks.") print() diff --git a/examples/graph/scenarios/driving_traffic_tracking.py b/examples/graph/scenarios/driving_traffic_tracking.py index 8dc8595..24421c7 100644 --- a/examples/graph/scenarios/driving_traffic_tracking.py +++ b/examples/graph/scenarios/driving_traffic_tracking.py @@ -19,7 +19,7 @@ - Tracker: ByteTrackWrapper (ByteTrack) or BotSortWrapper (BotSort + GMC) Graph Flow: - Detect β†’ Filter(vehicle classes) β†’ Track β†’ Annotate β†’ Fuse + Detect > Filter(vehicle classes) > Track > Annotate > Fuse Usage: python driving_traffic_tracking.py # Mock mode (simulates 30 frames) @@ -121,7 +121,7 @@ def main(): else: print("=== Autonomous Driving: Multi-Object Traffic Tracking (Mock) ===") print() - print("Graph: Detect β†’ Filter β†’ Track β†’ Annotate β†’ Fuse") + print("Graph: Detect > Filter > Track > Annotate > Fuse") print("Models: RT-DETR (detector) + ByteTrackWrapper or BotSortWrapper (tracker)") print() print("Tracker options:") @@ -129,9 +129,9 @@ def main(): print(" BotSortWrapper β€” GMC-enabled for panning/tilting cameras") print() print("Expected output structure (per frame):") - print(" result['final'].instances β†’ detected objects with track IDs") - print(" result['final'].tracks β†’ tracking state information") - print(" result['final'].image β†’ annotated frame with track visualizations") + print(" result['final'].instances > detected objects with track IDs") + print(" result['final'].tracks > tracking state information") + print(" result['final'].image > annotated frame with track visualizations") print() print("Video Processing Pattern:") print(" 1. Create tracker instance ONCE before video loop") diff --git a/examples/graph/scenarios/manufacturing_assembly_verify.py b/examples/graph/scenarios/manufacturing_assembly_verify.py index f956c9c..9dd6e5b 100644 --- a/examples/graph/scenarios/manufacturing_assembly_verify.py +++ b/examples/graph/scenarios/manufacturing_assembly_verify.py @@ -18,7 +18,7 @@ - Detector: facebook/detr-resnet-50 (component detection) Graph Flow: - Parallel(VLMQuery, Detect) β†’ Filter β†’ Fuse + Parallel(VLMQuery, Detect) > Filter > Fuse Usage: python manufacturing_assembly_verify.py # Mock mode @@ -70,12 +70,12 @@ def main(): else: print("=== Manufacturing: Assembly Verification with VLM (Mock) ===") print() - print("Graph: Parallel(VLMQuery, Detect) β†’ Filter β†’ Fuse") + print("Graph: Parallel(VLMQuery, Detect) > Filter > Fuse") print("Models: Qwen3-VL (VLM) + DETR (detector)") print() print("Expected output structure:") - print(" result['vlm_assessment'] β†’ VLM's holistic inspection report") - print(" result['final'].instances β†’ detected components with counts") + print(" result['vlm_assessment'] > VLM's holistic inspection report") + print(" result['final'].instances > detected components with counts") print() print("Example VLM output:") print(' "Assembly appears complete. All 4 screws present and torqued correctly."') diff --git a/examples/graph/scenarios/manufacturing_component_inspect.py b/examples/graph/scenarios/manufacturing_component_inspect.py index 9d9375c..ea68731 100644 --- a/examples/graph/scenarios/manufacturing_component_inspect.py +++ b/examples/graph/scenarios/manufacturing_component_inspect.py @@ -17,7 +17,7 @@ - VLM: Qwen/Qwen3-VL-2B-Instruct (detailed per-component inspection) Graph Flow: - Detect β†’ Filter β†’ ExtractROIs β†’ VLMQuery β†’ Fuse + Detect > Filter > ExtractROIs > VLMQuery > Fuse Usage: python manufacturing_component_inspect.py # Mock mode @@ -69,12 +69,12 @@ def main(): else: print("=== Manufacturing: Per-Component Detailed Inspection (Mock) ===") print() - print("Graph: Detect β†’ Filter β†’ ExtractROIs β†’ VLMQuery β†’ Fuse") + print("Graph: Detect > Filter > ExtractROIs > VLMQuery > Fuse") print("Models: DETR (detector) + Qwen3-VL (VLM)") print() print("Expected output structure:") - print(" result['final'].instances β†’ detected components") - print(" result['final'].rois β†’ cropped images of each component") + print(" result['final'].instances > detected components") + print(" result['final'].rois > cropped images of each component") print(" Each component gets individual VLM inspection report:") print(' "Component shows minor surface wear on upper edge. No critical defects."') print(' "Excellent condition. No visible defects or contamination."') diff --git a/examples/graph/scenarios/manufacturing_defect_classify.py b/examples/graph/scenarios/manufacturing_defect_classify.py index 543ab77..74a68c9 100644 --- a/examples/graph/scenarios/manufacturing_defect_classify.py +++ b/examples/graph/scenarios/manufacturing_defect_classify.py @@ -16,7 +16,7 @@ - Classifier: openai/clip-vit-base-patch32 (zero-shot classification) Graph Flow: - Detect("scratch . crack . dent") β†’ Filter β†’ ExtractROIs β†’ Classify β†’ Fuse + Detect("scratch . crack . dent") > Filter > ExtractROIs > Classify > Fuse Usage: python manufacturing_defect_classify.py # Mock mode @@ -53,13 +53,13 @@ def main(): else: print("=== Manufacturing: Defect Detection & Classification (Mock) ===") print() - print("Graph: Detect β†’ Filter β†’ ExtractROIs β†’ Classify β†’ Fuse") + print("Graph: Detect > Filter > ExtractROIs > Classify > Fuse") print("Models: GroundingDINO (detector) + CLIP (classifier)") print() print("Expected output structure:") - print(" result['final'].instances β†’ list of defect instances with bboxes") - print(" result['final'].rois β†’ cropped images of each defect") - print(" result['final'].classifications β†’ defect type per crop") + print(" result['final'].instances > list of defect instances with bboxes") + print(" result['final'].rois > cropped images of each defect") + print(" result['final'].classifications > defect type per crop") print() # Verify preset construction diff --git a/examples/graph/scenarios/manufacturing_defect_segment.py b/examples/graph/scenarios/manufacturing_defect_segment.py index b31c638..537c79c 100644 --- a/examples/graph/scenarios/manufacturing_defect_segment.py +++ b/examples/graph/scenarios/manufacturing_defect_segment.py @@ -16,7 +16,7 @@ - Segmenter: facebook/sam-vit-base (prompt-based segmentation) Graph Flow: - Detect("scratch . crack . dent") β†’ PromptBoxes(SAM) β†’ RefineMask β†’ MaskToBox + Detect("scratch . crack . dent") > PromptBoxes(SAM) > RefineMask > MaskToBox Usage: python manufacturing_defect_segment.py # Mock mode @@ -64,11 +64,11 @@ def main(): else: print("=== Manufacturing: Defect Segmentation & Area Measurement (Mock) ===") print() - print("Graph: Detect β†’ PromptBoxes(SAM) β†’ RefineMask β†’ MaskToBox") + print("Graph: Detect > PromptBoxes(SAM) > RefineMask > MaskToBox") print("Models: GroundingDINO (detector) + SAM (segmenter)") print() print("Expected output structure:") - print(" result['final'].instances β†’ list of defect instances") + print(" result['final'].instances > list of defect instances") print(" Each instance has:") print(" - bbox: bounding box derived from segmentation mask") print(" - mask: pixel-precise segmentation mask (RLE format)") diff --git a/examples/graph/scenarios/medical_pathology_triage.py b/examples/graph/scenarios/medical_pathology_triage.py index 4c053f3..077a5ff 100644 --- a/examples/graph/scenarios/medical_pathology_triage.py +++ b/examples/graph/scenarios/medical_pathology_triage.py @@ -30,7 +30,7 @@ - VLM: Qwen/Qwen3-VL-2B (detailed analysis for flagged regions) Graph Flow: - Detect β†’ Filter β†’ ExtractROIs β†’ Classify β†’ [Conditional VLM Query] β†’ Fuse + Detect > Filter > ExtractROIs > Classify > [Conditional VLM Query] > Fuse Usage: python medical_pathology_triage.py # Mock mode @@ -129,25 +129,25 @@ def main(): print("Complex Conditional Pipeline (Example-Only Pattern):") print() print(" Flow:") - print(" 1. Detect β†’ identify regions of interest") - print(" 2. ExtractROIs β†’ crop each region") - print(" 3. Classify β†’ triage as [normal, benign, atypical, uncertain]") + print(" 1. Detect > identify regions of interest") + print(" 2. ExtractROIs > crop each region") + print(" 3. Classify > triage as [normal, benign, atypical, uncertain]") print(" 4. Conditional Logic:") print(" - If 'atypical' or 'uncertain' score > 0.3:") - print(" β†’ Flag for VLM detailed analysis (research review queue)") + print(" > Flag for VLM detailed analysis (research review queue)") print(" - Otherwise:") - print(" β†’ Mark as routine") + print(" > Mark as routine") print() - print("Graph: Detect β†’ Filter β†’ ExtractROIs β†’ Classify β†’ [Conditional] β†’ Fuse") + print("Graph: Detect > Filter > ExtractROIs > Classify > [Conditional] > Fuse") print("Models: DETR (detector) + CLIP (classifier) + Qwen3-VL (optional VLM)") print() print("This demonstrates the most complex scenario: conditional branching") print("with real-world research triage logic for prioritizing expert review.") print() print("Expected output structure:") - print(" result['final'].instances β†’ detected regions") - print(" result['final'].rois β†’ cropped region images") - print(" result['final'].classifications β†’ triage classifications") + print(" result['final'].instances > detected regions") + print(" result['final'].rois > cropped region images") + print(" result['final'].classifications > triage classifications") print(" Conditional logic identifies which regions need expert review") print() diff --git a/examples/graph/scenarios/medical_report_generation.py b/examples/graph/scenarios/medical_report_generation.py index cc11e7d..f3c39a5 100644 --- a/examples/graph/scenarios/medical_report_generation.py +++ b/examples/graph/scenarios/medical_report_generation.py @@ -24,7 +24,7 @@ - Depth: depth-anything/Depth-Anything-V2-Small-hf (depth estimation) Graph Flow: - Parallel(VLMDescribe, Detect, EstimateDepth) β†’ Filter β†’ Fuse + Parallel(VLMDescribe, Detect, EstimateDepth) > Filter > Fuse Usage: python medical_report_generation.py # Mock mode @@ -92,14 +92,14 @@ def main(): print("⚠️ DISCLAIMER: Research and demonstration purposes only.") print(" NOT for clinical diagnosis or treatment decisions.") print() - print("Graph: Parallel(VLMDescribe, Detect, EstimateDepth) β†’ Filter β†’ Fuse") + print("Graph: Parallel(VLMDescribe, Detect, EstimateDepth) > Filter > Fuse") print("Models: Qwen3-VL + DETR + Depth Anything") print("VLM prompt: Medical abnormality description") print() print("Expected output structure:") - print(" result['description'] β†’ VLM natural language description") - print(" result['final'].instances β†’ detected objects with bboxes") - print(" result['depth'] β†’ depth map for spatial understanding") + print(" result['description'] > VLM natural language description") + print(" result['final'].instances > detected objects with bboxes") + print(" result['depth'] > depth map for spatial understanding") print() print("This preset demonstrates reusing vlm_scene_understanding()") print("with domain-specific medical prompts for research applications.") diff --git a/examples/graph/scenarios/medical_roi_segmentation.py b/examples/graph/scenarios/medical_roi_segmentation.py index 0b07f8c..4c4f17e 100644 --- a/examples/graph/scenarios/medical_roi_segmentation.py +++ b/examples/graph/scenarios/medical_roi_segmentation.py @@ -22,7 +22,7 @@ - Segmenter: facebook/sam-vit-base (prompt-based segmentation) Graph Flow: - Detect("lesion . nodule . mass") β†’ Filter β†’ PromptBoxes(SAM) β†’ RefineMask β†’ Fuse + Detect("lesion . nodule . mass") > Filter > PromptBoxes(SAM) > RefineMask > Fuse Usage: python medical_roi_segmentation.py # Mock mode @@ -87,12 +87,12 @@ def main(): print("⚠️ DISCLAIMER: Research and demonstration purposes only.") print(" NOT for clinical diagnosis or treatment decisions.") print() - print("Graph: Detect β†’ Filter β†’ PromptBoxes(SAM) β†’ RefineMask β†’ Fuse") + print("Graph: Detect > Filter > PromptBoxes(SAM) > RefineMask > Fuse") print("Models: GroundingDINO (detector) + SAM (segmenter)") print("Text prompts: 'lesion . nodule . mass . abnormality'") print() print("Expected output structure:") - print(" result['final'].instances β†’ list of ROI instances") + print(" result['final'].instances > list of ROI instances") print(" Each instance has:") print(" - bbox: bounding box derived from segmentation mask") print(" - mask: pixel-precise segmentation mask (RLE format)") diff --git a/examples/graph/scenarios/retail_product_search.py b/examples/graph/scenarios/retail_product_search.py index ecc7e27..b539e8a 100644 --- a/examples/graph/scenarios/retail_product_search.py +++ b/examples/graph/scenarios/retail_product_search.py @@ -19,7 +19,7 @@ - Segmenter: facebook/sam-vit-base (segment-anything) Graph Flow: - Detect(text_prompts) β†’ Filter β†’ PromptBoxes(SAM) β†’ RefineMask β†’ Fuse + Detect(text_prompts) > Filter > PromptBoxes(SAM) > RefineMask > Fuse Use Cases: - Visual product search in inventory photos @@ -103,7 +103,7 @@ def main(): print(" for precise boundary extraction.") print() print("Graph Flow:") - print(" Detect(text_prompts) β†’ Filter β†’ PromptBoxes(SAM) β†’ RefineMask β†’ Fuse") + print(" Detect(text_prompts) > Filter > PromptBoxes(SAM) > RefineMask > Fuse") print() print("Models:") print(" - Detector: IDEA-Research/grounding-dino-tiny (zero-shot)") @@ -115,7 +115,7 @@ def main(): print(' "organic product . gluten_free label . sale tag"') print() print("Expected output structure:") - print(" result['final'].instances β†’ list of detected products") + print(" result['final'].instances > list of detected products") print(" - Each instance has: bbox, mask (segmentation), score, label_name") print(" Mask area can be used for shelf space allocation calculations") print() diff --git a/examples/graph/scenarios/retail_shelf_analysis.py b/examples/graph/scenarios/retail_shelf_analysis.py index b44bf05..2c6d9e9 100644 --- a/examples/graph/scenarios/retail_shelf_analysis.py +++ b/examples/graph/scenarios/retail_shelf_analysis.py @@ -17,7 +17,7 @@ - Classifier: openai/clip-vit-base-patch32 (zero-shot brand matching) Graph Flow: - Detect β†’ Filter β†’ NMS β†’ ExtractROIs β†’ Classify β†’ Fuse + Detect > Filter > NMS > ExtractROIs > Classify > Fuse Use Cases: - Planogram compliance verification @@ -102,17 +102,17 @@ def main(): print(" 4. CLIP classifies each crop into brand/category") print() print("Graph Flow:") - print(" Detect β†’ Filter β†’ NMS β†’ ExtractROIs β†’ Classify β†’ Fuse") + print(" Detect > Filter > NMS > ExtractROIs > Classify > Fuse") print() print("Models:") print(" - Detector: torchvision/fasterrcnn_resnet50_fpn_v2") print(" - Classifier: openai/clip-vit-base-patch32") print() print("Expected output structure:") - print(" result['final'].instances β†’ list of product instances") + print(" result['final'].instances > list of product instances") print(" - Each instance has: bbox, label_name (brand), score") - print(" result['rois'] β†’ cropped images of each product") - print(" result['classes'] β†’ classification results per crop") + print(" result['rois'] > cropped images of each product") + print(" result['classes'] > classification results per crop") print() # Verify preset construction diff --git a/examples/graph/scenarios/retail_stock_level.py b/examples/graph/scenarios/retail_stock_level.py index 112e3c4..b3120ca 100644 --- a/examples/graph/scenarios/retail_stock_level.py +++ b/examples/graph/scenarios/retail_stock_level.py @@ -24,7 +24,7 @@ - Classifier: openai/clip-vit-base-patch32 (stock level categorization) Graph Flow: - Parallel(VLMDescribe, Detect, Classify) β†’ Filter β†’ Fuse + Parallel(VLMDescribe, Detect, Classify) > Filter > Fuse Use Cases: - Automated restocking alerts @@ -141,7 +141,7 @@ def main(): print(" 3. CLIP categories stock level (fully/partially/low/empty)") print() print("Graph Flow:") - print(" Parallel(VLMDescribe, Detect, Classify) β†’ Filter β†’ Fuse") + print(" Parallel(VLMDescribe, Detect, Classify) > Filter > Fuse") print() print("Models:") print(" - VLM: Qwen/Qwen3-VL-2B-Instruct") @@ -149,9 +149,9 @@ def main(): print(" - Classifier: openai/clip-vit-base-patch32") print() print("Expected output structure:") - print(" result['final'].meta['vlm_description'] β†’ semantic assessment") - print(" result['final'].instances β†’ detected products (count)") - print(" result['final'].meta['classifications'] β†’ stock level category") + print(" result['final'].meta['vlm_description'] > semantic assessment") + print(" result['final'].instances > detected products (count)") + print(" result['final'].meta['classifications'] > stock level category") print() print("Benefits:") print(" β€’ Quantitative: Exact product count from detector") diff --git a/examples/graph/scenarios/security_crowd_monitoring.py b/examples/graph/scenarios/security_crowd_monitoring.py index a0a3c6c..13a6e1b 100644 --- a/examples/graph/scenarios/security_crowd_monitoring.py +++ b/examples/graph/scenarios/security_crowd_monitoring.py @@ -19,7 +19,7 @@ - Tracker: ByteTrackWrapper (ByteTrack) or BotSortWrapper (BotSort + GMC) Graph Flow: - Detect β†’ Filter(person) β†’ Track β†’ Annotate β†’ Fuse + Detect > Filter(person) > Track > Annotate > Fuse Usage: python security_crowd_monitoring.py # Mock mode @@ -84,7 +84,7 @@ def main(): else: print("=== Security: Crowd Monitoring with Alert System (Mock) ===") print() - print("Graph: Detect β†’ Filter(person) β†’ Track β†’ Annotate β†’ Fuse") + print("Graph: Detect > Filter(person) > Track > Annotate > Fuse") print("Models: DETR (detector) + ByteTrackWrapper or BotSortWrapper (tracker)") print() print("Tracker options:") @@ -92,7 +92,7 @@ def main(): print(" BotSortWrapper β€” GMC-enabled, better for panning/zooming cameras") print() print("Expected output structure:") - print(" result['final'].instances β†’ list of tracked person instances") + print(" result['final'].instances > list of tracked person instances") print(" Each instance has:") print(" - track_id: unique persistent ID across frames") print(" - bbox: bounding box coordinates") diff --git a/examples/graph/scenarios/security_situational_awareness.py b/examples/graph/scenarios/security_situational_awareness.py index 847e64f..6070883 100644 --- a/examples/graph/scenarios/security_situational_awareness.py +++ b/examples/graph/scenarios/security_situational_awareness.py @@ -27,7 +27,7 @@ - Depth: depth-anything/Depth-Anything-V2-Small-hf (spatial understanding) Graph Flow: - Parallel(VLMDescribe, Detect, EstimateDepth) β†’ Fuse + Parallel(VLMDescribe, Detect, EstimateDepth) > Fuse Usage: python security_situational_awareness.py # Mock mode @@ -115,7 +115,7 @@ def main(): else: print("=== Security: Comprehensive Situational Awareness (Mock) ===") print() - print("Graph: Parallel(VLMDescribe, Detect, EstimateDepth) β†’ Fuse") + print("Graph: Parallel(VLMDescribe, Detect, EstimateDepth) > Fuse") print("Models: Qwen3-VLM + DETR + Depth-Anything") print() print("Key Design Pattern: PRESET REUSABILITY") @@ -126,9 +126,9 @@ def main(): print("architecture now provides security-specific insights.") print() print("Expected output structure:") - print(" result['scene'].description β†’ VLM's security assessment") - print(" result['scene'].instances β†’ Detected objects (people, vehicles, etc.)") - print(" result['scene'].depth_map β†’ Spatial depth information") + print(" result['scene'].description > VLM's security assessment") + print(" result['scene'].instances > Detected objects (people, vehicles, etc.)") + print(" result['scene'].depth_map > Spatial depth information") print() print("Security use cases:") print(" - Perimeter security: Detect intrusions, unusual activity") diff --git a/examples/graph/scenarios/security_suspicious_object.py b/examples/graph/scenarios/security_suspicious_object.py index 52230ed..90564d1 100644 --- a/examples/graph/scenarios/security_suspicious_object.py +++ b/examples/graph/scenarios/security_suspicious_object.py @@ -24,7 +24,7 @@ - VLM: Qwen/Qwen3-VL-2B-Instruct (contextual reasoning) Graph Flow: - Detect β†’ Filter β†’ PromptBoxes(SAM) β†’ RefineMask β†’ VLMQuery β†’ Fuse + Detect > Filter > PromptBoxes(SAM) > RefineMask > VLMQuery > Fuse Usage: python security_suspicious_object.py # Mock mode @@ -78,13 +78,13 @@ def main(): else: print("=== Security: Suspicious Unattended Object Detection (Mock) ===") print() - print("Graph: Detect β†’ Filter β†’ PromptBoxes β†’ RefineMask β†’ VLMQuery β†’ Fuse") + print("Graph: Detect > Filter > PromptBoxes > RefineMask > VLMQuery > Fuse") print("Models: GroundingDINO + SAM + Qwen3-VL (3-model chain)") print() print("Expected output structure:") - print(" result['final'].instances β†’ list of detected suspicious objects") - print(" result['final'].masks β†’ precise segmentation masks for each object") - print(" result['final'].vlm_analysis β†’ VLM reasoning about each object") + print(" result['final'].instances > list of detected suspicious objects") + print(" result['final'].masks > precise segmentation masks for each object") + print(" result['final'].vlm_analysis > VLM reasoning about each object") print() print("Why 3 models?") print(" 1. GroundingDINO: Zero-shot detection via text prompts") @@ -98,7 +98,7 @@ def main(): print("Real-world deployment considerations:") print(" - Temporal analysis: Track if object remains unattended over time") print(" - Owner proximity detection: Use person detection + spatial proximity") - print(" - Alert escalation: High-risk objects β†’ immediate human review") + print(" - Alert escalation: High-risk objects > immediate human review") print() # Verify preset construction diff --git a/examples/graph/simple_pipeline.py b/examples/graph/simple_pipeline.py index 998b659..e31345d 100644 --- a/examples/graph/simple_pipeline.py +++ b/examples/graph/simple_pipeline.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Simple two-stage pipeline: Detect β†’ Filter β†’ Segment β†’ Fuse. +"""Simple two-stage pipeline: Detect > Filter > Segment > Fuse. Demonstrates the core MATA graph workflow: 1. Load provider models (detector + segmenter) @@ -22,7 +22,7 @@ # Setup paths IMAGE_DIR = Path(__file__).parent.parent / "images" IMAGE_1 = IMAGE_DIR / "000000039769.jpg" -IMAGE_2 = IMAGE_DIR / "hanvin-cheong-tuR2XRPdtYI-unsplash.jpg" +IMAGE_2 = IMAGE_DIR / "street_001.jpg" # --------------------------------------------------------------------------- # Mock providers for standalone demo (no model downloads required) @@ -127,12 +127,19 @@ def main(): # Access results via channel names print(f"Result type: {type(result).__name__}") print(f"Channels: {list(result.channels.keys())}") - if result.has_channel("detections"): - dets = result.get_channel("detections") - print(f"Detections: {len(dets.instances)} objects") - for inst in dets.instances: - print(f" - {inst.label_name}: score={inst.score:.2f}, bbox={inst.bbox}") - + if result.has_channel("final"): + final = result.get_channel("final") + if final.has_channel("detections"): + dets = final.get_channel("detections") + print(f"Final detections: {len(dets.instances)} objects") + for inst in dets.instances: + print(f" - {inst.label_name}: score={inst.score:.2f}, bbox={inst.bbox}") + if final.has_channel("masks"): + masks = final.get_channel("masks") + print(f"Final masks: {len(masks.instances)} objects") + for inst in masks.instances: + print(f" - {inst.label_name}: mask shape={inst.mask}") + # ----------------------------------------------------------------------- # Option B: Build a Graph with the fluent builder API # ----------------------------------------------------------------------- @@ -157,6 +164,25 @@ def main(): print(f"Graph '{graph.name}' completed") print(f"Channels: {list(result_b.channels.keys())}") + # ----------------------------------------------------------------------- + # Option C: Build a Graph with the fluent builder API (hidden: not documented yet :D) + # ----------------------------------------------------------------------- + print("\n=== Option C: Graph builder ===") + + result_c = ( + Graph("detect_and_segment") + .then(Detect(using="detector", out="dets", text_prompts="cat . dog . person")) + .then(Filter(src="dets", score_gt=0.3, out="filtered")) + .then(PromptBoxes(using="segmenter", dets_src="filtered", out="masks")) + .then(Fuse(detections="filtered", masks="masks", out="final")) + ).run( + image=str(IMAGE_1), + providers=providers, + ) + + print(f"Graph '{graph.name}' completed") + print(f"Channels: {list(result_c.channels.keys())}") + print("\nβœ“ Simple pipeline example complete!") diff --git a/examples/graph/video_tracking.py b/examples/graph/video_tracking.py index 7cc811e..80423fb 100644 --- a/examples/graph/video_tracking.py +++ b/examples/graph/video_tracking.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Video processing with object tracking: Detect β†’ Track across frames. +"""Video processing with object tracking: Detect > Track across frames. Demonstrates: 1. Building a detection + tracking graph @@ -11,6 +11,7 @@ Usage: python examples/graph/video_tracking.py python examples/graph/video_tracking.py --real + python examples/graph/video_tracking.py --video_path /path/to/video.mp4 Note: Requires OpenCV for video file processing: @@ -21,7 +22,6 @@ import os import sys -import tempfile # --------------------------------------------------------------------------- # Mock providers @@ -32,6 +32,7 @@ def create_mock_providers(): from unittest.mock import Mock from mata.core.types import Instance, VisionResult + from mata.nodes.track import SimpleIOUTracker # Simulate a detector that returns slightly different bboxes each frame # (mimicking real object movement) @@ -59,39 +60,26 @@ def mock_predict(image, **kwargs): mock_detector = Mock() mock_detector.predict = mock_predict - # Simple mock tracker (Track node uses its own internal tracker - # via ByteTrackWrapper / SimpleIOUTracker, so we use "tracker" as provider name) - mock_tracker = Mock() + # Use SimpleIOUTracker (built-in) as the tracker β€” returns real Tracks artifacts + tracker = SimpleIOUTracker() - return {"detector": mock_detector, "tracker": mock_tracker} + return { + "detect": {"detector": mock_detector}, + "track": {"tracker": tracker}, + } -def create_mock_video(num_frames: int = 30, fps: float = 30.0) -> str: - """Create a short mock video file for demo purposes. +DEFAULT_VIDEO_PATH = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "videos", "cup.mp4" +) - Returns path to temporary video file. - """ - try: - import cv2 - except ImportError: - return "" - import numpy as np - - path = os.path.join(tempfile.gettempdir(), "mata_demo_video.avi") - fourcc = cv2.VideoWriter_fourcc(*"MJPG") - writer = cv2.VideoWriter(path, fourcc, fps, (640, 480)) - - for i in range(num_frames): - # Simple frame with a moving rectangle (simulating object motion) - frame = np.zeros((480, 640, 3), dtype=np.uint8) - x = 50 + i * 10 - cv2.rectangle(frame, (x, 100), (x + 100, 300), (0, 255, 0), -1) - cv2.rectangle(frame, (400, 150), (500, 250), (255, 0, 0), -1) - writer.write(frame) - - writer.release() - return path +def get_video_path() -> str: + """Return video path from --video_path arg or the default cup.mp4.""" + for i, arg in enumerate(sys.argv): + if arg == "--video_path" and i + 1 < len(sys.argv): + return sys.argv[i + 1] + return DEFAULT_VIDEO_PATH # --------------------------------------------------------------------------- @@ -112,13 +100,17 @@ def main(): use_real = "--real" in sys.argv if use_real: import mata + from mata.nodes.track import SimpleIOUTracker print("Loading real models...") + detector = mata.load("detect", "PekingU/rtdetr_r50vd") + # VideoProcessor needs nested {capability: {name: adapter}} format providers = { - "detector": mata.load("detect", "PekingU/rtdetr_v2_r18vd"), - "tracker": "simple_iou", # Track node has built-in tracker + "detect": {"detector": detector}, + "track": {"tracker": SimpleIOUTracker()}, } else: print("Running with mock providers") + # VideoProcessor needs nested {capability: {name: adapter}} format providers = create_mock_providers() # ----------------------------------------------------------------------- @@ -140,8 +132,8 @@ def main(): # Every 3rd frame β€” good for offline video processing policy_every3 = FramePolicyEveryN(n=3) print(f"EveryN(3): frame 0={policy_every3.should_process(0)}, " - f"frame 1={policy_every3.should_process(1)}, " - f"frame 3={policy_every3.should_process(3)}") + f"frame 1={policy_every3.should_process(1)}, " + f"frame 3={policy_every3.should_process(3)}") # Latest frame only β€” good for real-time (RTSP/webcam) policy_latest = FramePolicyLatest() @@ -156,14 +148,20 @@ def main(): # ----------------------------------------------------------------------- print("\n=== Video File Processing ===") - # Create a mock video for demo - video_path = create_mock_video(num_frames=15, fps=15.0) - if not video_path: - print("OpenCV not installed β€” skipping video demo") - print("Install with: pip install opencv-python") + video_path = get_video_path() + if not os.path.exists(video_path): + print(f"Video not found: {video_path}") + print("Provide a path with: --video_path /path/to/video.mp4") else: - # Compile the graph for VideoProcessor - compiled = graph.compile(providers=providers) + print(f"Video: {video_path}") + # providers is nested {capability: {name: adapter}} for VideoProcessor/ExecutionContext. + # graph.compile() validator needs flat {name: adapter}, so flatten for it. + flat_providers = { + name: prov + for cap_dict in providers.values() + for name, prov in cap_dict.items() + } + compiled = graph.compile(providers=flat_providers) # Process with EveryN policy (every 3rd frame) processor = VideoProcessor( @@ -177,7 +175,7 @@ def main(): max_frames=15, # Limit to 15 frames for demo ) - print(f"Processed {len(results)} frames (every 3rd of 15)") + print(f"Processed {len(results)} frames (every 3rd, up to 15)") for i, frame_result in enumerate(results): channels = list(frame_result.channels.keys()) print(f" Frame {i}: channels={channels}") @@ -186,9 +184,6 @@ def main(): tracks = frame_result.get_channel("tracks") print(f" Active tracks: {len(tracks.tracks) if hasattr(tracks, 'tracks') else 'N/A'}") - # Clean up temp video - os.unlink(video_path) - # ----------------------------------------------------------------------- # Example 3: Using the detect_and_track preset # ----------------------------------------------------------------------- @@ -197,9 +192,9 @@ def main(): preset_graph = detect_and_track( detection_threshold=0.5, - track_thresh=0.4, + track_threshold=0.4, track_buffer=30, - match_thresh=0.8, + match_threshold=0.8, ) print(f"Preset graph: {preset_graph.name}") print(f"Nodes: {len(preset_graph._nodes)}") @@ -236,7 +231,7 @@ def on_frame_result(result, frame_num): # From another thread: stop_event.set() to stop """) - print("βœ“ Video tracking example complete!") + print("Video tracking example complete!") if __name__ == "__main__": diff --git a/examples/graph/vlm_workflows.py b/examples/graph/vlm_workflows.py index 1e4d8e3..dd4e24a 100644 --- a/examples/graph/vlm_workflows.py +++ b/examples/graph/vlm_workflows.py @@ -2,10 +2,10 @@ """VLM (Vision-Language Model) graph workflows. Demonstrates: -1. VLM grounded detection: VLM semantic β†’ GroundingDINO spatial β†’ promoted instances +1. VLM grounded detection: VLM semantic > GroundingDINO spatial > promoted instances 2. VLM scene understanding: parallel VLM + detection + depth 3. Multi-image comparison with VLMQuery -4. Entity β†’ Instance promotion with PromoteEntities +4. Entity > Instance promotion with PromoteEntities 5. Using VLM presets Usage: diff --git a/examples/segment/grounding_sam_pipeline.py b/examples/segment/grounding_sam_pipeline.py index d0c47cb..1d5d557 100644 --- a/examples/segment/grounding_sam_pipeline.py +++ b/examples/segment/grounding_sam_pipeline.py @@ -1,11 +1,11 @@ """GroundingDINO + SAM Pipeline Examples β€” MATA Framework -Demonstrates text-prompt-based instance segmentation using the GroundingDINOβ†’SAM +Demonstrates text-prompt-based instance segmentation using the GroundingDINO>SAM pipeline. Combines zero-shot object detection with zero-shot instance segmentation for precise masks. Pipeline Flow: - Text Prompts β†’ GroundingDINO (bboxes) β†’ SAM (masks) β†’ VisionResult + Text Prompts > GroundingDINO (bboxes) > SAM (masks) > VisionResult Run: python examples/segment/grounding_sam_pipeline.py Requirements: pip install mata transformers pillow numpy @@ -84,12 +84,12 @@ def visualize_instances(image, result, output_path): def example_basic_pipeline(): - """Example 1: Basic GroundingDINOβ†’SAM pipeline with text prompts.""" + """Example 1: Basic GroundingDINO>SAM pipeline with text prompts.""" print("\n" + "=" * 70) - print("Example 1: Basic Pipeline - Text β†’ BBox β†’ Mask") + print("Example 1: Basic Pipeline - Text > BBox > Mask") print("=" * 70) - print("\n[1/4] Loading GroundingDINOβ†’SAM pipeline...") + print("\n[1/4] Loading GroundingDINO>SAM pipeline...") pipeline = mata.load( "pipeline", detector_model_id="IDEA-Research/grounding-dino-tiny", diff --git a/examples/segment/sam_segment.py b/examples/segment/sam_segment.py index 4f05c95..6273c52 100644 --- a/examples/segment/sam_segment.py +++ b/examples/segment/sam_segment.py @@ -62,7 +62,7 @@ "segment", str(IMG_PATH), model=SAM3_MODEL, - text_prompts="cat", + text_prompts=["cat"], box_prompts=[(0, 0, 100, 100)], # Exclude top-left region box_labels=[0], # 0 = negative box (exclude) threshold=0.5, @@ -71,8 +71,10 @@ print(f"Found {len(result_text_refined.masks)} cats (excluded top-left region)") print("Use case: Remove false positives by excluding specific areas") -except Exception: - print("SAM3 not available (skipping)") +except Exception as e: + print(f"SAM3 not available: {e}") + print("Install with: pip install -U transformers>=4.46.0") + print("Continuing with original SAM examples (visual prompts)...") # ================================================================= # Example 1: Basic Point Prompt (Foreground Click) diff --git a/examples/tools/save_results.py b/examples/tools/save_results.py index 97a99f6..0467a8c 100644 --- a/examples/tools/save_results.py +++ b/examples/tools/save_results.py @@ -239,7 +239,7 @@ def main(): print("DEMONSTRATION COMPLETE") print("=" * 70) print("\nKey takeaways:") - print(" β€’ Extension auto-detection: .json β†’ JSON, .csv β†’ CSV, .png β†’ image") + print(" β€’ Extension auto-detection: .json > JSON, .csv > CSV, .png > image") print(" β€’ Image path auto-stored when using file paths") print(" β€’ PIL/numpy inputs require explicit image for overlay/crops") print(" β€’ Customize overlays with show_boxes/show_labels/show_scores/alpha") diff --git a/examples/track/stream_tracking.py b/examples/track/stream_tracking.py index 6ba3823..348b550 100644 --- a/examples/track/stream_tracking.py +++ b/examples/track/stream_tracking.py @@ -174,7 +174,7 @@ def main(argv: list[str] | None = None) -> None: print(f" Tracker: {args.tracker} conf={args.conf}") print() - # stream=True β†’ returns a generator, never accumulates full result list + # stream=True > returns a generator, never accumulates full result list frame_gen = mata.track( source, model=args.model, diff --git a/examples/validation.py b/examples/validation.py index 14be0ff..7c4d439 100644 --- a/examples/validation.py +++ b/examples/validation.py @@ -5,7 +5,7 @@ 2. Segmentation (COCO) 3. Classification (ImageNet) 4. Depth (DIODE / NYU) - 5. Standalone (pre-run predictions β†’ metrics without re-running inference) + 5. Standalone (pre-run predictions > metrics without re-running inference) Dataset path setup ------------------ @@ -68,7 +68,7 @@ def _save_metrics_json(metrics: object, save_dir: str, filename: str = "metrics. out_dir.mkdir(parents=True, exist_ok=True) out_path = out_dir / filename out_path.write_text(metrics.to_json(), encoding="utf-8") - print(f"Metrics saved β†’ {out_path}") + print(f"Metrics saved to {out_path}") # --------------------------------------------------------------------------- diff --git a/examples/vlm/ocr.py b/examples/vlm/ocr.py index 27296c9..e3feb69 100644 --- a/examples/vlm/ocr.py +++ b/examples/vlm/ocr.py @@ -125,7 +125,7 @@ def demo_load_once(): if not _check_image(img_path): continue result = adapter.predict(img_path) - print(f" [{img_path.name}] β†’ {result.full_text[:80]!r}") + print(f" [{img_path.name}] > {result.full_text[:80]!r}") # === Section 7: Export Results === diff --git a/src/mata/adapters/huggingface_sam_adapter.py b/src/mata/adapters/huggingface_sam_adapter.py index aff0685..1acab99 100644 --- a/src/mata/adapters/huggingface_sam_adapter.py +++ b/src/mata/adapters/huggingface_sam_adapter.py @@ -453,7 +453,8 @@ def predict( # SAM3 uses different format: List[List[List[float]]] for batching input_boxes = [[list(box) for box in box_prompts]] # SAM3 uses box labels (1=positive, 0=negative) - input_boxes_labels = [[box_labels or ([1] * len(box_prompts))]] + # Format: List[List[int]] = [image_batch, [per_box_labels]] + input_boxes_labels = [box_labels or ([1] * len(box_prompts))] else: # Original SAM: List[List[List[float]]] - batch, num_boxes, 4 # All boxes in a single batch diff --git a/src/mata/core/tool_registry.py b/src/mata/core/tool_registry.py index 2e62823..5273515 100644 --- a/src/mata/core/tool_registry.py +++ b/src/mata/core/tool_registry.py @@ -151,7 +151,7 @@ def __init__(self, ctx: ExecutionContext, tool_names: list[str]): # Provider-based tool - search all capabilities capability, provider = self._resolve_provider(tool_name) self._tool_map[tool_name] = (capability, provider) - self._schemas[tool_name] = self._schema_for_capability(capability, tool_name) + self._schemas[tool_name] = self._schema_for_capability(capability, tool_name, provider) logger.debug(f"Registered provider tool: {tool_name} (capability: {capability})") def _resolve_provider(self, tool_name: str) -> tuple[str, Any]: @@ -196,18 +196,38 @@ def _resolve_provider(self, tool_name: str) -> tuple[str, Any]: f"Available built-in tools: {builtin_str}" ) - def _schema_for_capability(self, capability: str, tool_name: str) -> ToolSchema: - """Generate ToolSchema for a capability. + def _is_zero_shot_provider(self, provider: Any) -> bool: + """Return True if the provider wraps a zero-shot adapter. - Uses the default task schemas from TASK_SCHEMA_DEFAULTS, but customizes - the name to match the provider name (e.g., "detr" instead of "detect"). + Unwraps one level of wrapper (e.g. DetectorWrapper) then checks the + class name for the 'ZeroShot' marker used by all MATA zero-shot adapters. + + Args: + provider: Provider instance (may be a wrapper or raw adapter) + + Returns: + True if the underlying adapter is a zero-shot model + """ + # Unwrap through a single wrapper layer (DetectorWrapper, ClassifierWrapper, etc.) + adapter = getattr(provider, "adapter", provider) + return "ZeroShot" in type(adapter).__name__ + + def _schema_for_capability(self, capability: str, tool_name: str, provider: Any) -> ToolSchema: + """Generate ToolSchema for a capability, tailored to the actual provider. + + Uses the default task schemas from TASK_SCHEMA_DEFAULTS as a base, but + customizes the name and β€” for zero-shot providers β€” upgrades + ``text_prompts`` to ``required=True`` so the VLM's system prompt + correctly instructs the model to always supply the parameter. Args: capability: Task capability ("detect", "classify", "segment", "depth") - tool_name: Provider name to use in schema + tool_name: Provider name to use in schema (e.g. "detector", "detr") + provider: The resolved provider instance (wrapper or raw adapter) Returns: - ToolSchema for this capability + ToolSchema for this capability, with text_prompts required when the + provider is a zero-shot model. """ if capability not in TASK_SCHEMA_DEFAULTS: # For unknown capabilities (e.g., "vlm"), create a minimal schema @@ -221,11 +241,34 @@ def _schema_for_capability(self, capability: str, tool_name: str) -> ToolSchema: # Clone the default schema but use the provider name default = TASK_SCHEMA_DEFAULTS[capability] + + # For zero-shot providers, upgrade text_prompts to required so the VLM + # knows it must always supply the classes it wants to detect/classify. + params = default.parameters + if self._is_zero_shot_provider(provider): + params = [ + ToolParameter( + p.name, + p.type, + ( + ( + "Object classes to detect, dot-separated (e.g. 'cat . dog . person'). " + "REQUIRED β€” this is a zero-shot model and cannot run without class names." + ) + if p.name == "text_prompts" + else p.description + ), + required=True if p.name == "text_prompts" else p.required, + default=None if p.name == "text_prompts" else p.default, + ) + for p in default.parameters + ] + return ToolSchema( name=tool_name, # Use provider name, not capability description=default.description, task=default.task, - parameters=default.parameters, + parameters=params, builtin=False, ) diff --git a/tests/test_sam_adapter.py b/tests/test_sam_adapter.py index 0df0661..20c1361 100644 --- a/tests/test_sam_adapter.py +++ b/tests/test_sam_adapter.py @@ -592,7 +592,7 @@ def test_sam3_text_with_negative_boxes(mock_transformers, mock_pycocotools): assert call_kwargs["text"] == "handle" assert "input_boxes" in call_kwargs assert "input_boxes_labels" in call_kwargs - assert call_kwargs["input_boxes_labels"] == [[[0]]] # Negative + assert call_kwargs["input_boxes_labels"] == [[0]] # Negative assert isinstance(result, VisionResult) assert len(result.masks) == 1 diff --git a/tests/test_tool_registry.py b/tests/test_tool_registry.py index 3e5b116..d69a06c 100644 --- a/tests/test_tool_registry.py +++ b/tests/test_tool_registry.py @@ -475,6 +475,92 @@ def test_schema_for_capability_uses_provider_name(): assert schema.task == "detect" +def test_schema_marks_text_prompts_required_for_zeroshot_provider(): + """Zero-shot provider: text_prompts must be required=True in generated schema. + + When a zero-shot detector (class name contains 'ZeroShot') is registered, + ToolRegistry must upgrade text_prompts to required so the VLM system prompt + tells the agent to always supply object class names. + """ + + # Build a mock provider whose underlying adapter class name contains 'ZeroShot' + class MockZeroShotDetectAdapter: + pass + + mock_adapter = MockZeroShotDetectAdapter() + mock_provider = Mock() + mock_provider.adapter = mock_adapter # Simulates DetectorWrapper.adapter + + ctx = ExecutionContext(providers={"detect": {"grounding_dino": mock_provider}}) + registry = ToolRegistry(ctx, ["grounding_dino"]) + schema = registry.get_schema("grounding_dino") + + text_prompts_param = next((p for p in schema.parameters if p.name == "text_prompts"), None) + assert text_prompts_param is not None + assert text_prompts_param.required is True, ( + "text_prompts must be required=True for zero-shot providers so the VLM " + "knows it must always supply class names" + ) + assert "REQUIRED" in text_prompts_param.description + + +def test_schema_marks_text_prompts_required_for_unwrapped_zeroshot_adapter(): + """Zero-shot adapter without wrapper: text_prompts still required=True.""" + + class HuggingFaceZeroShotDetectAdapter: + pass + + mock_provider = HuggingFaceZeroShotDetectAdapter() # No wrapper β€” raw adapter + ctx = ExecutionContext(providers={"detect": {"zs_detector": mock_provider}}) + registry = ToolRegistry(ctx, ["zs_detector"]) + schema = registry.get_schema("zs_detector") + + text_prompts_param = next(p for p in schema.parameters if p.name == "text_prompts") + assert text_prompts_param.required is True + + +def test_schema_keeps_text_prompts_optional_for_standard_provider(): + """Standard (supervised) provider: text_prompts stays optional in schema. + + Non-zero-shot detectors (DETR, YOLO, RT-DETR, etc.) do not use text prompts, + so the parameter must remain optional in the schema. + """ + mock_provider = Mock() # Class name is 'Mock', no 'ZeroShot' + ctx = ExecutionContext(providers={"detect": {"detr": mock_provider}}) + registry = ToolRegistry(ctx, ["detr"]) + schema = registry.get_schema("detr") + + text_prompts_param = next((p for p in schema.parameters if p.name == "text_prompts"), None) + assert text_prompts_param is not None + assert text_prompts_param.required is False, "text_prompts must remain optional for supervised/standard detectors" + + +def test_is_zero_shot_provider_detects_via_wrapper(): + """_is_zero_shot_provider() unwraps through wrapper.adapter.""" + + class HuggingFaceZeroShotDetectAdapter: + pass + + class DetectorWrapper: + def __init__(self, adapter): + self.adapter = adapter + + wrapper = DetectorWrapper(HuggingFaceZeroShotDetectAdapter()) + ctx = ExecutionContext(providers={"detect": {"zs": wrapper}}) + registry = ToolRegistry(ctx, ["zs"]) + + assert registry._is_zero_shot_provider(wrapper) is True + + +def test_is_zero_shot_provider_false_for_standard(): + """_is_zero_shot_provider() returns False for non-zero-shot providers.""" + mock_provider = Mock() # 'Mock' has no 'ZeroShot' in name + ctx = ExecutionContext(providers={"detect": {"detr": mock_provider}}) + registry = ToolRegistry(ctx, ["detr"]) + + assert registry._is_zero_shot_provider(mock_provider) is False + + # ============================================================================ # BUILTIN_SCHEMAS Tests # ============================================================================