diff --git a/.github/workflows/check.yaml b/.github/workflows/check.yaml
index 12ae0c9..68b4c7c 100644
--- a/.github/workflows/check.yaml
+++ b/.github/workflows/check.yaml
@@ -15,13 +15,12 @@ jobs:
- name: Checkout Repository
uses: actions/checkout@v4
- - name: Install Poetry
+ - name: Install uv
run: |
- curl -sSL https://install.python-poetry.org | python3 -
- echo "$HOME/.local/bin" >> $GITHUB_PATH
+ curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Install Dependencies
- run: poetry install
+ run: uv sync && uv sync --group dev
- name: Ensure check.sh exists and is executable
run: |
diff --git a/README.md b/README.md
index d55ea8d..8c7b98e 100644
--- a/README.md
+++ b/README.md
@@ -1,85 +1,305 @@
-# GRAID: Generating Reasoning questions from Analysis of Images via Discriminative artificial intelligence
+## GRAID: Generating Reasoning questions from Analysis of Images via Discriminative artificial intelligence
[Design Doc](https://docs.google.com/document/d/1zgb1odK3zfwLg2zKts2eC1uQcQfUd6q_kKeMzd1q-m4/edit?tab=t.0)
## π Quick Start
### Installation
+0. 0. Install uv (optional if you already have it): `curl -LsSf https://astral.sh/uv/install.sh | sh` (or see [uv installation guide](https://docs.astral.sh/uv/getting-started/installation/))
1. Create a virtual environment: `uv venv`
2. Activate it: `source .venv/bin/activate` (or use direnv with the provided .envrc)
3. Install dependencies: `uv sync`
4. Install all backends: `uv run install_all`
-### Using GRAID CLI
+### π€ HuggingFace Dataset Generation
-**Interactive Mode (Recommended):**
+**Generate high-quality VQA datasets for modern ML workflows:**
```bash
-# Using conda environment
-/work/ke/miniconda3/envs/scenic_reason/bin/python scenic_reasoning/src/scenic_reasoning/graid_cli.py generate
+# Interactive mode with step-by-step guidance
+graid generate-dataset
-# Using uv (after installation)
-uv run graid generate
+# Or using uv run (equivalent, but not necessary after uv sync)
+uv run graid generate-dataset
```
-**Non-Interactive Mode:**
+**Key Features:**
+- **π― Object Filtering**: Smart allowable sets for focused object detection
+- **π¬ Multi-Model Ensemble**: Weighted Boxes Fusion (WBF) for improved accuracy
+- **βοΈ Flexible Configuration**: JSON configs for reproducible experiments
+- **π HuggingFace Hub Integration**: Direct upload to share datasets
+- **πΌοΈ PIL Image Support**: Ready for modern vision-language models
+- **π Rich Metadata**: Comprehensive dataset documentation
+
+**Quick Examples:**
```bash
-# Generate ground truth database
-uv run graid generate --dataset bdd --split val --interactive false
+# Generate with specific object types (autonomous driving focus)
+uv run graid generate-dataset --allowable-set "person,car,truck,bicycle,traffic light"
+
+# Multi-model ensemble for enhanced accuracy
+uv run graid generate-dataset --config examples/wbf_ensemble.json
+
+# Upload directly to HuggingFace Hub
+uv run graid generate-dataset --upload-to-hub --hub-repo-id "your-org/dataset-name"
-# Use pre-configured model
-uv run graid generate --dataset nuimage --split train --backend ultralytics --model yolov8x --conf 0.3 --interactive false
+# List all valid COCO objects
+uv run graid generate-dataset --list-objects
```
-**Available Commands:**
-```bash
-uv run graid --help # Show help
-uv run graid list-models # List available models
-uv run graid info # Show project information
+### ποΈ Configuration-Driven Workflows
+
+**Create reusable configurations for systematic experiments:**
+
+**Basic Configuration:**
+```json
+{
+ "dataset_name": "bdd",
+ "split": "val",
+ "models": [
+ {
+ "backend": "detectron",
+ "model_name": "faster_rcnn_R_50_FPN_3x",
+ "confidence_threshold": 0.7
+ },
+ {
+ "backend": "mmdetection",
+ "model_name": "co_detr",
+ "confidence_threshold": 0.6
+ }
+ ],
+ "use_wbf": true,
+ "wbf_config": {
+ "iou_threshold": 0.6,
+ "model_weights": [1.0, 1.2]
+ },
+ "allowable_set": ["person", "car", "truck", "bus", "motorcycle", "bicycle"],
+ "confidence_threshold": 0.5,
+ "batch_size": 4
+}
+```
+
+**Advanced Configuration with Custom Questions and Transforms:**
+```json
+{
+ "dataset_name": "bdd",
+ "split": "val",
+ "models": [
+ {
+ "backend": "ultralytics",
+ "model_name": "yolov8x.pt",
+ "confidence_threshold": 0.6
+ }
+ ],
+ "use_wbf": false,
+ "allowable_set": ["person", "car", "bicycle", "motorcycle", "traffic light"],
+ "confidence_threshold": 0.5,
+ "batch_size": 2,
+
+ "questions": [
+ {
+ "name": "HowMany",
+ "params": {}
+ },
+ {
+ "name": "Quadrants",
+ "params": {
+ "N": 3,
+ "M": 3
+ }
+ },
+ {
+ "name": "WidthVsHeight",
+ "params": {
+ "threshold": 0.4
+ }
+ },
+ {
+ "name": "LargestAppearance",
+ "params": {
+ "threshold": 0.35
+ }
+ },
+ {
+ "name": "MostClusteredObjects",
+ "params": {
+ "threshold": 80
+ }
+ }
+ ],
+
+ "transforms": {
+ "type": "yolo_bdd",
+ "new_shape": [640, 640]
+ },
+
+ "save_path": "./datasets/custom_bdd_vqa",
+ "upload_to_hub": true,
+ "hub_repo_id": "your-org/bdd-reasoning-dataset",
+ "hub_private": false
+}
+```
+
+**Custom Model Configuration:**
+```json
+{
+ "dataset_name": "custom",
+ "split": "train",
+ "models": [
+ {
+ "backend": "detectron",
+ "model_name": "custom_retinanet",
+ "custom_config": {
+ "config": "path/to/config.yaml",
+ "weights": "path/to/model.pth"
+ }
+ },
+ {
+ "backend": "ultralytics",
+ "model_name": "custom_yolo",
+ "custom_config": {
+ "model_path": "path/to/custom_yolo.pt"
+ }
+ }
+ ],
+ "transforms": {
+ "type": "yolo_bdd",
+ "new_shape": [832, 832]
+ },
+ "questions": [
+ {
+ "name": "IsObjectCentered",
+ "params": {}
+ },
+ {
+ "name": "LeftOf",
+ "params": {}
+ },
+ {
+ "name": "RightOf",
+ "params": {}
+ }
+ ]
+}
+```
+
+### π¦ Custom Dataset Support
+
+**Bring Your Own Data**: GRAID supports any PyTorch-compatible dataset:
+
+```python
+from graid.data.generate_dataset import generate_dataset
+from torch.utils.data import Dataset
+
+class CustomDataset(Dataset):
+ """Your custom dataset implementation"""
+ def __getitem__(self, idx):
+ # Return: (image_tensor, optional_annotations, metadata)
+ # Annotations are only needed for mAP/mAR evaluation
+ # For VQA generation, only images are required
+ pass
+
+# Generate HuggingFace dataset from your data
+dataset = generate_dataset(
+ dataset_name="custom",
+ split="train",
+ models=your_models,
+ allowable_set=["person", "vehicle"],
+ save_path="./datasets/custom_vqa"
+)
```
-## Status
+**Key Point**: Custom datasets only require images for VQA generation. Annotations are optional and only needed if you want to evaluate model performance with mAP/mAR metrics.
+
+## π§ Advanced Features
+
+### **Multi-Model Ensemble with WBF**
+Combine predictions from multiple models using Weighted Boxes Fusion for enhanced detection accuracy:
+- Improved precision through model consensus
+- Configurable fusion parameters and model weights
+- Supports mixed backends (Detectron2 + MMDetection + Ultralytics)
+
+### **Intelligent Object Filtering**
+Focus datasets on specific object categories:
+- **Common presets**: Autonomous driving, indoor scenes, animals
+- **Interactive selection**: Visual picker from 80 COCO categories
+- **Manual specification**: Comma-separated object lists
+- **Validation**: Automatic checking against COCO standard
+
+### **Production-Ready Outputs**
+Generated datasets include:
+- **PIL Images**: Direct compatibility with vision-language models
+- **Rich Annotations**: Bounding boxes, confidence scores, object classes
+- **Structured QA Pairs**: Question templates with precise answers
+- **Comprehensive Metadata**: Model info, generation parameters, statistics
+
+## π Supported Models & Datasets
### Backends
-| | Ultralytics | Detectron | MMDetection |
-|-----------------------|-------------|-----------|-------------|
-| Object Detection | β
| β
| β
|
-| Instance Segmentation | β
| β
| β
|
+| | Detectron2 | MMDetection | Ultralytics |
+|-----------------------|-------------|-------------|-------------|
+| Object Detection | β
| β
| β
|
+| Instance Segmentation | β
| β
| β
|
+| WBF Ensemble | β
| β
| β
|
-### Datasets
+### Built-in Datasets
-| | BDD100K | Waymo | NuImages |
-|-----------------------|-------------|-----------|-------------|
-| Object Detection | β
| β
| β
|
-| Instance Segmentation | β
| β
| β
|
+| | BDD100K | NuImages | Waymo |
+|-----------------------|-------------|-------------|-------------|
+| Object Detection | β
| β
| β
|
+| Instance Segmentation | β
| β
| β
|
+| HuggingFace Export | β
| β
| β
|
-## π§ Supported Models
+### Example Models
-**Detectron2:** `retinanet_R_101_FPN_3x`, `faster_rcnn_R_50_FPN_3x`
-**MMDetection:** `co_detr`, `dino`
+**Detectron2:** `faster_rcnn_R_50_FPN_3x`, `retinanet_R_101_FPN_3x`
+**MMDetection:** `co_detr`, `dino`, `rtmdet`
**Ultralytics:** `yolov8x`, `yolov10x`, `yolo11x`, `rtdetr-x`
-## β¨ GRAID Features
-
-- **Interactive CLI**: User-friendly prompts for dataset and model selection
-- **Multiple Backends**: Support for Detectron2, MMDetection, and Ultralytics
-- **Custom Models**: Bring your own model configurations
-- **Ground Truth Support**: Generate databases using original annotations
-- **Batch Processing**: Support for non-interactive scripted usage
+## π― Research Applications
-## π Project Structure
+This framework enables systematic evaluation of:
+- **Vision-Language Models**: Generate targeted VQA benchmarks
+- **Object Detection Methods**: Compare model performance on specific object types
+- **Reasoning Capabilities**: Create challenging spatial and counting questions
+- **Domain Adaptation**: Generate domain-specific evaluation sets
+- **Ensemble Methods**: Evaluate fusion strategies across detection models
-The project has been renamed from `scenic-reasoning` to **GRAID**. Key components:
+## π Quality Assurance
-- **Package**: `scenic_reasoning/src/graid/` (new GRAID package)
-- **CLI**: `scenic_reasoning/src/scenic_reasoning/graid_cli.py`
-- **Original**: `scenic_reasoning/src/scenic_reasoning/` (backward compatibility)
+Generated datasets undergo comprehensive validation:
+- **Model Verification**: Automatic testing of model loading and inference
+- **Annotation Quality**: Confidence score filtering and duplicate removal
+- **Metadata Integrity**: Complete provenance tracking for reproducibility
+- **Format Compliance**: COCO-standard annotations with HuggingFace compatibility
-## π Databases
+## π Legacy Support
-Generated databases are saved in:
+**Interactive CLI**: User-friendly prompts for dataset and model selection
+```bash
+uv run graid generate
```
-data/databases_ablations/{dataset}_{split}_{conf}_{backend}_{model}.sqlite
+
+**Available Commands:**
+```bash
+uv run graid --help # Show help
+uv run graid list-models # List available models
+uv run graid list-questions # List available question types with parameters
+uv run graid info # Show project information
+uv run graid generate-dataset # Modern HuggingFace generation
+
+# Interactive features
+uv run graid generate-dataset --interactive-questions # Select questions interactively
+uv run graid generate-dataset --list-questions # Show available questions
```
-β
**Ready to use!**
+## β¨ Key Advantages
+
+- **π Modern Format**: HuggingFace datasets for seamless ML integration
+- **π― Targeted Generation**: Focus on relevant object categories
+- **π¬ Ensemble Support**: Multi-model fusion for enhanced accuracy
+- **βοΈ Reproducible**: Configuration-driven experiments
+- **π Shareable**: Direct HuggingFace Hub integration
+- **π Comprehensive**: Rich metadata and quality metrics
+- **π§ Extensible**: Support for custom datasets and models
+
+**β
Ready for production VQA research and applications!**
diff --git a/YOLO_BDD_Workflow.md b/YOLO_BDD_Workflow.md
new file mode 100644
index 0000000..a0639a1
--- /dev/null
+++ b/YOLO_BDD_Workflow.md
@@ -0,0 +1,86 @@
+# BDD100K β YOLOv9 Training Workflow
+
+This short guide explains **(1) how to export BDD100K annotations to YOLO-compatible `.txt` files** and **(2) how to start a multi-GPU YOLOv9 training run logged to Weights & Biases (wandb)**. Two launch modes are covered: *local workstation* and *Slurm cluster*.
+
+---
+## 1 Export YOLO labels
+
+```bash
+# Activate the project environment
+conda activate scenic_reason
+
+# Run the exporter (1β3 min per split on a V100)
+python graid/src/graid/data/export_bdd_to_yolo.py
+```
+
+The script:
+* Instantiates `Bdd100kDataset(split, use_original_categories=True, use_time_filtered=False)`.
+* Converts each `ObjectDetectionResultI` to YOLOβstyle **normalized** `x_center y_center width height` using `as_xywhn()`.
+* Writes one `.txt` per image **only when at least one object is present**.
+
+### Output structure (expected by Ultralytics)
+```
+data/bdd100k/
+βββ images/100k/train/ *.jpg β original images
+βββ images/100k/val/ *.jpg
+βββ labels/100k/
+ βββ train/ *.txt β generated by script
+ βββ val/ *.txt
+```
+If you initially generated files under `yolo_labels/`, move them into the parallel `labels/100k/{train,val}/` directories:
+
+```bash
+mkdir -p data/bdd100k/labels/100k/{train,val}
+mv data/bdd100k/yolo_labels/train/*.txt data/bdd100k/labels/100k/train/
+mv data/bdd100k/yolo_labels/val/*.txt data/bdd100k/labels/100k/val/
+```
+
+> **Tip:** set `root_output_dir = Path("data/bdd100k/labels/100k")` in the script if you prefer the exporter to write directly to the correct place.
+
+---
+## 2 Start training (wandb logging)
+
+### 2.1 Local / interactive
+```bash
+python train_yolo_bdd.py
+```
+`train_yolo_bdd.py` will:
+1. Detect available GPUs via `CUDA_VISIBLE_DEVICES` (or fall back to all GPUs).
+2. Set `WANDB_PROJECT=yolo_bdd`, `WANDB_NAME=yolo_bdd_train` (edit if you like).
+3. Launch Ultralytics training:
+ * Model : `yolov9e.pt` (pre-trained)
+ * Dataset: `bdd_ultra.yaml`
+ * Epochs : 100 β’ Image size : 1080 β’ Batch : 32
+4. Artifacts & logs are stored in `runs/detect/yolo_bdd/` and on wandb.
+
+Requirements
+* `wandb login` once beforehand (or set `WANDB_MODE=offline`).
+* ~128 GB RAM, 4 GPUs, < 10 h wall-time (V100).
+
+### 2.2 Slurm cluster
+1. **Inspect / edit resources** in `train_yolo_bdd.slurm` (default: 4 GPUs, 16 CPU, 128 GB, 10 h).
+2. Submit the job:
+ ```bash
+ sbatch train_yolo_bdd.slurm
+ ```
+3. Logs stream to `slurm-.out` and wandb.
+
+The batch script:
+* Loads your conda env (`scenic_reason`).
+* Exports `WANDB_MODE=online` (remove or change to *offline* if desired).
+* Executes `python train_yolo_bdd.py`. Ultralytics launches DDP automatically across the 4 GPUs allocated by Slurm.
+
+---
+## 3 Monitoring & artefacts
+* **wandb:**
+ ```bash
+ wandb online # if you disabled it earlier
+ # view runs at https://wandb.ai//yolo_bdd
+ ```
+* **TensorBoard:**
+ ```bash
+ tensorboard --logdir runs/detect/yolo_bdd
+ ```
+* **Best weights:** `runs/detect/yolo_bdd/weights/best.pt`
+
+Happy training! π
\ No newline at end of file
diff --git a/bdd_train_gt_dataset_config.json b/bdd_train_gt_dataset_config.json
new file mode 100644
index 0000000..63e5f68
--- /dev/null
+++ b/bdd_train_gt_dataset_config.json
@@ -0,0 +1,158 @@
+{
+ "dataset_name": "bdd",
+ "split": "train+val",
+ "models": [],
+ "use_wbf": false,
+ "confidence_threshold": 0.0,
+ "batch_size": 128,
+ "device": null,
+ "allowable_set": null,
+ "num_workers": 16,
+ "qa_workers": 32,
+ "num_samples": 0,
+ "question_configs": [
+ {
+ "name": "IsObjectCentered",
+ "params": {
+ "buffer_ratio": 0.05
+ }
+ },
+ {
+ "name": "WidthVsHeight",
+ "params": {
+ "threshold": 0.75,
+ "non_articulated_classes": ["car", "truck", "bus", "person", "bicycle", "motorcycle"]
+ }
+ },
+ {
+ "name": "Quadrants",
+ "params": {
+ "N": 2,
+ "M": 2,
+ "margin_ratio": 0.1
+ }
+ },
+ {
+ "name": "Quadrants",
+ "params": {
+ "N": 2,
+ "M": 3,
+ "margin_ratio": 0.1
+ }
+ },
+ {
+ "name": "Quadrants",
+ "params": {
+ "N": 3,
+ "M": 2,
+ "margin_ratio": 0.1
+ }
+ },
+ {
+ "name": "Quadrants",
+ "params": {
+ "N": 3,
+ "M": 3,
+ "margin_ratio": 0.1
+ }
+ },
+ {
+ "name": "LargestAppearance",
+ "params": {
+ "threshold": 0.3
+ }
+ },
+ {
+ "name": "RankLargestK",
+ "params": {
+ "k": 2,
+ "margin_ratio": 0.3
+ }
+ },
+ {
+ "name": "RankLargestK",
+ "params": {
+ "k": 3,
+ "margin_ratio": 0.3
+ }
+ },
+ {
+ "name": "MostAppearance",
+ "params": {
+ "margin_ratio": 0.2
+ }
+ },
+ {
+ "name": "LeastAppearance",
+ "params": {
+ "margin_ratio": 0.2
+ }
+ },
+ {
+ "name": "LeftOf",
+ "params": {}
+ },
+ {
+ "name": "RightOf",
+ "params": {}
+ },
+ {
+ "name": "LeftMost",
+ "params": {}
+ },
+ {
+ "name": "RightMost",
+ "params": {}
+ },
+ {
+ "name": "HowMany",
+ "params": {}
+ },
+ {
+ "name": "AreMore",
+ "params": {
+ "margin_ratio": 0.2
+ }
+ },
+ {
+ "name": "WhichMore",
+ "params": {
+ "margin_ratio": 0.2
+ }
+ },
+ {
+ "name": "LeftMostWidthVsHeight",
+ "params": {
+ "threshold": 0.75,
+ "spatial_margin_ratio": 0.05
+ }
+ },
+ {
+ "name": "RightMostWidthVsHeight",
+ "params": {
+ "threshold": 0.75,
+ "spatial_margin_ratio": 0.05
+ }
+ },
+ {
+ "name": "MoreThanThresholdHowMany",
+ "params": {
+ "threshold": 1.5
+ }
+ },
+ {
+ "name": "LessThanThresholdHowMany",
+ "params": {
+ "threshold": 0.5
+ }
+ },
+ {
+ "name": "MultiChoiceHowMany",
+ "params": {}
+ }
+ ],
+ "save_path": "./bdd_train_gt_dataset",
+ "upload_to_hub": true,
+ "hub_repo_id": "kd7/graid-bdd100k-ground-truth",
+ "hub_private": false
+}
\ No newline at end of file
diff --git a/bdd_ultra.yaml b/bdd_ultra.yaml
new file mode 100644
index 0000000..513a834
--- /dev/null
+++ b/bdd_ultra.yaml
@@ -0,0 +1,19 @@
+path: /work/ke/research/scenic-reasoning/data/bdd100k
+train: images/100k/train
+val: images/100k/val
+test: # optional
+
+# classes
+names:
+ 0: pedestrian
+ 1: person
+ 2: rider
+ 3: car
+ 4: truck
+ 5: bus
+ 6: train
+ 7: motorcycle
+ 8: bicycle
+ 9: traffic light
+ 10: traffic sign
+ 11: sidewalk
\ No newline at end of file
diff --git a/evals/coco_stream.py b/evals/coco_stream.py
index b11da74..ba44a4d 100644
--- a/evals/coco_stream.py
+++ b/evals/coco_stream.py
@@ -12,9 +12,6 @@
NuImagesDataset,
WaymoDataset,
)
-from graid.models.Detectron import Detectron_obj
-from graid.models.MMDetection import MMdetection_obj
-from graid.models.Ultralytics import RT_DETR, Yolo
from graid.utilities.common import (
project_root_dir,
yolo_bdd_transform,
@@ -43,17 +40,18 @@
"--model",
"-m",
type=str,
- default="yolo11x",
+ default="yolov8x-world",
choices=[
"DINO",
"Co_DETR",
"yolov10x",
- "yolo11x",
+ "yolov8x-world",
"rtdetr",
"retinanet_R_101_FPN_3x",
"faster_rcnn_R_50_FPN_3x",
"X101_FPN",
"faster_rcnn_R_101_FPN_3x",
+ "vitdet",
],
help="Model to use",
)
@@ -81,9 +79,15 @@
dataset = args.dataset
if dataset == "bdd":
+ # For vitdet, we don't apply a transform here because the model's
+ # preprocessing is handled inside its class. For others, we might.
+ transform = None
+ if args.model != "vitdet":
+ transform = lambda i, l: yolo_bdd_transform(i, l, new_shape=(768, 1280))
+
dataset = Bdd100kDataset(
- split="train",
- transform=lambda i, l: yolo_bdd_transform(i, l, new_shape=(768, 1280)),
+ split="val", # Use validation set for evaluation
+ transform=transform,
use_original_categories=False,
use_extended_annotations=False,
)
@@ -111,7 +115,7 @@
# Initialize the model
"""
-Yolo(model="yolo11n.pt")"
+Yolo(model="yolov8x-world.pt")"
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.094
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.161
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.093
@@ -127,10 +131,16 @@
"""
model = args.model
if "yolov6" in model:
+ from graid.models.Ultralytics import Yolo
+
model = Yolo(model=f"{model}.yaml")
elif "yolo" in model:
+ from graid.models.Ultralytics import Yolo
+
model = Yolo(model=f"{model}.pt")
elif model == "DINO":
+ from graid.models.MMDetection import MMdetection_obj
+
MMDETECTION_PATH = project_root_dir() / "install" / "mmdetection"
DINO_config = str(
MMDETECTION_PATH / "configs/dino/dino-5scale_swin-l_8xb2-12e_coco.py"
@@ -141,6 +151,8 @@
model = MMdetection_obj(DINO_config, DINO_checkpoint)
BATCH_SIZE = 1 # MMDetection does not support batch size > 1
elif model == "Co_DETR":
+ from graid.models.MMDetection import MMdetection_obj
+
MMDETECTION_PATH = project_root_dir() / "install" / "mmdetection"
Co_DETR_config = str(
MMDETECTION_PATH
@@ -152,6 +164,8 @@
model = MMdetection_obj(Co_DETR_config, Co_DETR_checkpoint)
BATCH_SIZE = 1 # MMDetection does not support batch size > 1
elif model == "retinanet_R_101_FPN_3x":
+ from graid.models.Detectron import Detectron_obj
+
retinanet_R_101_FPN_3x_config = (
"COCO-Detection/retinanet_R_101_FPN_3x.yaml" # 228MB
)
@@ -161,6 +175,8 @@
weights_file=retinanet_R_101_FPN_3x_weights,
)
elif model == "faster_rcnn_R_50_FPN_3x":
+ from graid.models.Detectron import Detectron_obj
+
faster_rcnn_R_50_FPN_3x_config = (
"COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml" # 167MB
)
@@ -170,9 +186,13 @@
weights_file=faster_rcnn_R_50_FPN_3x_weights,
)
elif model == "rtdetr":
+ from graid.models.Ultralytics import RT_DETR
+
model = RT_DETR("rtdetr-x.pt")
elif model == "X101_FPN":
+ from graid.models.Detectron import Detectron_obj
+
X101_FPN_config = "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml" # 167MB
X101_FPN_weights = "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"
model = Detectron_obj(
@@ -180,12 +200,25 @@
weights_file=X101_FPN_weights,
)
elif model == "faster_rcnn_R_101_FPN_3x":
+ from graid.models.Detectron import Detectron_obj
+
faster_rcnn_R_101_FPN_3x_config = "COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml"
faster_rcnn_R_101_FPN_3x_weights = "COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml"
model = Detectron_obj(
config_file=faster_rcnn_R_101_FPN_3x_config,
weights_file=faster_rcnn_R_101_FPN_3x_weights,
)
+elif model == "vitdet":
+ from graid.models.Detectron import DetectronLazy
+
+ CONFIG_FILE = "install/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py"
+ CHECKPOINT_FILE = "checkpoints/detectron2/model_final_f05665.pkl"
+ model = DetectronLazy(
+ config_file=CONFIG_FILE,
+ weights_file=CHECKPOINT_FILE,
+ threshold=args.conf,
+ device=device,
+ )
model.to(device)
@@ -304,7 +337,15 @@
# Now, we build the final COCO ground-truth file by streaming through the temporary files with ijson.
with open(coco_gt_path, "w") as f_out:
- f_out.write('{"images": ')
+ info_record = {
+ "description": f"COCO-style dataset generated for {args.dataset} evaluation",
+ "version": "1.0",
+ "year": 2024,
+ }
+ f_out.write('{"info": ')
+ f_out.write(json.dumps(info_record))
+ f_out.write(', "images": ')
+
# Stream images from the temporary images file.
with open(gt_images_temp_path, "r") as f_images:
f_out.write("[")
diff --git a/evals/eval_vlms.py b/evals/eval_vlms.py
index 78f6601..d75cb85 100644
--- a/evals/eval_vlms.py
+++ b/evals/eval_vlms.py
@@ -269,16 +269,17 @@ def iterate_sqlite_db(db_path, my_vlm, my_metric, my_prompt, use_batch=False, sa
if use_batch and image_path is not None:
questions = ", ".join([item for i, item in enumerate(questions)])
answers = ", ".join([item for i, item in enumerate(answers)])
- _, prompt = my_prompt.generate_prompt(image_path, questions)
+ annotated_image, messages = my_prompt.generate_prompt(image_path, questions)
if image_path is not None:
- cache_key = f"{my_vlm}_{my_prompt}_{image_path}_{prompt}" + (
+ messages_str = str(messages)
+ cache_key = f"{my_vlm}_{my_prompt}_{image_path}_{messages_str}" + (
"_SoM" if "SetOfMarkPrompt" == str(my_prompt) else ""
) + "_batch"
if cache_key in vlm_cache:
preds = vlm_cache[cache_key]
else:
- preds, prompt = my_vlm.generate_answer(
- image_path, questions, my_prompt
+ preds, returned_messages = my_vlm.generate_answer(
+ annotated_image, messages
)
vlm_cache[cache_key] = preds
else:
@@ -292,15 +293,16 @@ def iterate_sqlite_db(db_path, my_vlm, my_metric, my_prompt, use_batch=False, sa
if len(q) < 5: # check for "D" and "y"
raise ValueError(f"Question too short: {q}")
- # the cache key should be image_path + prompt
- _, prompt = my_prompt.generate_prompt(image, q)
- cache_key = f"{my_vlm}_{my_prompt}_{image_path}_{prompt}" + (
+ # Generate prompt and unique cache key
+ annotated_image, messages = my_prompt.generate_prompt(image, q)
+ messages_str = str(messages)
+ cache_key = f"{my_vlm}_{my_prompt}_{image_path}_{messages_str}" + (
"_SoM" if "SetOfMarkPrompt" == str(my_prompt) else ""
)
if cache_key in vlm_cache:
pred = vlm_cache[cache_key]
else:
- pred, prompt = my_vlm.generate_answer(image_path, q, my_prompt)
+ pred, returned_messages = my_vlm.generate_answer(annotated_image, messages)
vlm_cache[cache_key] = pred
vlm_cache.commit()
correct = my_metric.evaluate(pred, a)
diff --git a/graid/src/graid/__init__.py b/graid/src/graid/__init__.py
index c028093..9f76143 100644
--- a/graid/src/graid/__init__.py
+++ b/graid/src/graid/__init__.py
@@ -1,3 +1,32 @@
-from graid.graid import app
+"""
+GRAID: Generating Reasoning questions from Analysis of Images via Discriminative artificial intelligence
+"""
+
+import logging
+import os
+import warnings
+
+# Suppress common warnings for better user experience
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+warnings.filterwarnings(
+ "ignore", message=".*TorchScript.*functional optimizers.*deprecated.*"
+)
+warnings.filterwarnings("ignore", message=".*Adafactor is already registered.*")
+
+# Suppress mmengine info messages
+logging.getLogger("mmengine").setLevel(logging.ERROR)
+
+# Set environment variable to reduce mmengine verbosity
+os.environ["MMENGINE_LOGGING_LEVEL"] = "ERROR"
+
+
+def __getattr__(name):
+ """Lazy import for the CLI app to avoid loading it on every import."""
+ if name == "app":
+ from graid.graid import app
+
+ return app
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
+
__all__ = ["app"]
diff --git a/graid/src/graid/cli/__init__.py b/graid/src/graid/cli/__init__.py
new file mode 100644
index 0000000..513123b
--- /dev/null
+++ b/graid/src/graid/cli/__init__.py
@@ -0,0 +1,27 @@
+"""
+GRAID CLI Components
+
+Modular CLI architecture for better maintainability and testing.
+"""
+
+from .config_manager import ConfigurationManager
+from .validators import ArgumentValidator
+from .error_handler import ErrorHandler
+from .exceptions import (
+ CLIError, ValidationError, DatasetValidationError, COCOValidationError,
+ SplitValidationError, ConfigurationError, ProcessingError, UploadError
+)
+
+__all__ = [
+ "ConfigurationManager",
+ "ArgumentValidator",
+ "ErrorHandler",
+ "CLIError",
+ "ValidationError",
+ "DatasetValidationError",
+ "COCOValidationError",
+ "SplitValidationError",
+ "ConfigurationError",
+ "ProcessingError",
+ "UploadError",
+]
diff --git a/graid/src/graid/cli/config_manager.py b/graid/src/graid/cli/config_manager.py
new file mode 100644
index 0000000..9c5812d
--- /dev/null
+++ b/graid/src/graid/cli/config_manager.py
@@ -0,0 +1,155 @@
+"""
+Configuration management for CLI commands.
+
+Handles loading configuration from files and merging with CLI arguments.
+"""
+
+import os
+from pathlib import Path
+from typing import Optional, Dict, Any
+
+import typer
+
+from graid.data.config_support import load_config_from_file, DatasetGenerationConfig
+from graid.cli.exceptions import ConfigurationError
+
+
+class ConfigurationManager:
+ """Handles loading and merging configuration from files and CLI."""
+
+ @staticmethod
+ def load_from_file(config_file: str) -> DatasetGenerationConfig:
+ """Load configuration from JSON file."""
+ try:
+ typer.secho(
+ "π Loading configuration from file...", fg=typer.colors.BLUE, bold=True
+ )
+ return load_config_from_file(config_file)
+ except FileNotFoundError:
+ raise ConfigurationError(f"Configuration file not found: {config_file}")
+ except Exception as e:
+ raise ConfigurationError(f"Failed to load configuration: {e}")
+
+ @staticmethod
+ def apply_cli_overrides(config: DatasetGenerationConfig, **cli_args) -> DatasetGenerationConfig:
+ """Apply CLI arguments that override config file values."""
+ # Remove None values to avoid overriding config with None
+ cli_args = {k: v for k, v in cli_args.items() if v is not None}
+
+ # Map CLI arguments to config attributes
+ if 'force' in cli_args and cli_args['force']:
+ config.force = True
+ if 'save_path' in cli_args:
+ config.save_path = cli_args['save_path']
+ if 'upload_to_hub' in cli_args and cli_args['upload_to_hub']:
+ config.upload_to_hub = True
+ if 'hub_repo_id' in cli_args:
+ config.hub_repo_id = cli_args['hub_repo_id']
+ if 'hub_private' in cli_args and cli_args['hub_private']:
+ config.hub_private = True
+ if 'dataset' in cli_args:
+ config.dataset_name = cli_args['dataset']
+ if 'split' in cli_args:
+ config.split = cli_args['split']
+ if 'num_workers' in cli_args:
+ config.num_workers = cli_args['num_workers']
+ if 'qa_workers' in cli_args:
+ config.qa_workers = cli_args['qa_workers']
+ if 'allowable_set' in cli_args:
+ # Parse comma-separated allowable set
+ allowable_str = cli_args['allowable_set']
+ if allowable_str:
+ config.allowable_set = [obj.strip() for obj in allowable_str.split(',') if obj.strip()]
+ else:
+ config.allowable_set = None
+
+ return config
+
+ @staticmethod
+ def create_interactive_config(**cli_args) -> DatasetGenerationConfig:
+ """Create configuration through interactive prompts."""
+ from graid.graid import (
+ get_dataset_name, get_split, get_model_selection,
+ get_confidence_threshold, interactive_question_selection
+ )
+
+ # Interactive configuration gathering
+ typer.secho("π οΈ Interactive Configuration", fg=typer.colors.BLUE, bold=True)
+ typer.echo("Let's configure your dataset generation step by step.")
+ typer.echo()
+
+ dataset_name = get_dataset_name()
+ split = get_split()
+ backend_name, model_name, custom_config = get_model_selection()
+ confidence_threshold = get_confidence_threshold()
+
+ # Get question selection
+ if cli_args.get('interactive_questions'):
+ # Local import to avoid heavy dependencies
+ from graid.data.generate_dataset import interactive_question_selection
+ question_configs = interactive_question_selection()
+ else:
+ # Use default question set
+ question_configs = [
+ {"name": "HowMany", "params": {}},
+ {"name": "IsObjectCentered", "params": {"buffer_ratio": 0.05}},
+ ]
+
+ # Create configuration object
+ config = DatasetGenerationConfig(
+ dataset_name=dataset_name,
+ split=split,
+ models=[model_name] if model_name else [],
+ confidence_threshold=confidence_threshold,
+ question_configs=question_configs,
+ save_path=cli_args.get('save_path'),
+ upload_to_hub=cli_args.get('upload_to_hub', False),
+ hub_repo_id=cli_args.get('hub_repo_id'),
+ hub_private=cli_args.get('hub_private', False),
+ num_workers=cli_args.get('num_workers', 4),
+ qa_workers=cli_args.get('qa_workers', 4),
+ force=cli_args.get('force', False),
+ )
+
+ # Handle allowable set if provided via CLI
+ if cli_args.get('allowable_set'):
+ allowable_str = cli_args['allowable_set']
+ config.allowable_set = [obj.strip() for obj in allowable_str.split(',') if obj.strip()]
+
+ return config
+
+ @staticmethod
+ def validate_configuration(config: DatasetGenerationConfig):
+ """Validate final configuration for consistency."""
+ # Import validators locally
+ from graid.cli.validators import ArgumentValidator
+
+ validator = ArgumentValidator()
+
+ # Validate core parameters
+ validator.require_valid_dataset(config.dataset_name)
+ validator.require_valid_split(config.split)
+
+ if config.allowable_set:
+ validator.require_valid_coco_objects(config.allowable_set)
+
+ # Validate upload configuration
+ if config.upload_to_hub and not config.hub_repo_id:
+ raise ConfigurationError("hub_repo_id is required when upload_to_hub=True")
+
+ # Validate save path
+ if not config.save_path:
+ raise ConfigurationError("save_path is required")
+
+ # Validate question configuration
+ if not config.question_configs:
+ raise ConfigurationError("At least one question type must be configured")
+
+ typer.secho("β Configuration validated successfully", fg=typer.colors.GREEN)
+ typer.echo(f" Dataset: {config.dataset_name}")
+ typer.echo(f" Split: {config.split}")
+ typer.echo(f" Questions: {len(config.question_configs)} types")
+ typer.echo(f" Save path: {config.save_path}")
+ if config.upload_to_hub:
+ typer.echo(f" Hub upload: {config.hub_repo_id}")
+ typer.echo()
diff --git a/graid/src/graid/cli/error_handler.py b/graid/src/graid/cli/error_handler.py
new file mode 100644
index 0000000..3d6479b
--- /dev/null
+++ b/graid/src/graid/cli/error_handler.py
@@ -0,0 +1,129 @@
+"""
+Standardized error handling for CLI commands.
+
+Provides consistent error messaging and eliminates silent failures.
+"""
+
+import os
+import traceback
+
+import typer
+
+from graid.cli.exceptions import (
+ CLIError, ValidationError, ConfigurationError, ProcessingError, UploadError
+)
+
+
+class ErrorHandler:
+ """Standardized error handling for CLI commands."""
+
+ @staticmethod
+ def handle_validation_error(error: ValidationError):
+ """Handle validation errors with appropriate user messaging."""
+ typer.secho(f"β Validation Error: {error.message}", fg=typer.colors.RED, bold=True)
+ typer.echo("π‘ Use --help for usage information or check your configuration.")
+ raise typer.Exit(error.exit_code)
+
+ @staticmethod
+ def handle_configuration_error(error: ConfigurationError):
+ """Handle configuration errors with helpful suggestions."""
+ typer.secho(f"β Configuration Error: {error.message}", fg=typer.colors.RED, bold=True)
+ typer.echo("π‘ Check your configuration file format and required parameters.")
+ raise typer.Exit(error.exit_code)
+
+ @staticmethod
+ def handle_processing_error(error: ProcessingError):
+ """Handle processing errors with debugging information."""
+ typer.secho(f"β Processing Error: {error.message}", fg=typer.colors.RED, bold=True)
+ typer.echo("π‘ This usually indicates an issue with the dataset or model configuration.")
+
+ # Show traceback in debug mode
+ if os.getenv("GRAID_DEBUG_VERBOSE"):
+ typer.echo("\nπ Debug traceback:")
+ traceback.print_exc()
+ else:
+ typer.echo("π‘ Set GRAID_DEBUG_VERBOSE=1 for detailed error information.")
+
+ raise typer.Exit(error.exit_code)
+
+ @staticmethod
+ def handle_upload_error(error: UploadError):
+ """Handle upload errors with network/authentication hints."""
+ typer.secho(f"β Upload Error: {error.message}", fg=typer.colors.RED, bold=True)
+ typer.echo("π‘ Check your HuggingFace Hub authentication and network connection.")
+ typer.echo("π‘ Run 'huggingface-cli login' if not authenticated.")
+ raise typer.Exit(error.exit_code)
+
+ @staticmethod
+ def handle_unexpected_error(error: Exception):
+ """Handle unexpected errors with full debugging information."""
+ typer.secho(f"β Unexpected Error: {error}", fg=typer.colors.RED, bold=True)
+ typer.echo("π‘ This is likely a bug. Please report it with the traceback below.")
+ typer.echo()
+
+ # Always show traceback for unexpected errors
+ typer.echo("π Full traceback:")
+ traceback.print_exc()
+
+ raise typer.Exit(1)
+
+ @staticmethod
+ def handle_cli_error(error: CLIError):
+ """Route CLI errors to appropriate handlers."""
+ if isinstance(error, ValidationError):
+ ErrorHandler.handle_validation_error(error)
+ elif isinstance(error, ConfigurationError):
+ ErrorHandler.handle_configuration_error(error)
+ elif isinstance(error, ProcessingError):
+ ErrorHandler.handle_processing_error(error)
+ elif isinstance(error, UploadError):
+ ErrorHandler.handle_upload_error(error)
+ else:
+ # Generic CLI error
+ typer.secho(f"β Error: {error.message}", fg=typer.colors.RED, bold=True)
+ raise typer.Exit(error.exit_code)
+
+ @staticmethod
+ def safe_operation(operation, error_message: str = "Operation failed"):
+ """
+ Execute operation safely, converting silent failures to explicit errors.
+
+ This replaces patterns like:
+ try:
+ some_operation()
+ except Exception:
+ pass # Silent failure - BAD!
+
+ With:
+ ErrorHandler.safe_operation(some_operation, "Failed to perform operation")
+ """
+ try:
+ return operation()
+ except Exception as e:
+ # Convert silent failure to explicit error
+ typer.secho(f"β {error_message}: {e}", fg=typer.colors.RED)
+ if os.getenv("GRAID_DEBUG_VERBOSE"):
+ traceback.print_exc()
+ raise CLIError(f"{error_message}: {e}")
+
+ @staticmethod
+ def validate_and_execute(validation_func, operation_func, operation_name: str):
+ """
+ Execute validation followed by operation with proper error handling.
+
+ This ensures validation errors are caught and handled appropriately
+ before attempting the main operation.
+ """
+ try:
+ # Validation phase
+ validation_func()
+
+ # Operation phase
+ return operation_func()
+
+ except CLIError:
+ # Re-raise CLI errors as-is (already have proper context)
+ raise
+ except Exception as e:
+ # Convert unexpected errors
+ ErrorHandler.handle_unexpected_error(e)
diff --git a/graid/src/graid/cli/exceptions.py b/graid/src/graid/cli/exceptions.py
new file mode 100644
index 0000000..3411498
--- /dev/null
+++ b/graid/src/graid/cli/exceptions.py
@@ -0,0 +1,47 @@
+"""
+CLI-specific exceptions for better error handling and user messaging.
+"""
+
+
+class CLIError(Exception):
+ """Base exception for CLI-related errors."""
+
+ def __init__(self, message: str, exit_code: int = 1):
+ super().__init__(message)
+ self.message = message
+ self.exit_code = exit_code
+
+
+class ValidationError(CLIError):
+ """Base exception for validation errors."""
+ pass
+
+
+class DatasetValidationError(ValidationError):
+ """Dataset name validation error."""
+ pass
+
+
+class COCOValidationError(ValidationError):
+ """COCO object validation error."""
+ pass
+
+
+class SplitValidationError(ValidationError):
+ """Split specification validation error."""
+ pass
+
+
+class ConfigurationError(CLIError):
+ """Configuration loading/parsing errors."""
+ pass
+
+
+class ProcessingError(CLIError):
+ """Dataset processing errors."""
+ pass
+
+
+class UploadError(CLIError):
+ """HuggingFace Hub upload errors."""
+ pass
diff --git a/graid/src/graid/cli/validators.py b/graid/src/graid/cli/validators.py
new file mode 100644
index 0000000..050595d
--- /dev/null
+++ b/graid/src/graid/cli/validators.py
@@ -0,0 +1,88 @@
+"""
+Centralized validation logic for CLI arguments.
+
+Eliminates duplicate validation code and provides consistent error messages.
+"""
+
+from typing import List, Union
+
+from graid.cli.exceptions import (
+ DatasetValidationError, COCOValidationError, SplitValidationError
+)
+
+
+class ArgumentValidator:
+ """Centralized validation logic for CLI arguments."""
+
+ # Supported datasets (can be extended)
+ SUPPORTED_DATASETS = ["bdd", "nuimage", "waymo"]
+
+ # Valid split names
+ VALID_SPLITS = ["train", "val", "test", "train+val", "both", "all", "trainval"]
+
+ @classmethod
+ def require_valid_dataset(cls, dataset_name: str):
+ """Validate dataset name or raise DatasetValidationError."""
+ if dataset_name not in cls.SUPPORTED_DATASETS:
+ raise DatasetValidationError(
+ f"Invalid dataset: {dataset_name}. Supported datasets: {', '.join(cls.SUPPORTED_DATASETS)}"
+ )
+
+ @classmethod
+ def require_valid_split(cls, split_value: Union[str, List[str]]):
+ """Validate split specification or raise SplitValidationError."""
+ if isinstance(split_value, (list, tuple)):
+ splits = list(split_value)
+ else:
+ splits = [str(split_value)]
+
+ for split in splits:
+ split_lower = split.lower()
+ # Allow individual splits or combined formats
+ if split_lower not in cls.VALID_SPLITS and split_lower not in ["train", "val", "test"]:
+ raise SplitValidationError(
+ f"Invalid split: {split}. Valid splits: {', '.join(cls.VALID_SPLITS)}"
+ )
+
+ @classmethod
+ def require_valid_coco_objects(cls, objects: List[str]):
+ """Validate COCO objects or raise COCOValidationError."""
+ # Local import to avoid heavy dependencies
+ from graid.utilities.coco import validate_coco_objects
+
+ is_valid, error_msg = validate_coco_objects(objects)
+ if not is_valid:
+ raise COCOValidationError(f"Invalid COCO objects: {error_msg}")
+
+ @classmethod
+ def parse_and_validate_split(cls, split_value: str) -> List[str]:
+ """Parse and validate split specification, returning normalized list."""
+ cls.require_valid_split(split_value)
+
+ # Normalize split specification
+ if isinstance(split_value, (list, tuple)):
+ return list(split_value)
+
+ value = str(split_value).lower()
+ if value in {"train+val", "both", "all", "trainval"}:
+ return ["train", "val"]
+
+ return [str(split_value)]
+
+ @classmethod
+ def validate_numeric_range(cls, value: float, min_val: float, max_val: float, name: str):
+ """Validate numeric value is within specified range."""
+ if not (min_val <= value <= max_val):
+ raise ValueError(f"{name} must be between {min_val} and {max_val}, got {value}")
+
+ @classmethod
+ def validate_positive_int(cls, value: int, name: str):
+ """Validate integer is positive."""
+ if value <= 0:
+ raise ValueError(f"{name} must be positive, got {value}")
+
+ @classmethod
+ def require_non_empty_string(cls, value: str, name: str):
+ """Validate string is not empty."""
+ if not value or not value.strip():
+ raise ValueError(f"{name} cannot be empty")
diff --git a/graid/src/graid/cli_helpers.py b/graid/src/graid/cli_helpers.py
new file mode 100644
index 0000000..8fd073e
--- /dev/null
+++ b/graid/src/graid/cli_helpers.py
@@ -0,0 +1,860 @@
+"""
+CLI Helper Classes for GRAID
+
+Provides centralized configuration management, validation, and error handling
+for all GRAID CLI commands.
+"""
+
+import logging
+import os
+from pathlib import Path
+from typing import Any, List, Optional, Union
+
+import typer
+
+logger = logging.getLogger(__name__)
+
+
+class GraidError(Exception):
+ """Base exception for GRAID CLI errors."""
+ pass
+
+
+class ValidationError(GraidError):
+ """Configuration or argument validation error."""
+ pass
+
+
+class DatasetValidationError(ValidationError):
+ """Dataset name validation error."""
+ pass
+
+
+class COCOValidationError(ValidationError):
+ """COCO object validation error."""
+ pass
+
+
+class ConfigurationError(GraidError):
+ """Configuration loading or processing error."""
+ pass
+
+
+class ProcessingError(GraidError):
+ """Dataset generation or processing error."""
+ pass
+
+
+class ArgumentValidator:
+ """Centralized validation logic for CLI arguments."""
+
+ @staticmethod
+ def validate_dataset_name(dataset_name: str) -> str:
+ """
+ Validate dataset name against supported options.
+
+ Args:
+ dataset_name: Dataset name to validate
+
+ Returns:
+ Validated dataset name
+
+ Raises:
+ DatasetValidationError: If dataset name is not supported
+ """
+ # Local import to avoid heavy dependencies
+ from graid.data.loaders import DatasetLoaderFactory
+
+ supported_datasets = DatasetLoaderFactory.get_supported_datasets()
+ if dataset_name not in supported_datasets:
+ raise DatasetValidationError(
+ f"Invalid dataset: '{dataset_name}'. Supported: {supported_datasets}"
+ )
+ return dataset_name
+
+ @staticmethod
+ def validate_split_format(split_value: str) -> List[str]:
+ """
+ Parse and validate split specification.
+
+ Args:
+ split_value: Split value (e.g., "train", "train+val", "train,val")
+
+ Returns:
+ List of validated split names
+
+ Raises:
+ ValidationError: If split format is invalid
+ """
+ if not split_value:
+ raise ValidationError("Split value cannot be empty")
+
+ # Handle different split formats
+ if "+" in split_value:
+ splits = split_value.split("+")
+ elif "," in split_value:
+ splits = split_value.split(",")
+ else:
+ splits = [split_value]
+
+ # Clean and validate each split
+ valid_splits = ["train", "val", "validation", "test"]
+ cleaned_splits = []
+
+ for split in splits:
+ split = split.strip()
+ if not split:
+ continue
+ if split not in valid_splits:
+ raise ValidationError(f"Invalid split: '{split}'. Valid splits: {valid_splits}")
+ cleaned_splits.append(split)
+
+ if not cleaned_splits:
+ raise ValidationError("No valid splits found")
+
+ return cleaned_splits
+
+ @staticmethod
+ def validate_coco_objects(allowable_set: List[str]) -> List[str]:
+ """
+ Validate COCO object class names.
+
+ Args:
+ allowable_set: List of COCO object class names
+
+ Returns:
+ Validated list of COCO object names
+
+ Raises:
+ COCOValidationError: If any object names are invalid
+ """
+ if not allowable_set:
+ return allowable_set
+
+ # Local import for COCO validation
+ from graid.utilities.coco import validate_coco_objects
+
+ is_valid, error_msg = validate_coco_objects(allowable_set)
+ if not is_valid:
+ raise COCOValidationError(error_msg)
+
+ return allowable_set
+
+ @staticmethod
+ def validate_path(path: Optional[str], must_exist: bool = False) -> Optional[Path]:
+ """
+ Validate and convert path string to Path object.
+
+ Args:
+ path: Path string to validate
+ must_exist: Whether the path must already exist
+
+ Returns:
+ Validated Path object or None
+
+ Raises:
+ ValidationError: If path validation fails
+ """
+ if not path:
+ return None
+
+ try:
+ path_obj = Path(path)
+
+ if must_exist and not path_obj.exists():
+ raise ValidationError(f"Path does not exist: {path}")
+
+ return path_obj
+
+ except Exception as e:
+ raise ValidationError(f"Invalid path '{path}': {e}")
+
+ @staticmethod
+ def validate_hub_config(upload_to_hub: bool, hub_repo_id: Optional[str]) -> None:
+ """
+ Validate HuggingFace Hub configuration.
+
+ Args:
+ upload_to_hub: Whether uploading to hub is requested
+ hub_repo_id: Hub repository ID
+
+ Raises:
+ ValidationError: If hub configuration is invalid
+ """
+ if upload_to_hub and not hub_repo_id:
+ raise ValidationError("hub_repo_id is required when upload_to_hub=True")
+
+ if hub_repo_id and "/" not in hub_repo_id:
+ raise ValidationError(
+ "hub_repo_id must be in format 'username/repo-name' or 'org/repo-name'"
+ )
+
+
+class ConfigurationManager:
+ """Handles loading and merging configuration from files and CLI."""
+
+ @staticmethod
+ def load_from_file(config_file: str):
+ """
+ Load configuration from JSON file.
+
+ Args:
+ config_file: Path to configuration file
+
+ Returns:
+ DatasetGenerationConfig object
+
+ Raises:
+ ConfigurationError: If config loading fails
+ """
+ try:
+ # Local import for config support
+ from graid.data.config_support import load_config_from_file
+
+ config_path = Path(config_file)
+ if not config_path.exists():
+ raise ConfigurationError(f"Configuration file not found: {config_file}")
+
+ return load_config_from_file(config_file)
+
+ except Exception as e:
+ if isinstance(e, ConfigurationError):
+ raise
+ raise ConfigurationError(f"Failed to load configuration from {config_file}: {e}")
+
+ @staticmethod
+ def apply_cli_overrides(config, **cli_args):
+ """
+ Apply CLI arguments that override config file values.
+
+ Args:
+ config: Configuration object to modify
+ **cli_args: CLI arguments to apply
+
+ Returns:
+ Modified configuration object
+ """
+ # Override config values with CLI arguments (only if CLI arg is not None)
+ override_fields = [
+ 'save_path', 'upload_to_hub', 'hub_repo_id', 'hub_private',
+ 'dataset', 'split', 'num_workers', 'qa_workers', 'allowable_set'
+ ]
+
+ for field in override_fields:
+ cli_value = cli_args.get(field)
+ if cli_value is not None:
+ # Map CLI field names to config attribute names
+ config_attr = field
+ if field == 'dataset':
+ config_attr = 'dataset_name'
+
+ setattr(config, config_attr, cli_value)
+
+ return config
+
+ @staticmethod
+ def validate_final_config(config) -> None:
+ """
+ Validate final configuration for consistency.
+
+ Args:
+ config: Configuration object to validate
+
+ Raises:
+ ValidationError: If configuration is invalid
+ """
+ validator = ArgumentValidator()
+
+ # Validate required fields
+ if not config.dataset_name:
+ raise ValidationError("Dataset name is required")
+
+ if not config.split:
+ raise ValidationError("Split is required")
+
+ # Validate individual fields
+ validator.validate_dataset_name(config.dataset_name)
+ validator.validate_split_format(config.split)
+
+ if config.allowable_set:
+ validator.validate_coco_objects(config.allowable_set)
+
+ validator.validate_hub_config(config.upload_to_hub, config.hub_repo_id)
+
+ # Validate numeric fields
+ if config.batch_size <= 0:
+ raise ValidationError("batch_size must be positive")
+
+ if config.num_workers < 0:
+ raise ValidationError("num_workers must be non-negative")
+
+ if config.qa_workers <= 0:
+ raise ValidationError("qa_workers must be positive")
+
+
+class ErrorHandler:
+ """Standardized error handling for CLI commands."""
+
+ @staticmethod
+ def handle_validation_error(error: ValidationError) -> None:
+ """
+ Handle validation errors with appropriate user messaging.
+
+ Args:
+ error: The validation error to handle
+ """
+ typer.secho(f"β Validation Error: {error}", fg=typer.colors.RED)
+ typer.echo("Use --help for usage information.")
+ raise typer.Exit(1)
+
+ @staticmethod
+ def handle_configuration_error(error: ConfigurationError) -> None:
+ """
+ Handle configuration errors.
+
+ Args:
+ error: The configuration error to handle
+ """
+ typer.secho(f"β Configuration Error: {error}", fg=typer.colors.RED)
+ typer.echo("Check your configuration file and try again.")
+ raise typer.Exit(1)
+
+ @staticmethod
+ def handle_processing_error(error: ProcessingError) -> None:
+ """
+ Handle processing errors with debugging information.
+
+ Args:
+ error: The processing error to handle
+ """
+ typer.secho(f"β Processing Error: {error}", fg=typer.colors.RED)
+
+ # Show traceback in debug mode
+ if os.getenv("GRAID_DEBUG_VERBOSE"):
+ import traceback
+ typer.echo("\nDetailed traceback:")
+ typer.echo(traceback.format_exc())
+ else:
+ typer.echo("Use GRAID_DEBUG_VERBOSE=1 for detailed error information.")
+
+ raise typer.Exit(1)
+
+ @staticmethod
+ def handle_unexpected_error(error: Exception) -> None:
+ """
+ Handle unexpected errors.
+
+ Args:
+ error: The unexpected error to handle
+ """
+ typer.secho(f"β Unexpected Error: {error}", fg=typer.colors.RED)
+ typer.echo("This is likely a bug. Please report it with the following information:")
+
+ import traceback
+ typer.echo("\nFull traceback:")
+ typer.echo(traceback.format_exc())
+
+ raise typer.Exit(1)
+
+
+class DatasetProcessor:
+ """Handles the actual dataset generation workflow."""
+
+ @staticmethod
+ def process_single_split(config) -> Any:
+ """
+ Process single split dataset generation.
+
+ Args:
+ config: Configuration object
+
+ Returns:
+ Generated DatasetDict
+
+ Raises:
+ ProcessingError: If processing fails
+ """
+ try:
+ # Local import for dataset generation
+ from graid.data.generate_dataset import generate_dataset
+
+ # Normalize split
+ splits = ArgumentValidator.validate_split_format(config.split)
+ if len(splits) != 1:
+ raise ProcessingError("Single split processing called with multiple splits")
+
+ result = generate_dataset(
+ dataset_name=config.dataset_name,
+ split=splits[0],
+ models=getattr(config, 'models', []),
+ use_wbf=getattr(config, 'use_wbf', False),
+ wbf_config=getattr(config, 'wbf_config', None),
+ conf_threshold=getattr(config, 'confidence_threshold', 0.0),
+ batch_size=getattr(config, 'batch_size', 32),
+ device=getattr(config, 'device', None),
+ allowable_set=getattr(config, 'allowable_set', None),
+ question_configs=getattr(config, 'question_configs', []),
+ num_workers=getattr(config, 'num_workers', 4),
+ qa_workers=getattr(config, 'qa_workers', 4),
+ save_path=getattr(config, 'save_path', "./graid-datasets"),
+ upload_to_hub=getattr(config, 'upload_to_hub', False),
+ hub_repo_id=getattr(config, 'hub_repo_id', None),
+ hub_private=getattr(config, 'hub_private', False),
+ num_samples=getattr(config, 'num_samples', None),
+ use_original_filenames=getattr(config, 'use_original_filenames', True),
+ filename_prefix=getattr(config, 'filename_prefix', 'img'),
+ )
+
+ return result
+
+ except Exception as e:
+ if isinstance(e, ProcessingError):
+ raise
+ raise ProcessingError(f"Single split processing failed: {e}")
+
+ @staticmethod
+ def process_multiple_splits(config) -> Any:
+ """
+ Process multi-split dataset generation with combined upload.
+
+ Args:
+ config: Configuration object
+
+ Returns:
+ Combined DatasetDict
+
+ Raises:
+ ProcessingError: If processing fails
+ """
+ try:
+ # Local imports
+ from datasets import DatasetDict
+
+ splits = ArgumentValidator.validate_split_format(config.split)
+ if len(splits) <= 1:
+ raise ProcessingError("Multiple split processing called with single split")
+
+ combined_dict = DatasetDict()
+
+ # Aggregate statistics across all splits
+ aggregated_question_counts = {}
+ aggregated_detailed_stats = {}
+
+ # Process each split separately
+ for split_name in splits:
+ logger.info(f"Processing split: {split_name}")
+
+ # Create temporary config for this split
+ config_dict = config.to_dict()
+ config_dict['split'] = split_name
+ config_dict['upload_to_hub'] = False # Don't upload individual splits
+
+ split_config = config.__class__(**config_dict)
+
+ split_result, split_stats = DatasetProcessor.process_single_split(split_config)
+
+ # Add to combined dict
+ for key, dataset in split_result.items():
+ combined_dict[key] = dataset
+
+ # Aggregate question statistics
+ if split_stats:
+ for qtype, count in split_stats.get('question_counts', {}).items():
+ aggregated_question_counts[qtype] = aggregated_question_counts.get(qtype, 0) + count
+
+ # Aggregate detailed stats (timings, etc.)
+ for qtype, stats in split_stats.get('detailed_stats', {}).items():
+ if qtype not in aggregated_detailed_stats:
+ aggregated_detailed_stats[qtype] = {
+ "is_applicable_time": (0.0, 0),
+ "is_applicable_true_count": 0,
+ "apply_time": (0.0, 0),
+ "apply_empty_results": 0,
+ "total_qa_generated": 0,
+ "question_text": stats.get("question_text", qtype)
+ }
+
+ # Aggregate timing data
+ agg_stats = aggregated_detailed_stats[qtype]
+
+ # Add is_applicable times
+ is_app_time, is_app_count = agg_stats["is_applicable_time"]
+ split_is_app_time, split_is_app_count = stats.get("is_applicable_time", (0.0, 0))
+ agg_stats["is_applicable_time"] = (is_app_time + split_is_app_time, is_app_count + split_is_app_count)
+
+ # Add apply times
+ apply_time, apply_count = agg_stats["apply_time"]
+ split_apply_time, split_apply_count = stats.get("apply_time", (0.0, 0))
+ agg_stats["apply_time"] = (apply_time + split_apply_time, apply_count + split_apply_count)
+
+ # Add other counters
+ agg_stats["is_applicable_true_count"] += stats.get("is_applicable_true_count", 0)
+ agg_stats["apply_empty_results"] += stats.get("apply_empty_results", 0)
+ agg_stats["total_qa_generated"] += stats.get("total_qa_generated", 0)
+
+ # Prepare aggregated statistics for README
+ question_stats = {
+ 'question_counts': aggregated_question_counts,
+ 'detailed_stats': aggregated_detailed_stats
+ } if aggregated_question_counts else None
+
+ # Log aggregated profiling statistics for multi-split processing
+ if question_stats and 'detailed_stats' in question_stats:
+ from graid.utils.profiling import log_profiling_statistics
+ log_profiling_statistics(question_stats, "Multi-Split Aggregated Question Processing Statistics")
+ logger.info("Notes: Aggregated across all processed splits")
+
+ # Handle combined upload if requested
+ if config.upload_to_hub and config.hub_repo_id:
+ DatasetProcessor.handle_hub_upload(combined_dict, config, question_stats)
+
+ # Clean up image files after successful multi-split upload
+ # DatasetProcessor._cleanup_multi_split_images(config.save_path)
+
+ return combined_dict
+
+ except Exception as e:
+ if isinstance(e, ProcessingError):
+ raise
+ raise ProcessingError(f"Multiple split processing failed: {e}")
+
+ @staticmethod
+ def _cleanup_multi_split_images(save_path: str) -> None:
+ """Clean up temporary image files after multi-split processing."""
+ try:
+ import shutil
+ from pathlib import Path
+
+ save_path_obj = Path(save_path)
+
+ # Clean up all split image directories
+ for split_dir in save_path_obj.iterdir():
+ if split_dir.is_dir():
+ images_dir = split_dir / "images"
+ if images_dir.exists():
+ logger.debug(f"Cleaning up images directory: {images_dir}")
+ shutil.rmtree(images_dir)
+ logger.debug(f"β
Cleaned up {images_dir}")
+
+ logger.info("β
Multi-split image cleanup completed successfully")
+
+ except Exception as e:
+ logger.warning(f"Failed to cleanup multi-split image files: {e}")
+
+ @staticmethod
+ def handle_hub_upload(dataset_dict: Any, config, question_stats=None) -> None:
+ """
+ Handle HuggingFace Hub upload workflow with comprehensive README.
+
+ Args:
+ dataset_dict: DatasetDict to upload
+ config: Configuration object
+ question_stats: Optional aggregated question statistics
+
+ Raises:
+ ProcessingError: If upload fails
+ """
+ try:
+ # Local import for Hub utilities
+ from huggingface_hub import create_repo, upload_file
+ from pathlib import Path
+
+ logger.info(f"Uploading to HuggingFace Hub: {config.hub_repo_id}")
+
+ # Create repository
+ create_repo(
+ config.hub_repo_id,
+ repo_type="dataset",
+ private=config.hub_private,
+ exist_ok=True
+ )
+
+ # Generate and upload README if statistics are available
+ if question_stats:
+ readme_content = DatasetProcessor._create_dataset_readme(
+ dataset_dict, config, question_stats
+ )
+
+ # Write README to temporary file
+ readme_path = Path("README.md")
+ readme_path.write_text(readme_content, encoding='utf-8')
+
+ logger.info("π Generated comprehensive README with question statistics")
+
+ # Upload README first
+ upload_file(
+ path_or_fileobj=str(readme_path),
+ path_in_repo="README.md",
+ repo_id=config.hub_repo_id,
+ repo_type="dataset",
+ commit_message="Add comprehensive README with GRAID statistics"
+ )
+
+ # Clean up temporary README file
+ readme_path.unlink()
+ logger.info("π README uploaded successfully")
+
+ # Push dataset with large shard size to avoid 10k file limit
+ dataset_dict.push_to_hub(
+ repo_id=config.hub_repo_id,
+ private=config.hub_private,
+ # embed_external_files=False,
+ commit_message=f"Upload {config.dataset_name} dataset",
+ max_shard_size="5GB", # Large shards to minimize file count
+ )
+
+ logger.info(f"Successfully uploaded to Hub: {config.hub_repo_id}")
+
+ except Exception as e:
+ raise ProcessingError(f"Hub upload failed: {e}")
+
+
+
+ @staticmethod
+ def _create_dataset_readme(dataset_dict, config, question_stats):
+ """
+ Generate comprehensive README content for HuggingFace Hub.
+
+ Args:
+ dataset_dict: DatasetDict with the generated dataset
+ config: Configuration object
+ question_stats: Dictionary with aggregated statistics
+
+ Returns:
+ str: Complete README content in markdown format
+ """
+ from datetime import datetime
+
+ # Calculate total QA pairs across all splits
+ total_qa_pairs = sum(len(dataset_dict[split]) for split in dataset_dict.keys())
+
+ # Dataset-specific configuration
+ dataset_configs = {
+ "bdd": {
+ "full_name": "BDD100K",
+ "description": "Berkeley DeepDrive autonomous driving dataset",
+ "license": "bsd-3-clause",
+ "tags": ["autonomous-driving", "bdd100k"],
+ "source_info": "BDD100K (Berkeley DeepDrive)",
+ "citation": """@inproceedings{bdd100k,
+ title={BDD100K: A Diverse Driving Dataset for Heterogeneous Multitask Learning},
+ author={Yu, Fisher and Chen, Haofeng and Wang, Xin and Xian, Wenqi and Chen, Yingying and Liu, Fangchen and Madhavan, Vashisht and Darrell, Trevor},
+ booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ year={2020}
+}""",
+ "license_text": "This dataset is derived from the BDD100K dataset. Please refer to the [BDD100K license terms](https://github.com/bdd100k/bdd100k) for usage restrictions."
+ },
+ "waymo": {
+ "full_name": "Waymo Open Dataset",
+ "description": "Waymo autonomous driving dataset",
+ "license": "other",
+ "tags": ["autonomous-driving", "waymo"],
+ "source_info": "Waymo Open Dataset",
+ "citation": """@inproceedings{waymo,
+ title={Scalability in Perception for Autonomous Driving: Waymo Open Dataset},
+ author={Sun, Pei and Kretzschmar, Henrik and Dotiwalla, Xerxes and Chouard, Aurelien and Patnaik, Vijaysai and Tsui, Paul and Guo, James and Zhou, Yin and Chai, Yuning and Caine, Benjamin and others},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={2446--2454},
+ year={2020}
+}""",
+ "license_text": "This dataset is derived from the Waymo Open Dataset. Please refer to the [Waymo Open Dataset license terms](https://waymo.com/open/terms/) for usage restrictions."
+ },
+ "nuimage": {
+ "full_name": "NuImages",
+ "description": "Large-scale autonomous driving dataset from nuTonomy",
+ "license": "other",
+ "tags": ["autonomous-driving", "nuimages"],
+ "source_info": "NuImages Dataset",
+ "citation": """@article{nuimages,
+ title={nuImages: A large-scale multimodal dataset for autonomous driving},
+ author={Caesar, Holger and Kabzan, Juraj and Tan, Kok Seang and Fong, Whye Kit and Wolff, Eric and Lang, Alex and Fletcher, Luke and Beijbom, Oscar and Omari, Sammy},
+ journal={arXiv preprint arXiv:2008.09969},
+ year={2020}
+}""",
+ "license_text": "This dataset is derived from the nuImages dataset. Please refer to the [nuImages license terms](https://www.nuscenes.org/terms-of-use) for usage restrictions."
+ },
+ "custom": {
+ "full_name": "Custom Dataset",
+ "description": "User-provided custom dataset",
+ "license": "other",
+ "tags": ["custom-dataset"],
+ "source_info": "Custom user dataset",
+ "citation": """@dataset{graid_custom,
+ title={GRAID Custom Question-Answer Dataset},
+ author={GRAID Framework},
+ year={2024},
+ note={Generated using GRAID framework with custom dataset}
+}""",
+ "license_text": "Please refer to your original dataset's license terms for usage restrictions."
+ }
+ }
+
+ # Get dataset config, default to custom if not found
+ dataset_name = config.dataset_name.lower()
+ dataset_config = dataset_configs.get(dataset_name, dataset_configs["custom"])
+
+ readme_content = f"""---
+pretty_name: "GRAID {dataset_config['full_name']} Question-Answer Dataset"
+language:
+- en
+license: "{dataset_config['license']}"
+task_categories:
+- visual-question-answering
+- object-detection
+tags:
+- visual-reasoning
+- spatial-reasoning
+- object-detection
+- computer-vision"""
+
+ # Add dataset-specific tags
+ for tag in dataset_config['tags']:
+ readme_content += f"\n- {tag}"
+
+ readme_content += f"""
+---
+
+# GRAID {dataset_config['full_name']} Question-Answer Dataset
+
+## Overview
+
+This dataset was generated using **GRAID** (**G**enerating **R**easoning questions from **A**nalysis of **I**mages via **D**iscriminative artificial intelligence), a framework for creating spatial reasoning datasets from object detection annotations.
+
+**GRAID** transforms raw object detection data into structured question-answer pairs that test various aspects of object localization, visual reasoning, spatial reasoning, and object relationship comprehension.
+
+## Dataset Details
+
+"""
+
+ # Add basic dataset information
+ readme_content += f"""- **Total QA Pairs**: {total_qa_pairs:,}
+- **Source Dataset**: {dataset_config['source_info']}
+- **Generation Date**: {datetime.now().strftime('%Y-%m-%d')}
+- **Image Format**: Embedded in parquet files (no separate image files)
+- **Question Types**: {len(question_stats.get('question_counts', {})) if question_stats else 'Multiple'} different reasoning patterns
+
+## Dataset Splits
+
+"""
+
+ # Add split information
+ for split_name in sorted(dataset_dict.keys()):
+ split_size = len(dataset_dict[split_name])
+ percentage = (split_size / total_qa_pairs) * 100
+ readme_content += f"- **{split_name}**: {split_size:,} ({percentage:.1f}%)\n"
+
+ readme_content += "\n## Question Type Distribution\n\n"
+
+ # Add question statistics across all splits
+ if question_stats and 'question_counts' in question_stats:
+ total_questions = sum(question_stats['question_counts'].values())
+ sorted_counts = sorted(question_stats['question_counts'].items(), key=lambda x: x[1], reverse=True)
+
+ for qtype, count in sorted_counts:
+ percentage = (count / total_questions) * 100
+ # Get full question text if available
+ question_text = qtype
+ if 'detailed_stats' in question_stats and qtype in question_stats['detailed_stats']:
+ question_text = question_stats['detailed_stats'][qtype].get("question_text", qtype)
+
+ readme_content += f"- **{question_text}**: {count:,} ({percentage:.1f}%)\n"
+
+ # Add performance profiling information if available
+ if question_stats and 'detailed_stats' in question_stats:
+ from graid.utils.profiling import format_profiling_table, format_profiling_notes
+ readme_content += "\n## Performance Analysis\n\n"
+ readme_content += "### Question Processing Efficiency\n\n"
+ readme_content += format_profiling_table(question_stats, "markdown")
+ readme_content += format_profiling_notes("markdown")
+
+ # Add usage information
+ readme_content += f"""
+## Usage
+
+```python
+from datasets import load_dataset
+
+# Load the complete dataset
+dataset = load_dataset("{config.hub_repo_id}")
+
+# Access individual splits"""
+
+ for split_name in sorted(dataset_dict.keys()):
+ readme_content += f"\n{split_name}_data = dataset[\"{split_name}\"]"
+
+ readme_content += f"""
+
+# Example of accessing a sample
+sample = dataset["train"][0] # or "val"
+print(f"Question: {{sample['question']}}")
+print(f"Answer: {{sample['answer']}}")
+print(f"Question Type: {{sample['question_type']}}")
+
+# The image is embedded as a PIL Image object
+image = sample["image"]
+image.show() # Display the image
+```
+
+## Dataset Schema
+
+- **image**: PIL Image object (embedded, no separate files)
+- **annotations**: COCO-style bounding box annotations
+- **question**: Generated question text
+- **answer**: Corresponding answer text
+- **reasoning**: Additional reasoning information (if applicable)
+- **question_type**: Type of question (e.g., "HowMany", "LeftOf", "Quadrants")
+- **source_id**: Original image identifier from {dataset_config['source_info']}
+
+## License
+
+"""
+
+ readme_content += f"""{dataset_config['license_text']} The GRAID-generated questions and metadata are provided under the same terms.
+
+## Citation
+
+If you use this dataset in your research, please cite both the original dataset and the GRAID framework:
+
+```bibtex
+@dataset{{graid_{dataset_name},
+ title={{GRAID {dataset_config['full_name']} Question-Answer Dataset}},
+ author={{GRAID Framework}},
+ year={{2024}},
+ note={{Generated using GRAID: Generating Reasoning questions from Analysis of Images via Discriminative artificial intelligence}}
+}}
+
+{dataset_config['citation']}
+```
+
+## Contact
+
+For questions about this dataset or the GRAID framework, please open an issue in the repository.
+"""
+
+ return readme_content
+
+
+def safe_mkdir(path: Union[str, Path], description: str = "directory") -> Path:
+ """
+ Safely create directory with proper error handling.
+
+ Args:
+ path: Directory path to create
+ description: Description for error messages
+
+ Returns:
+ Created Path object
+
+ Raises:
+ ProcessingError: If directory creation fails
+ """
+ try:
+ path_obj = Path(path)
+ path_obj.mkdir(parents=True, exist_ok=True)
+ return path_obj
+ except PermissionError:
+ raise ProcessingError(f"Permission denied creating {description}: {path}")
+ except OSError as e:
+ raise ProcessingError(f"Failed to create {description} '{path}': {e}")
+ except Exception as e:
+ raise ProcessingError(f"Unexpected error creating {description} '{path}': {e}")
diff --git a/graid/src/graid/data/Datasets.py b/graid/src/graid/data/Datasets.py
index aa54184..0c7025e 100644
--- a/graid/src/graid/data/Datasets.py
+++ b/graid/src/graid/data/Datasets.py
@@ -9,18 +9,15 @@
import numpy as np
import torch
from PIL import Image
-from graid.data.ImageLoader import (
- Bdd100kDataset,
- NuImagesDataset,
- WaymoDataset,
-)
-from graid.interfaces.ObjectDetectionI import ObjectDetectionModelI
-from graid.questions.ObjectDetectionQ import ALL_QUESTIONS
-from graid.utilities.common import project_root_dir
from sqlitedict import SqliteDict
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
+from graid.data.ImageLoader import Bdd100kDataset, NuImagesDataset, WaymoDataset
+from graid.interfaces.ObjectDetectionI import ObjectDetectionModelI
+from graid.questions.ObjectDetectionQ import ALL_QUESTIONS
+from graid.utilities.common import project_root_dir
+
lock = threading.Lock()
diff --git a/graid/src/graid/data/ImageLoader.py b/graid/src/graid/data/ImageLoader.py
index 12a8f2f..f9a16c6 100644
--- a/graid/src/graid/data/ImageLoader.py
+++ b/graid/src/graid/data/ImageLoader.py
@@ -6,7 +6,7 @@
import pickle
import re
from datetime import datetime, time, timezone
-from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+from typing import Any, Callable, Literal, Optional, Union
import numpy as np
import pandas as pd
@@ -39,7 +39,7 @@ def __init__(
target_transform: Union[Callable, None] = None,
merge_transform: Union[Callable, None] = None,
use_extended_annotations: bool = False,
- img_labels: Optional[List[Dict]] = None,
+ img_labels: Optional[list[dict]] = None,
):
self.img_dir = img_dir
self.transform = transform
@@ -54,7 +54,7 @@ def __init__(
if annotations_file:
self.img_labels = self.load_annotations(annotations_file)
- def load_annotations(self, annotations_file: str) -> List[Dict]:
+ def load_annotations(self, annotations_file: str) -> list[dict]:
"""Load annotations from a JSON file."""
with open(annotations_file, "r") as file:
return json.load(file)
@@ -143,36 +143,14 @@ def __init__(
img_dir = root_dir / "images" / "10k" / split
rle = root_dir / "labels" / "ins_seg" / "rles" / f"ins_seg_{split}.json"
- def merge_transform(image: Tensor, labels, timestamp):
- results = []
- attributes = []
-
- for instance_id, label in enumerate(labels):
- rle = label["rle"]
- mask = cocomask.decode(rle)
- class_label = label["category"]
- class_id = self.category_to_coco_cls(class_label)
- result = InstanceSegmentationResultI(
- score=1.0,
- cls=int(class_id),
- label=class_label,
- instance_id=int(instance_id),
- image_hw=rle["size"],
- mask=torch.from_numpy(mask).unsqueeze(0),
- )
- results.append(result)
- attributes.append(label["attributes"])
-
- return image, results, timestamp
-
super().__init__(
img_dir=str(img_dir),
annotations_file=rle,
- merge_transform=merge_transform,
+ merge_transform=self.merge_transform,
**kwargs,
)
- def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
+ def __getitem__(self, idx: int) -> Union[Any, tuple[Tensor, dict, dict, str]]:
data = self.img_labels["frames"][idx]
img_path = os.path.join(self.img_dir, data["name"])
labels = data["labels"]
@@ -190,7 +168,8 @@ def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
if self.target_transform:
labels = self.target_transform(labels)
if self.merge_transform:
- image, labels, timestamp = self.merge_transform(image, labels, timestamp)
+ image, labels, timestamp = self.merge_transform(
+ image, labels, timestamp)
return {
"name": data["name"],
@@ -200,6 +179,28 @@ def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
"timestamp": timestamp,
}
+ def merge_transform(self, image: Tensor, labels, timestamp):
+ results = []
+ attributes = []
+
+ for instance_id, label in enumerate(labels):
+ rle = label["rle"]
+ mask = cocomask.decode(rle)
+ class_label = label["category"]
+ class_id = self.category_to_coco_cls(class_label)
+ result = InstanceSegmentationResultI(
+ score=1.0,
+ cls=int(class_id),
+ label=class_label,
+ instance_id=int(instance_id),
+ image_hw=rle["size"],
+ mask=torch.from_numpy(mask).unsqueeze(0),
+ )
+ results.append(result)
+ attributes.append(label["attributes"])
+
+ return image, results, timestamp
+
class Bdd100kDataset(ImageDataset):
"""
@@ -324,58 +325,7 @@ def __init__(
)
self.img_labels = self.load_annotations(annotations_file)
-
- def merge_transform(
- image: Tensor,
- labels: List[Dict[str, Any]],
- timestamp: str,
- ) -> Union[
- Tuple[
- Tensor, List[Union[ObjectDetectionResultI, InstanceSegmentationResultI]]
- ],
- Tuple[
- Tensor,
- List[
- Tuple[
- Union[ObjectDetectionResultI, InstanceSegmentationResultI],
- Dict[str, Any],
- str,
- ]
- ],
- Dict[str, Any],
- str,
- ],
- ]:
- results = []
-
- for label in labels:
- channels, height, width = image.shape
- if use_original_categories:
- cls = self.category_to_cls(label["category"])
- res_label = label["category"]
- else:
- cls = self.category_to_coco_cls(label["category"])
- # handle the case where exact category is not in COCO aka different names for people
- res_label = label["category"] if cls != 0 else "person"
-
- result = ObjectDetectionResultI(
- score=1.0,
- cls=cls,
- label=res_label,
- bbox=[
- label["box2d"]["x1"],
- label["box2d"]["y1"],
- label["box2d"]["x2"],
- label["box2d"]["y2"],
- ],
- image_hw=(height, width),
- bbox_format=BBox_Format.XYXY,
- attributes=[label["attributes"]],
- )
-
- results.append(result)
-
- return image, results, timestamp
+ self.use_original_categories = use_original_categories
# finally, filter out following labels
# 'other person', 'other vehicle' and 'trail'
@@ -443,10 +393,39 @@ def merge_transform(
# No mapping file and not rebuilding, so use empty mapping
self.filtered_to_orig_mapping = {}
+ # Ensure per-image pickle files exist for this split. If missing, build them.
+ pkl_root = project_root_dir() / "data" / f"bdd_{self.split}"
+ try:
+ pkl_root.mkdir(parents=True, exist_ok=True)
+ os.chmod(pkl_root, 0o777)
+ except Exception:
+ pass
+ need_build = True
+ # Quick existence check for index 0
+ if (pkl_root / "0.pkl").exists():
+ need_build = False
+ if need_build:
+ print(f"Building per-image pickle cache for BDD100K {self.split}...")
+ for idx, label in tqdm(enumerate(self.img_labels), total=len(self.img_labels), desc=f"Indexing BDD100K {self.split}..."):
+ # respect filtering flag when deciding to include
+ if self.use_time_filtered and (not self._meets_filtering_criteria(label)):
+ continue
+ name = label.get("name")
+ timestamp = label.get("timestamp", 0)
+ labels = label.get("labels", [])
+ save_path = pkl_root / f"{idx}.pkl"
+ try:
+ with open(save_path, "wb") as f:
+ pickle.dump({"name": name, "labels": labels, "timestamp": timestamp}, f)
+ os.chmod(save_path, 0o777)
+ except Exception:
+ # best-effort; skip on failure
+ continue
+
super().__init__(
annotations_file=str(annotations_file),
img_dir=str(img_dir),
- merge_transform=merge_transform,
+ merge_transform=self.merge_transform,
use_extended_annotations=use_extended_annotations,
**kwargs,
)
@@ -463,18 +442,20 @@ def __len__(self) -> int:
# Fallback to original dataset size if no mapping exists
return len(self.img_labels)
- def _meets_filtering_criteria(self, label: Dict[str, Any]) -> bool:
+ def _meets_filtering_criteria(self, label: dict[str, Any]) -> bool:
"""
Check if an image meets the filtering criteria:
- timeofday must be 'daytime'
- weather must not be 'foggy', 'snowy', or 'rainy'
+ Object category filtering is always applied separately.
- Args:
- label: Image label dictionary containing attributes
-
- Returns:
- True if the image meets filtering criteria, False otherwise
+ When self.use_time_filtered is False, we should skip the time & weather
+ checks entirely and keep all images.
"""
+ # If we're not applying time/weather filtering, accept all images.
+ if not self.use_time_filtered:
+ return True
+
if "attributes" not in label:
return False
@@ -490,7 +471,7 @@ def _meets_filtering_criteria(self, label: Dict[str, Any]) -> bool:
return True
- def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
+ def __getitem__(self, idx: int) -> Union[Any, tuple[Tensor, dict, dict, str]]:
# If we're using filtered dataset and have a mapping
if self.use_time_filtered and hasattr(self, "filtered_to_orig_mapping"):
if not self.filtered_to_orig_mapping and os.path.exists(self.mapping_file):
@@ -513,7 +494,11 @@ def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
# Use original dataset if not filtered
orig_path = project_root_dir() / "data" / f"bdd_{self.split}" / f"{idx}.pkl"
- # Load the data from the original file
+ if not orig_path.exists():
+ raise FileNotFoundError(
+ f"Pickle file {orig_path} not found. Set rebuild=True to generate it."
+ )
+
with open(orig_path, "rb") as f:
data = pickle.load(f)
@@ -550,7 +535,8 @@ def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
if self.target_transform:
labels = self.target_transform(labels)
if self.merge_transform:
- image, labels, timestamp = self.merge_transform(image, labels, timestamp)
+ image, labels, timestamp = self.merge_transform(
+ image, labels, timestamp)
return {
"name": data["name"],
@@ -560,6 +546,61 @@ def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
"timestamp": timestamp,
}
+ def merge_transform(
+ self,
+ image: Tensor,
+ labels: list[dict[str, Any]],
+ timestamp: str,
+ ) -> Union[
+ tuple[
+ Tensor, list[Union[ObjectDetectionResultI,
+ InstanceSegmentationResultI]]
+ ],
+ tuple[
+ Tensor,
+ list[
+ tuple[
+ Union[ObjectDetectionResultI,
+ InstanceSegmentationResultI],
+ dict[str, Any],
+ str,
+ ]
+ ],
+ dict[str, Any],
+ str,
+ ],
+ ]:
+ results = []
+
+ for label in labels:
+ channels, height, width = image.shape
+ if self.use_original_categories:
+ cls = self.category_to_cls(label["category"])
+ res_label = label["category"]
+ else:
+ cls = self.category_to_coco_cls(label["category"])
+ # handle the case where exact category is not in COCO aka different names for people
+ res_label = label["category"] if cls != 0 else "person"
+
+ result = ObjectDetectionResultI(
+ score=1.0,
+ cls=cls,
+ label=res_label,
+ bbox=[
+ label["box2d"]["x1"],
+ label["box2d"]["y1"],
+ label["box2d"]["x2"],
+ label["box2d"]["y2"],
+ ],
+ image_hw=(height, width),
+ bbox_format=BBox_Format.XYXY,
+ attributes=[label["attributes"]],
+ )
+
+ results.append(result)
+
+ return image, results, timestamp
+
class NuImagesDataset(ImageDataset):
"""
@@ -692,8 +733,8 @@ def category_to_coco(self, category: str):
return self._CATEGORIES_TO_COCO[category]
def filter_by_token(
- self, data: List[Dict[str, Any]], field: str, match_value: str
- ) -> List[Dict[str, Any]]:
+ self, data: list[dict[str, Any]], field: str, match_value: str
+ ) -> list[dict[str, Any]]:
filtered_list = []
for item in data:
if item.get(field) == match_value:
@@ -774,15 +815,9 @@ def __init__(
self.category_labels = json.load(open(categories_file))
self.obj_annotations = json.load(open(obj_annotations_file))
- if rebuild:
- from nuimages import NuImages
-
- self.nuim = NuImages(
- dataroot=img_dir,
- version=subdir,
- verbose=False,
- lazy=True, # verbose off to avoid excessive print statement
- )
+ # Auto-generate per-image pickle cache when missing or when rebuild is requested
+ if rebuild or True:
+ # If cache already exists, we can skip heavy work unless rebuild=True
save_path_parent = (
project_root_dir()
/ "data"
@@ -798,106 +833,81 @@ def __init__(
except Exception as e:
logger.warning(f"Failed to set permissions on {save_path_parent}: {e}")
- empty_count = 0
- idx = 0
- for i in tqdm(
- range(len(self.nuim.sample)),
- desc="Processing NuImage dataset...", # len(self.nuim.sample)
- ):
- # see: https://www.nuscenes.org/tutorials/nuimages_tutorial.html
- sample = self.nuim.sample[i]
- sample_token = sample["token"]
- key_camera_token = sample["key_camera_token"]
- object_tokens, surface_tokens = self.nuim.list_anns(
- sample_token, verbose=False
- ) # verbose off to avoid excessive print statement
- if object_tokens == []:
- empty_count += 1
- continue
+ need_build = rebuild or not (save_path_parent / "0.pkl").exists()
+ if need_build:
+ from nuimages import NuImages
- object_data = []
- for object_token in object_tokens:
- obj = self.nuim.get("object_ann", object_token)
- category_token = obj["category_token"]
- attribute_tokens = obj["attribute_tokens"]
- attributes = []
- for attribute_token in attribute_tokens:
- attribute = self.nuim.get("attribute", attribute_token)
- attributes.append(attribute)
-
- category = self.nuim.get("category", category_token)["name"]
- obj["category"] = category
- obj["attributes"] = attributes
- object_data.append(obj)
-
- sample_data = self.nuim.get("sample_data", key_camera_token)
- img_filename = sample_data["filename"]
- timestamp = sample_data["timestamp"]
-
- save_path = save_path_parent / f"{idx}.pkl"
- print("creating idx... ", idx)
- # if save_path.exists():
- # continue
- if self.use_time_filtered and not self.is_time_in_working_hours(
- img_filename
- ):
- print("invalid")
- continue
+ self.nuim = NuImages(
+ dataroot=img_dir,
+ version=subdir,
+ verbose=False,
+ lazy=True,
+ )
- with open(save_path, "wb") as f:
- pickle.dump(
- {
- "filename": img_filename,
- "labels": object_data,
- "timestamp": timestamp,
- },
- f,
+ empty_count = 0
+ idx = 0
+ for i in tqdm(
+ range(len(self.nuim.sample)),
+ desc=f"Processing NuImages {self.split}...",
+ ):
+ sample = self.nuim.sample[i]
+ sample_token = sample["token"]
+ key_camera_token = sample["key_camera_token"]
+ object_tokens, surface_tokens = self.nuim.list_anns(
+ sample_token, verbose=False
)
- os.chmod(save_path, 0o777)
- idx += 1
-
- # TODO: add error catching logic in case of empty token or token mismatch.
+ if not object_tokens:
+ empty_count += 1
+ continue
- print(
- f"{split} has {empty_count} out of {len(self.nuim.sample)} empty samples."
- )
+ object_data = []
+ for object_token in object_tokens:
+ obj = self.nuim.get("object_ann", object_token)
+ category_token = obj["category_token"]
+ attribute_tokens = obj["attribute_tokens"]
+ attributes = []
+ for attribute_token in attribute_tokens:
+ attribute = self.nuim.get("attribute", attribute_token)
+ attributes.append(attribute)
+ category = self.nuim.get("category", category_token)["name"]
+ obj["category"] = category
+ obj["attributes"] = attributes
+ object_data.append(obj)
+
+ sample_data = self.nuim.get("sample_data", key_camera_token)
+ img_filename = sample_data["filename"]
+ timestamp = sample_data["timestamp"]
+
+ # Apply time filtering if enabled
+ if self.use_time_filtered and not self.is_time_in_working_hours(
+ img_filename
+ ):
+ continue
- def merge_transform(
- image: Tensor, labels: List[Dict[str, Any]], timestamp: str
- ) -> Tuple[
- Tensor,
- List[Tuple[ObjectDetectionResultI, Dict[str, Any], str]],
- List[Dict[str, Any]],
- str,
- ]:
- results = []
- attributes = []
+ save_path = save_path_parent / f"{idx}.pkl"
+ try:
+ with open(save_path, "wb") as f:
+ pickle.dump(
+ {
+ "filename": img_filename,
+ "labels": object_data,
+ "timestamp": timestamp,
+ },
+ f,
+ )
+ os.chmod(save_path, 0o777)
+ except Exception:
+ # best-effort; skip on failure
+ pass
+ idx += 1
- for obj_label in labels:
- _, height, width = image.shape
- obj_category = obj_label["category"]
- obj_attributes = obj_label["attributes"]
- label = self.category_to_coco(obj_category)
- cls = self.category_to_cls(label)
-
- results.append(
- ObjectDetectionResultI(
- score=1.0,
- cls=cls,
- label=label,
- bbox=obj_label["bbox"],
- image_hw=(height, width),
- bbox_format=BBox_Format.XYXY,
- attributes=obj_attributes,
- )
+ print(
+ f"{self.split} has {empty_count} out of {len(self.nuim.sample)} empty samples."
)
- attributes.append(obj_attributes)
-
- return (image, results, attributes, timestamp)
super().__init__(
img_dir=img_dir,
- merge_transform=merge_transform,
+ merge_transform=self.merge_transform,
**kwargs,
)
@@ -913,7 +923,7 @@ def __len__(self) -> int:
)
return len(os.listdir(save_path))
- def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
+ def __getitem__(self, idx: int) -> Union[Any, tuple[Tensor, dict, dict, str]]:
# if isinstance(idx, slice):
# img_filename = self.img_labels[idx][0]["filename"]
# labels = self.img_labels[idx][0]["labels"]
@@ -957,7 +967,7 @@ def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
)
return {
- "name": data["name"],
+ "name": img_filename,
"path": img_path,
"image": image,
"labels": labels,
@@ -965,6 +975,42 @@ def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
"timestamp": timestamp,
}
+ def merge_transform(
+ self,
+ image: Tensor,
+ labels: list[dict[str, Any]],
+ timestamp: str
+ ) -> tuple[
+ Tensor,
+ list[tuple[ObjectDetectionResultI, dict[str, Any], str]],
+ list[dict[str, Any]],
+ str,
+ ]:
+ results = []
+ attributes = []
+
+ for obj_label in labels:
+ _, height, width = image.shape
+ obj_category = obj_label["category"]
+ obj_attributes = obj_label["attributes"]
+ label = self.category_to_coco(obj_category)
+ cls = self.category_to_cls(label)
+
+ results.append(
+ ObjectDetectionResultI(
+ score=1.0,
+ cls=cls,
+ label=label,
+ bbox=obj_label["bbox"],
+ image_hw=(height, width),
+ bbox_format=BBox_Format.XYXY,
+ attributes=obj_attributes,
+ )
+ )
+ attributes.append(obj_attributes)
+
+ return (image, results, attributes, timestamp)
+
class NuImagesDataset_seg(ImageDataset):
"""
@@ -1094,8 +1140,8 @@ def category_to_cls(self, category: str) -> int:
return self._CATEGORIES[category]
def filter_by_token(
- self, data: List[Dict[str, Any]], field: str, match_value: str
- ) -> List[Dict[str, Any]]:
+ self, data: list[dict[str, Any]], field: str, match_value: str
+ ) -> list[dict[str, Any]]:
filtered_list = []
for item in data:
if item.get(field) == match_value:
@@ -1115,7 +1161,8 @@ def __init__(
img_dir = root_dir
mask_annotations_file = root_dir / f"v1.0-{split}" / "object_ann.json"
categories_file = root_dir / f"v1.0-{split}" / "category.json"
- sample_data_labels_file = root_dir / f"v1.0-{split}" / "sample_data.json"
+ sample_data_labels_file = root_dir / \
+ f"v1.0-{split}" / "sample_data.json"
attributes_file = root_dir / f"v1.0-{split}" / "attribute.json"
self.nuim = NuImages(
@@ -1165,48 +1212,14 @@ def __init__(
}
)
- def merge_transform(
- image: Tensor, labels: List[Dict[str, Any]], timestamp: str
- ) -> Tuple[
- Tensor,
- List[Tuple[InstanceSegmentationResultI, Dict[str, Any], str]],
- Dict[str, Any],
- str,
- ]:
- results = []
- attributes = []
-
- for instance_id, obj_label in enumerate(labels):
- _, height, width = image.shape
- obj_category = obj_label["category"]
- obj_attributes = obj_label["attributes"]
- new_mask = obj_label["mask"].copy()
- new_mask["counts"] = base64.b64decode(new_mask["counts"])
- mask = cocomask.decode(new_mask)
-
- results.append(
- InstanceSegmentationResultI(
- score=1.0,
- cls=self.category_to_cls(obj_category),
- label=obj_category,
- instance_id=instance_id,
- image_hw=(height, width),
- mask=torch.from_numpy(mask).unsqueeze(0),
- mask_format=Mask_Format.BITMASK,
- )
- )
- attributes.append(obj_attributes)
-
- return (image, results, attributes, timestamp)
-
super().__init__(
img_labels=img_labels,
img_dir=img_dir,
- merge_transform=merge_transform,
+ merge_transform=self.merge_transform,
**kwargs,
)
- def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
+ def __getitem__(self, idx: int) -> Union[Any, tuple[Tensor, dict, dict, str]]:
img_filename = self.img_labels[idx]["filename"]
labels = self.img_labels[idx]["labels"]
timestamp = self.img_labels[idx]["timestamp"]
@@ -1237,6 +1250,43 @@ def __getitem__(self, idx: int) -> Union[Any, Tuple[Tensor, Dict, Dict, str]]:
"timestamp": timestamp,
}
+ def merge_transform(
+ self,
+ image: Tensor,
+ labels: list[dict[str, Any]],
+ timestamp: str
+ ) -> tuple[
+ Tensor,
+ list[tuple[InstanceSegmentationResultI, dict[str, Any], str]],
+ dict[str, Any],
+ str,
+ ]:
+ results = []
+ attributes = []
+
+ for instance_id, obj_label in enumerate(labels):
+ _, height, width = image.shape
+ obj_category = obj_label["category"]
+ obj_attributes = obj_label["attributes"]
+ new_mask = obj_label["mask"].copy()
+ new_mask["counts"] = base64.b64decode(new_mask["counts"])
+ mask = cocomask.decode(new_mask)
+
+ results.append(
+ InstanceSegmentationResultI(
+ score=1.0,
+ cls=self.category_to_cls(obj_category),
+ label=obj_category,
+ instance_id=instance_id,
+ image_hw=(height, width),
+ mask=torch.from_numpy(mask).unsqueeze(0),
+ mask_format=Mask_Format.BITMASK,
+ )
+ )
+ attributes.append(obj_attributes)
+
+ return (image, results, attributes, timestamp)
+
class WaymoDataset(ImageDataset):
"""
@@ -1282,6 +1332,8 @@ class WaymoDataset(ImageDataset):
_CATEGORIES_R = {v: k for k, v in _CATEGORIES.items()}
+ _CLS_TO_CATEGORIES = {str(v): k for k, v in _CATEGORIES.items()}
+
_CLS_TO_COCO_CLS = {
"TYPE_UNKNOWN": "undefined",
"TYPE_VEHICLE": "car",
@@ -1459,34 +1511,11 @@ def __init__(
)
idx += 1
- def merge_transform(image, labels, attributes, timestamp):
- results = []
-
- for label in labels:
- cls = label["type"]
- bbox = label["bbox"]
- class_label = self.cls_to_category(cls)
- cls = self.category_to_cls(class_label)
- label = self.category_to_coco(class_label)
-
- result = ObjectDetectionResultI(
- score=1.0,
- cls=cls,
- label=label,
- bbox=list(bbox),
- image_hw=image.shape,
- attributes=[attributes],
- bbox_format=BBox_Format.XYXY,
- )
- results.append(result)
-
- return (image, results, attributes, timestamp)
-
# Call the parent class constructor (no annotations_file argument)
super().__init__(
annotations_file=None,
img_dir=str(self.camera_img_dir),
- merge_transform=merge_transform,
+ merge_transform=self.merge_transform,
**kwargs,
)
@@ -1502,7 +1531,7 @@ def __len__(self) -> int:
)
return len(os.listdir(save_path))
- def __getitem__(self, idx: int) -> Dict:
+ def __getitem__(self, idx: int) -> dict:
"""Retrieve an image and its annotations."""
if idx >= self.__len__():
raise IndexError(
@@ -1553,6 +1582,29 @@ def __getitem__(self, idx: int) -> Dict:
"timestamp": timestamp,
}
+ def merge_transform(self, image, labels, attributes, timestamp):
+ results = []
+
+ for label in labels:
+ cls = label["type"]
+ bbox = label["bbox"]
+ class_label = self.cls_to_category(cls)
+ cls = self.category_to_cls(class_label)
+ label = self.category_to_coco(class_label)
+
+ result = ObjectDetectionResultI(
+ score=1.0,
+ cls=cls,
+ label=label,
+ bbox=list(bbox),
+ image_hw=image.shape,
+ attributes=[attributes],
+ bbox_format=BBox_Format.XYXY,
+ )
+ results.append(result)
+
+ return (image, results, attributes, timestamp)
+
class WaymoDataset_seg(ImageDataset):
@@ -1723,47 +1775,19 @@ def __init__(
f"No valid data found in {self.camera_img_dir} and {self.camera_box_dir}"
)
- def merge_transform(image, labels, attributes, timestamp):
- masks_bytes = labels[0]["masks"]
- divisor = labels[0]["divisor"]
- instance_id = labels[0]["instance_id"]
- masks = transforms.ToTensor()(Image.open(io.BytesIO(masks_bytes)))
- instance_masks = masks % divisor
- semantic_masks = masks // divisor
-
- results = []
- for i in instance_id:
- semantic_id = self.get_semantic_class(instance_masks, semantic_masks, i)
- if len(semantic_id) == 0:
- # Skip if semantic class could not be determined
- continue
- class_id = int(semantic_id[0])
- instance_mask = instance_masks == i
- result = InstanceSegmentationResultI(
- score=1.0,
- cls=class_id,
- label=self.cls_to_category(class_id),
- instance_id=i,
- image_hw=image.shape,
- mask=instance_mask,
- )
- results.append(result)
-
- return image, results, attributes, timestamp
-
# Call the parent class constructor (no annotations_file argument)
super().__init__(
annotations_file=None,
img_dir=str(self.camera_img_dir),
img_labels=self.img_labels,
- merge_transform=merge_transform,
+ merge_transform=self.merge_transform,
**kwargs,
)
def __len__(self) -> int:
return len(self.img_labels)
- def __getitem__(self, idx: int) -> Dict:
+ def __getitem__(self, idx: int) -> dict:
"""Retrieve an image and its annotations."""
if idx >= len(self.img_labels):
raise IndexError(
@@ -1800,3 +1824,31 @@ def __getitem__(self, idx: int) -> Dict:
"attributes": attributes,
"timestamp": timestamp,
}
+
+ def merge_transform(self, image, labels, attributes, timestamp):
+ masks_bytes = labels[0]["masks"]
+ divisor = labels[0]["divisor"]
+ instance_id = labels[0]["instance_id"]
+ masks = transforms.ToTensor()(Image.open(io.BytesIO(masks_bytes)))
+ instance_masks = masks % divisor
+ semantic_masks = masks // divisor
+
+ results = []
+ for i in instance_id:
+ semantic_id = self.get_semantic_class(instance_masks, semantic_masks, i)
+ if len(semantic_id) == 0:
+ # Skip if semantic class could not be determined
+ continue
+ class_id = int(semantic_id[0])
+ instance_mask = instance_masks == i
+ result = InstanceSegmentationResultI(
+ score=1.0,
+ cls=class_id,
+ label=self.cls_to_category(class_id),
+ instance_id=i,
+ image_hw=image.shape,
+ mask=instance_mask,
+ )
+ results.append(result)
+
+ return image, results, attributes, timestamp
diff --git a/graid/src/graid/data/config_support.py b/graid/src/graid/data/config_support.py
new file mode 100644
index 0000000..aa66b18
--- /dev/null
+++ b/graid/src/graid/data/config_support.py
@@ -0,0 +1,455 @@
+"""
+Configuration File Support for HuggingFace Dataset Generation
+
+This module provides configuration file support for specifying models and WBF settings
+without using CLI arguments. It supports JSON configuration files with validation.
+"""
+
+import json
+import logging
+from pathlib import Path
+from typing import Any, Optional, Union
+
+from graid.data.generate_db import create_model
+from graid.utilities.coco import coco_labels
+from graid.utilities.common import get_default_device
+
+logger = logging.getLogger(__name__)
+
+
+class ConfigurationError(Exception):
+ """Exception raised for configuration-related errors."""
+
+ pass
+
+
+class ModelConfig:
+ """Configuration for a single model."""
+
+ def __init__(
+ self,
+ backend: str,
+ model_name: str,
+ custom_config: Optional[dict[str, Any]] = None,
+ confidence_threshold: float = 0.2,
+ device: Optional[str] = None,
+ ):
+ self.backend = backend
+ self.model_name = model_name
+ self.custom_config = custom_config
+ self.confidence_threshold = confidence_threshold
+ self.device = device
+
+ # Validate configuration
+ self._validate()
+
+ def _validate(self):
+ """Validate the model configuration."""
+ # Check if backend is supported
+ supported_backends = ["detectron", "mmdetection", "ultralytics"]
+ if self.backend not in supported_backends:
+ raise ConfigurationError(
+ f"Unsupported backend: {self.backend}. Supported: {supported_backends}"
+ )
+
+ # For detectron and mmdetection, custom config is required
+ if self.backend in ["detectron", "mmdetection"] and self.custom_config is None:
+ raise ConfigurationError(
+ f"Custom config is required for {self.backend} backend. "
+ f"Use create_model() with custom_config parameter."
+ )
+
+ # Validate custom config structure
+ if self.custom_config:
+ self._validate_custom_config()
+
+ def _validate_custom_config(self):
+ """Validate custom configuration structure."""
+ if self.custom_config is None:
+ return
+
+ if self.backend == "detectron":
+ if (
+ "config" not in self.custom_config
+ or "weights" not in self.custom_config
+ ):
+ raise ConfigurationError(
+ "Detectron custom config must have 'config' and 'weights' keys"
+ )
+ elif self.backend == "mmdetection":
+ if (
+ "config" not in self.custom_config
+ or "checkpoint" not in self.custom_config
+ ):
+ raise ConfigurationError(
+ "MMDetection custom config must have 'config' and 'checkpoint' keys"
+ )
+ elif self.backend == "ultralytics":
+ if (
+ isinstance(self.custom_config, dict)
+ and "model_file" not in self.custom_config
+ ):
+ raise ConfigurationError(
+ "Ultralytics custom config must have 'model_file' key"
+ )
+
+ def create_model(self):
+ """Create a model instance from this configuration."""
+ device = self.device or get_default_device()
+
+ return create_model(
+ backend=self.backend,
+ model_name=self.model_name,
+ device=device,
+ threshold=self.confidence_threshold,
+ custom_config=self.custom_config,
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary representation."""
+ return {
+ "backend": self.backend,
+ "model_name": self.model_name,
+ "custom_config": self.custom_config,
+ "confidence_threshold": self.confidence_threshold,
+ "device": self.device,
+ }
+
+
+class WBFConfig:
+ """Configuration for Weighted Boxes Fusion."""
+
+ def __init__(
+ self,
+ iou_threshold: float = 0.55,
+ skip_box_threshold: float = 0.0,
+ model_weights: Optional[list[float]] = None,
+ ):
+ self.iou_threshold = iou_threshold
+ self.skip_box_threshold = skip_box_threshold
+ self.model_weights = model_weights
+
+ # Validate configuration
+ self._validate()
+
+ def _validate(self):
+ """Validate WBF configuration."""
+ if not 0.0 <= self.iou_threshold <= 1.0:
+ raise ConfigurationError(
+ f"iou_threshold must be between 0.0 and 1.0, got {self.iou_threshold}"
+ )
+
+ if not 0.0 <= self.skip_box_threshold <= 1.0:
+ raise ConfigurationError(
+ f"skip_box_threshold must be between 0.0 and 1.0, got {self.skip_box_threshold}"
+ )
+
+ if self.model_weights is not None:
+ if not all(w > 0 for w in self.model_weights):
+ raise ConfigurationError("All model weights must be positive")
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary representation."""
+ return {
+ "iou_threshold": self.iou_threshold,
+ "skip_box_threshold": self.skip_box_threshold,
+ "model_weights": self.model_weights,
+ }
+
+
+class DatasetGenerationConfig:
+ """Complete configuration for dataset generation."""
+
+ def __init__(
+ self,
+ dataset_name: str,
+ split: str,
+ models: list[ModelConfig],
+ use_wbf: bool = False,
+ wbf_config: Optional[WBFConfig] = None,
+ confidence_threshold: float = 0.2,
+ batch_size: int = 1,
+ device: Optional[str] = None,
+ allowable_set: Optional[list[str]] = None,
+ question_configs: Optional[list[dict[str, Any]]] = None,
+ num_workers: int = 4,
+ qa_workers: int = 4,
+ save_path: Optional[str] = None,
+ upload_to_hub: bool = False,
+ hub_repo_id: Optional[str] = None,
+ hub_private: bool = False,
+ num_samples: Optional[int] = None,
+ use_original_filenames: bool = True,
+ filename_prefix: str = "img",
+ ):
+ self.dataset_name = dataset_name
+ self.split = split
+ self.models = models
+ self.use_wbf = use_wbf
+ self.wbf_config = wbf_config
+ self.confidence_threshold = confidence_threshold
+ self.batch_size = batch_size
+ self.device = device
+ self.allowable_set = allowable_set
+ self.question_configs = question_configs
+ self.num_workers = num_workers
+ self.qa_workers = qa_workers
+ self.save_path = save_path
+ self.upload_to_hub = upload_to_hub
+ self.hub_repo_id = hub_repo_id
+ self.hub_private = hub_private
+ self.num_samples = num_samples
+ self.use_original_filenames = use_original_filenames
+ self.filename_prefix = filename_prefix
+
+ # Validate configuration
+ self._validate()
+
+ def _validate(self):
+ """Validate the complete configuration."""
+ # Validate dataset name
+ supported_datasets = ["bdd", "nuimage", "waymo"]
+ if self.dataset_name not in supported_datasets:
+ raise ConfigurationError(f"Unsupported dataset: {self.dataset_name}")
+
+ # Validate split - support individual splits and combined splits like "train+val"
+ valid_individual_splits = ["train", "val", "test"]
+ if "+" in self.split:
+ # Handle combined splits like "train+val"
+ split_parts = [s.strip() for s in self.split.split("+")]
+ for part in split_parts:
+ if part not in valid_individual_splits:
+ raise ConfigurationError(f"Invalid split part: {part}. Valid splits: {valid_individual_splits}")
+ else:
+ # Handle individual splits
+ if self.split not in valid_individual_splits:
+ raise ConfigurationError(f"Invalid split: {self.split}. Valid splits: {valid_individual_splits}")
+
+ # Validate models
+ if not self.models:
+ logger.warning("No models specified, will use ground truth")
+
+ # Validate WBF configuration
+ if self.use_wbf:
+ if len(self.models) < 2:
+ raise ConfigurationError("WBF requires at least 2 models")
+
+ if self.wbf_config is None:
+ self.wbf_config = WBFConfig()
+
+ if self.wbf_config.model_weights is not None:
+ if len(self.wbf_config.model_weights) != len(self.models):
+ raise ConfigurationError(
+ f"Number of model weights ({len(self.wbf_config.model_weights)}) "
+ f"must match number of models ({len(self.models)})"
+ )
+
+ # Validate Hub configuration
+ if self.upload_to_hub and not self.hub_repo_id:
+ raise ConfigurationError("hub_repo_id is required when upload_to_hub=True")
+
+ # Validate allowable_set
+ if self.allowable_set is not None:
+ valid_coco_objects = set(coco_labels.values())
+ # Remove undefined as it's not a real COCO class
+ valid_coco_objects.discard("undefined")
+
+ invalid_objects = []
+ for obj in self.allowable_set:
+ if obj not in valid_coco_objects:
+ invalid_objects.append(obj)
+
+ if invalid_objects:
+ raise ConfigurationError(
+ f"Invalid COCO objects in allowable_set: {invalid_objects}. "
+ f"Valid objects: {sorted(valid_coco_objects)}"
+ )
+
+ def create_models(self):
+ """Create model instances from the configuration."""
+ return [model.create_model() for model in self.models]
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary representation."""
+ return {
+ "dataset_name": self.dataset_name,
+ "split": self.split,
+ "models": [model.to_dict() for model in self.models],
+ "use_wbf": self.use_wbf,
+ "wbf_config": self.wbf_config.to_dict() if self.wbf_config else None,
+ "confidence_threshold": self.confidence_threshold,
+ "batch_size": self.batch_size,
+ "device": self.device,
+ "allowable_set": self.allowable_set,
+ "save_path": self.save_path,
+ "upload_to_hub": self.upload_to_hub,
+ "hub_repo_id": self.hub_repo_id,
+ "hub_private": self.hub_private,
+ "question_configs": self.question_configs,
+ "num_workers": self.num_workers,
+ "qa_workers": self.qa_workers,
+ "num_samples": self.num_samples,
+ "use_original_filenames": self.use_original_filenames,
+ "filename_prefix": self.filename_prefix,
+ }
+
+
+def load_config_from_file(config_path: Union[str, Path]) -> DatasetGenerationConfig:
+ """
+ Load configuration from JSON file.
+
+ Args:
+ config_path: Path to the configuration file
+
+ Returns:
+ DatasetGenerationConfig instance
+
+ Raises:
+ ConfigurationError: If the file doesn't exist or is invalid
+ """
+ config_path = Path(config_path)
+ if not config_path.exists():
+ raise ConfigurationError(f"Configuration file not found: {config_path}")
+
+ try:
+ with open(config_path, "r") as f:
+ config_data = json.load(f)
+ except json.JSONDecodeError as e:
+ raise ConfigurationError(f"Invalid JSON in configuration file: {e}")
+
+ return load_config_from_dict(config_data)
+
+
+def load_config_from_dict(config_data: dict[str, Any]) -> DatasetGenerationConfig:
+ """
+ Load configuration from dictionary.
+
+ Args:
+ config_data: Configuration dictionary
+
+ Returns:
+ DatasetGenerationConfig instance
+
+ Raises:
+ ConfigurationError: If the configuration is invalid
+ """
+ try:
+ # Parse model configurations
+ models = []
+ for model_data in config_data.get("models", []):
+ models.append(
+ ModelConfig(
+ backend=model_data["backend"],
+ model_name=model_data["model_name"],
+ custom_config=model_data.get("custom_config"),
+ confidence_threshold=model_data.get("confidence_threshold", 0.2),
+ device=model_data.get("device"),
+ )
+ )
+
+ # Parse WBF configuration
+ wbf_config = None
+ if config_data.get("wbf_config"):
+ wbf_data = config_data["wbf_config"]
+ wbf_config = WBFConfig(
+ iou_threshold=wbf_data.get("iou_threshold", 0.55),
+ skip_box_threshold=wbf_data.get("skip_box_threshold", 0.0),
+ model_weights=wbf_data.get("model_weights"),
+ )
+
+ # Create main configuration
+ return DatasetGenerationConfig(
+ dataset_name=config_data["dataset_name"],
+ split=config_data["split"],
+ models=models,
+ use_wbf=config_data.get("use_wbf", False),
+ wbf_config=wbf_config,
+ confidence_threshold=config_data.get("confidence_threshold", 0.2),
+ batch_size=config_data.get("batch_size", 1),
+ device=config_data.get("device"),
+ allowable_set=config_data.get("allowable_set"),
+ question_configs=config_data.get("question_configs"),
+ num_workers=config_data.get("num_workers", 4),
+ qa_workers=config_data.get("qa_workers", 4),
+ save_path=config_data.get("save_path"),
+ upload_to_hub=config_data.get("upload_to_hub", False),
+ hub_repo_id=config_data.get("hub_repo_id"),
+ hub_private=config_data.get("hub_private", False),
+ num_samples=config_data.get("num_samples"),
+ use_original_filenames=config_data.get("use_original_filenames", True),
+ filename_prefix=config_data.get("filename_prefix", "img"),
+ )
+
+ except KeyError as e:
+ raise ConfigurationError(f"Missing required configuration key: {e}")
+ except Exception as e:
+ raise ConfigurationError(f"Error parsing configuration: {e}")
+
+
+def create_example_config() -> dict[str, Any]:
+ """Create an example configuration dictionary."""
+ return {
+ "dataset_name": "bdd",
+ "split": "val",
+ "models": [
+ {
+ "backend": "ultralytics",
+ "model_name": "yolov8x",
+ "confidence_threshold": 0.3,
+ },
+ {
+ "backend": "detectron",
+ "model_name": "faster_rcnn_R_50_FPN_3x",
+ "confidence_threshold": 0.2,
+ },
+ {
+ "backend": "mmdetection",
+ "model_name": "co_detr",
+ "confidence_threshold": 0.25,
+ },
+ ],
+ "use_wbf": True,
+ "wbf_config": {
+ "iou_threshold": 0.55,
+ "skip_box_threshold": 0.0,
+ "model_weights": [1.0, 1.0, 1.0],
+ },
+ "confidence_threshold": 0.2,
+ "batch_size": 4,
+ "allowable_set": ["person", "car", "truck", "bus", "bicycle", "motorcycle"],
+ "save_path": "./my_dataset",
+ "upload_to_hub": False,
+ "hub_repo_id": "username/my-dataset",
+ "hub_private": True,
+ }
+
+
+def save_example_config(output_path: Union[str, Path]):
+ """Save an example configuration file."""
+ config = create_example_config()
+ output_path = Path(output_path)
+
+ with open(output_path, "w") as f:
+ json.dump(config, f, indent=2)
+
+ logger.info(f"Example configuration saved to {output_path}")
+
+
+def validate_config_file(config_path: Union[str, Path]) -> tuple[bool, Optional[str]]:
+ """
+ Validate a configuration file.
+
+ Args:
+ config_path: Path to the configuration file
+
+ Returns:
+ Tuple of (is_valid, error_message)
+ """
+ try:
+ config = load_config_from_file(config_path)
+ return True, None
+ except ConfigurationError as e:
+ return False, str(e)
+ except Exception as e:
+ return False, f"Unexpected error: {e}"
+
diff --git a/graid/src/graid/data/export_bdd_to_yolo.py b/graid/src/graid/data/export_bdd_to_yolo.py
new file mode 100644
index 0000000..7c6b1e9
--- /dev/null
+++ b/graid/src/graid/data/export_bdd_to_yolo.py
@@ -0,0 +1,93 @@
+from __future__ import annotations
+
+from pathlib import Path
+
+from tqdm import tqdm
+
+from graid.data.ImageLoader import Bdd100kDataset
+from graid.interfaces.ObjectDetectionI import ObjectDetectionResultI
+
+
+def convert_bdd_to_yolo():
+ """
+ Converts BDD100K dataset annotations to YOLO format.
+
+ This script processes the 'train', 'val', and 'test' splits of the
+ BDD100K dataset. For each image containing objects, it generates a
+ corresponding '.txt' file in YOLO format.
+
+ The YOLO format consists of one line per object, with each line
+ containing the class ID and the normalized bounding box coordinates
+ (center_x, center_y, width, height).
+
+ The class IDs are derived from the original BDD100K categories,
+ mapped to a zero-indexed integer.
+
+ Generated label files are stored in a 'yolo_labels' directory,
+ organized by dataset split.
+ """
+ root_output_dir = Path("data/bdd100k/yolo_labels")
+ print(f"Output directory: {root_output_dir.resolve()}")
+
+ # Get categories from the Bdd100kDataset class
+ category_map = {v: k for k, v in Bdd100kDataset._CATEGORIES.items()}
+ print("BDD100K Class to ID mapping (from Bdd100kDataset):")
+ for class_id, name in sorted(category_map.items()):
+ print(f" {class_id}: {name}")
+
+ for split in ["train", "val"]:
+ print(f"\nProcessing '{split}' split...")
+
+ dataset = Bdd100kDataset(
+ split=split,
+ use_original_categories=True,
+ use_time_filtered=False
+ )
+
+ output_dir = root_output_dir / split
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ labels_generated = 0
+
+ for i in tqdm(range(len(dataset)), desc=f"Exporting {split}"):
+ item = dataset[i]
+ image_name = item["name"]
+ # The 'labels' are a list of ObjectDetectionResultI objects
+ detections: list[ObjectDetectionResultI] = item["labels"]
+
+ if not detections:
+ continue
+
+ yolo_lines = []
+ for det in detections:
+ # The class ID is directly available in the detection object
+ class_id = det.cls
+
+ # as_xywhn() provides normalized [x_center, y_center, width, height]
+ # It returns a tensor, so we get the first (and only) row
+ xywhn = det.as_xywhn()[0]
+ x_center, y_center, width, height = xywhn
+
+ yolo_lines.append(
+ f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
+ )
+
+ if yolo_lines:
+ label_path = output_dir / f"{Path(image_name).stem}.txt"
+ with open(label_path, "w") as f:
+ f.write("\n".join(yolo_lines))
+ labels_generated += 1
+
+ if labels_generated == 0:
+ print(f"\nWARNING: No label files were generated for the '{split}' split.")
+ print("This could be because the dataset split is empty, contains no annotations,")
+ print("or all annotations were filtered out.")
+ print("This can cause errors during training/validation if the framework expects labels.")
+ else:
+ print(f"\nSuccessfully generated {labels_generated} label files for the '{split}' split.")
+
+ print("\nConversion to YOLO format complete.")
+
+
+if __name__ == "__main__":
+ convert_bdd_to_yolo()
\ No newline at end of file
diff --git a/graid/src/graid/data/generate_dataset.py b/graid/src/graid/data/generate_dataset.py
new file mode 100644
index 0000000..8aa6604
--- /dev/null
+++ b/graid/src/graid/data/generate_dataset.py
@@ -0,0 +1,1787 @@
+"""
+GRAID HuggingFace Dataset Generation
+
+This module provides comprehensive functionality for generating HuggingFace datasets
+from object detection data, supporting multiple model backends, ensemble methods,
+and flexible question-answer generation patterns.
+
+Key Features:
+ - Multi-backend support: Detectron2, MMDetection, Ultralytics
+ - Weighted Box Fusion (WBF) ensemble methods
+ - Parallel question-answer generation
+ - COCO-style annotations with embedded PIL images
+ - Unlabeled image support (model-generated detections)
+ - Memory-efficient dataset generation
+ - HuggingFace Hub integration
+
+Classes:
+ HuggingFaceDatasetBuilder: Main dataset generation engine
+ QABatchProcessor: Abstract strategy for QA processing
+ SequentialQAProcessor: Sequential QA generation strategy
+ ParallelQAProcessor: Parallel QA generation with ThreadPoolExecutor
+ QAProcessorFactory: Factory for creating QA processing strategies
+
+Functions:
+ generate_dataset: High-level API for dataset generation
+ list_available_questions: Query available question types
+ interactive_question_selection: Interactive question configuration
+"""
+
+import json
+import logging
+import os
+import time
+from abc import ABC, abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union, Tuple
+
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from graid.utilities.common import get_default_device
+
+logger = logging.getLogger(__name__)
+
+
+class QABatchProcessor(ABC):
+ """
+ Abstract strategy for processing question-answer generation in batches.
+
+ This class defines the interface for different QA processing strategies,
+ allowing flexible switching between sequential and parallel processing
+ approaches based on performance requirements and resource constraints.
+
+ The strategy pattern enables:
+ - Sequential processing for memory-limited environments
+ - Parallel processing for high-throughput scenarios
+ - Easy extension with new processing strategies
+ """
+
+ @abstractmethod
+ def process_batch(
+ self, batch_data: List[Tuple[Image.Image, List[Any], str, int, int]]
+ ) -> List[Any]:
+ """
+ Process a batch of image data and generate question-answer pairs.
+
+ This method takes prepared batch data and applies question generation
+ algorithms to produce structured QA pairs with optional timing information.
+
+ Args:
+ batch_data: List of tuples containing:
+ - pil_image (PIL.Image.Image): Processed image
+ - detections (List[Detection]): Object detection results
+ - source_id (str): Unique identifier for the image
+ - base_image_index (int): Starting index for this batch
+ - j (int): Position within the batch
+
+ Returns:
+ List of QA results where each element is either:
+ - List[Dict[str, Any]]: QA pairs (when profiling disabled)
+ - Tuple[List[Dict[str, Any]], Dict[str, tuple[float, int]]]:
+ QA pairs with timing data (when profiling enabled)
+
+ Raises:
+ NotImplementedError: If called on abstract base class
+ """
+ pass
+
+
+class SequentialQAProcessor(QABatchProcessor):
+ """
+ Sequential question-answer processing strategy.
+
+ This implementation processes images one by one in a single thread,
+ providing predictable memory usage and easier debugging at the cost
+ of processing speed. Ideal for:
+ - Memory-constrained environments
+ - Debugging and development
+ - Small batch sizes
+ - Systems with limited CPU cores
+
+ Attributes:
+ qa_generator: Reference to the dataset builder instance
+ profile_questions: Whether to collect timing statistics
+ """
+
+ def __init__(self, qa_generator, profile_questions: bool):
+ """
+ Initialize sequential QA processor.
+
+ Args:
+ qa_generator: The HuggingFaceDatasetBuilder instance that contains
+ the question generation logic and configuration
+ profile_questions: Whether to enable timing profiling for
+ performance analysis
+ """
+ self.qa_generator = qa_generator
+ self.profile_questions = profile_questions
+ logger.debug(
+ "β Initialized SequentialQAProcessor with profiling=%s", profile_questions
+ )
+
+ def process_batch(
+ self, batch_data: List[Tuple[Image.Image, List[Any], str, int, int]]
+ ) -> List[Any]:
+ """
+ Process QA generation sequentially for all images in the batch.
+
+ Args:
+ batch_data: List of prepared image data tuples
+
+ Returns:
+ List of QA results maintaining input order
+ """
+ logger.debug("π Processing batch of %d images sequentially", len(batch_data))
+ results = []
+
+ for i, args in enumerate(batch_data):
+ pil_image, detections, source_id, base_image_index, j = args
+ image_index = base_image_index + j
+
+ # Set current example for filename inference
+ self.qa_generator._current_example = {"name": source_id}
+
+ try:
+ ret = self.qa_generator._qa_for_image(
+ pil_image, detections, source_id, image_index
+ )
+ results.append(ret)
+ logger.debug(
+ "β Processed image %d/%d: %s", i + 1, len(batch_data), source_id
+ )
+ except Exception as e:
+ logger.error("β Failed to process image %s: %s", source_id, e)
+ # Add empty result to maintain order
+ empty_result = ([], {}) if self.profile_questions else []
+ results.append(empty_result)
+
+ logger.debug(
+ "β
Sequential batch processing completed: %d results", len(results)
+ )
+ return results
+
+
+class ParallelQAProcessor(QABatchProcessor):
+ """
+ Parallel question-answer processing strategy using ThreadPoolExecutor.
+
+ This implementation processes multiple images concurrently using a thread pool,
+ providing significant speedup for I/O-bound question generation tasks.
+ Uses ThreadPoolExecutor.map() to maintain result ordering. Ideal for:
+ - High-throughput scenarios
+ - Systems with multiple CPU cores
+ - I/O-bound question generation
+ - Large batch processing
+
+ Note:
+ Maintains strict ordering through executor.map() to ensure
+ QA results correspond to input images correctly.
+
+ Attributes:
+ qa_generator: Reference to the dataset builder instance
+ qa_workers: Number of parallel worker threads
+ profile_questions: Whether to collect timing statistics
+ """
+
+ def __init__(self, qa_generator, qa_workers: int, profile_questions: bool):
+ """
+ Initialize parallel QA processor.
+
+ Args:
+ qa_generator: The HuggingFaceDatasetBuilder instance containing
+ the thread-safe question generation logic
+ qa_workers: Number of parallel worker threads to spawn.
+ Recommended: 2-4x CPU cores for I/O-bound tasks
+ profile_questions: Whether to enable timing profiling for
+ performance analysis
+ """
+ self.qa_generator = qa_generator
+ self.qa_workers = qa_workers
+ self.profile_questions = profile_questions
+ logger.debug(
+ "β Initialized ParallelQAProcessor with %d workers, profiling=%s",
+ qa_workers,
+ profile_questions,
+ )
+
+ def process_batch(
+ self, batch_data: List[Tuple[Image.Image, List[Any], str, int, int]]
+ ) -> List[Any]:
+ """
+ Process QA generation in parallel with strict order preservation.
+
+ Uses ThreadPoolExecutor.map() which maintains the order of results
+ corresponding to the input batch_data order, ensuring QA pairs
+ match their source images correctly.
+
+ Args:
+ batch_data: List of prepared image data tuples
+
+ Returns:
+ List of QA results in the same order as input batch_data
+ """
+ logger.debug(
+ "π Processing batch of %d images with %d parallel workers",
+ len(batch_data),
+ self.qa_workers,
+ )
+
+ with ThreadPoolExecutor(max_workers=self.qa_workers) as executor:
+ results = list(
+ executor.map(self.qa_generator._qa_for_image_threadsafe, batch_data)
+ )
+
+ logger.debug("β
Parallel batch processing completed: %d results", len(results))
+ return results
+
+
+class QAProcessorFactory:
+ """
+ Factory for creating appropriate QA processing strategies.
+
+ This factory implements the Strategy pattern by selecting the optimal
+ QA processing approach based on configuration parameters. The selection
+ logic considers performance requirements, resource constraints, and
+ system capabilities.
+
+ Strategy Selection Rules:
+ - qa_workers = 1: Sequential processing (safe, predictable)
+ - qa_workers > 1: Parallel processing (high throughput)
+ """
+
+ @staticmethod
+ def create(
+ qa_workers: int, qa_generator, profile_questions: bool
+ ) -> QABatchProcessor:
+ """
+ Create the appropriate QA processing strategy based on configuration.
+
+ Automatically selects between sequential and parallel processing
+ strategies based on the number of workers requested. This enables
+ transparent optimization without changing client code.
+
+ Args:
+ qa_workers: Number of QA worker threads to use:
+ - 1: Creates SequentialQAProcessor for single-threaded processing
+ - >1: Creates ParallelQAProcessor with specified worker count
+ qa_generator: The HuggingFaceDatasetBuilder instance that provides
+ the question generation logic and configuration
+ profile_questions: Whether to enable performance profiling and
+ timing collection for analysis
+
+ Returns:
+ QABatchProcessor: Configured strategy instance ready for processing
+
+ Example:
+ >>> # Single-threaded for debugging
+ >>> processor = QAProcessorFactory.create(1, builder, True)
+ >>>
+ >>> # Multi-threaded for production
+ >>> processor = QAProcessorFactory.create(8, builder, False)
+ """
+ if qa_workers > 1:
+ logger.info("π Creating ParallelQAProcessor with %d workers", qa_workers)
+ return ParallelQAProcessor(qa_generator, qa_workers, profile_questions)
+ else:
+ logger.info(
+ "π Creating SequentialQAProcessor for single-threaded processing"
+ )
+ return SequentialQAProcessor(qa_generator, profile_questions)
+
+
+class HuggingFaceDatasetBuilder:
+ """
+ Advanced HuggingFace dataset builder for object detection question-answering.
+
+ This class orchestrates the complete pipeline for generating high-quality VQA datasets
+ from object detection data. It supports multiple detection backends, ensemble methods,
+ parallel processing, and produces datasets compatible with modern vision-language models.
+
+ Key Capabilities:
+ π― Multi-Backend Support: Detectron2, MMDetection, Ultralytics models
+ π Ensemble Methods: Weighted Box Fusion (WBF) for improved accuracy
+ π Parallel Processing: Configurable worker threads for QA generation
+ π COCO Compatibility: Standard annotations with category strings
+ πΌοΈ PIL Integration: Embedded images ready for VLM workflows
+ π Flexible Storage: Original or generated filenames
+ π Hub Integration: Direct upload to HuggingFace Hub
+
+ Architecture:
+ The builder uses the Strategy pattern for QA processing, Factory pattern
+ for dataset loading, and incremental dataset construction to handle
+ large-scale data generation efficiently.
+
+ Workflow:
+ 1. Initialize models and configure processing parameters
+ 2. Load and transform source dataset (BDD100K, NuImages, Waymo, Custom)
+ 3. Apply object detection (ensemble via WBF or single model)
+ 4. Generate question-answer pairs using parallel/sequential strategies
+ 5. Build incremental HuggingFace datasets with embedded PIL images
+ 6. Optional: Upload to HuggingFace Hub with metadata
+
+ Performance Optimizations:
+ - Batch processing with configurable sizes
+ - Parallel QA generation with ThreadPoolExecutor
+ - Memory-efficient generator-based processing
+ - Confidence thresholds for quality control
+
+ Example:
+ >>> builder = HuggingFaceDatasetBuilder(
+ ... dataset_name="bdd",
+ ... split="val",
+ ... models=[yolo_model, detectron_model],
+ ... use_wbf=True,
+ ... qa_workers=8,
+ ... num_samples=1000
+ ... )
+ >>> dataset_dict = builder.build()
+ >>> print(f"Generated {len(dataset_dict['val'])} QA pairs")
+ """
+
+ def __init__(
+ self,
+ dataset_name: str,
+ split: str,
+ models: Optional[List[Any]] = None,
+ use_wbf: bool = False,
+ wbf_config: Optional[Dict[str, Any]] = None,
+ conf_threshold: float = 0.2,
+ batch_size: int = 1,
+ device: Optional[Union[str, torch.device]] = None,
+ allowable_set: Optional[List[str]] = None,
+ question_configs: Optional[List[Dict[str, Any]]] = None,
+ num_workers: int = 4,
+ qa_workers: int = 4,
+ num_samples: Optional[int] = None,
+ use_original_filenames: bool = True,
+ filename_prefix: str = "img",
+ save_path: str = "./graid-datasets",
+ ):
+ """
+ Initialize the HuggingFace dataset builder.
+
+ Args:
+ dataset_name: Name of the dataset ("bdd", "nuimage", "waymo")
+ split: Dataset split ("train", "val", "test")
+ models: List of model objects for inference (optional)
+ use_wbf: Whether to use Weighted Box Fusion ensemble
+ wbf_config: Configuration for WBF ensemble (optional)
+ conf_threshold: Confidence threshold for filtering detections
+ batch_size: Batch size for processing
+ device: Device to use for inference (optional)
+ allowable_set: List of allowed object classes (optional)
+ question_configs: List of question configuration dictionaries (optional)
+ num_workers: Number of data loading workers
+ qa_workers: Number of QA generation workers
+ num_samples: Maximum number of samples to process (0 or None = process all)
+ save_path: Path to save dataset (required)
+ use_original_filenames: Whether to keep original filenames
+ filename_prefix: Prefix for generated filenames if not using originals
+ """
+ self.dataset_name = dataset_name
+ self.split = split
+ self.models = models or []
+ self.use_wbf = use_wbf
+ self.wbf_config = wbf_config or {}
+ self.conf_threshold = conf_threshold
+ self.batch_size = batch_size
+ self.device = device if device is not None else get_default_device()
+ self.allowable_set = allowable_set
+ self.num_workers = num_workers
+ self.qa_workers = qa_workers
+ self.num_samples = num_samples
+ self.save_path = Path(save_path)
+ self.use_original_filenames = use_original_filenames
+ self.filename_prefix = filename_prefix
+
+ # Question profiling (timings)
+ self.profile_questions: bool = bool(os.getenv("GRAID_PROFILE_QUESTIONS"))
+ self.question_timings: Dict[str, Tuple[float, int]] = {}
+ self.question_counts: Dict[str, int] = {}
+
+ # Enhanced profiling for is_applicable vs apply efficiency
+ self.question_detailed_stats: Dict[str, Dict[str, Any]] = {}
+
+
+
+ # Validate allowable_set
+ if allowable_set is not None:
+ from graid.utilities.coco import validate_coco_objects
+
+ is_valid, error_msg = validate_coco_objects(allowable_set)
+ if not is_valid:
+ raise ValueError(f"Invalid allowable_set: {error_msg}")
+
+ # Initialize dataset transforms
+ self.transform = self._get_dataset_transform()
+
+ # Initialize questions
+ self.questions = self._initialize_questions(question_configs)
+
+ # Initialize dataset loader
+ self._init_dataset_loader()
+
+ # Note: No longer creating image directories - using embedded images in parquet
+
+ # Prepare WBF ensemble if needed
+ self.wbf_ensemble = None
+ if self.use_wbf and self.models:
+ self._prepare_wbf_ensemble()
+
+ def _get_dataset_transform(self):
+ """Get the appropriate transform for the dataset."""
+ from graid.utilities.common import (
+ yolo_bdd_transform,
+ yolo_nuscene_transform,
+ yolo_waymo_transform,
+ )
+
+ if self.dataset_name == "bdd":
+ return lambda i, l: yolo_bdd_transform(i, l, new_shape=(768, 1280))
+ elif self.dataset_name == "nuimage":
+ return lambda i, l: yolo_nuscene_transform(i, l, new_shape=(896, 1600))
+ elif self.dataset_name == "waymo":
+ return lambda i, l: yolo_waymo_transform(i, l, (1280, 1920))
+ else:
+ raise ValueError(f"Unsupported dataset: {self.dataset_name}")
+
+ def _initialize_questions(
+ self, question_configs: Optional[List[Dict[str, Any]]]
+ ) -> List[Any]:
+ """Initialize question objects from configuration."""
+ if question_configs is None:
+ # Use all available questions
+ from graid.questions.ObjectDetectionQ import ALL_QUESTION_CLASSES
+
+ return list(ALL_QUESTION_CLASSES.values())
+
+ questions = []
+ from graid.questions.ObjectDetectionQ import (
+ IsObjectCentered,
+ WidthVsHeight,
+ LargestAppearance,
+ RankLargestK,
+ MostAppearance,
+ LeastAppearance,
+ LeftOf,
+ RightOf,
+ LeftMost,
+ RightMost,
+ HowMany,
+ MostClusteredObjects,
+ WhichMore,
+ AreMore,
+ Quadrants,
+ LeftMostWidthVsHeight,
+ RightMostWidthVsHeight,
+ ObjectsInRow,
+ ObjectsInLine,
+ MoreThanThresholdHowMany,
+ LessThanThresholdHowMany,
+ MultiChoiceHowMany,
+ )
+
+ # Map question names to classes
+ question_class_map = {
+ "IsObjectCentered": IsObjectCentered,
+ "WidthVsHeight": WidthVsHeight,
+ "LargestAppearance": LargestAppearance,
+ "RankLargestK": RankLargestK,
+ "MostAppearance": MostAppearance,
+ "LeastAppearance": LeastAppearance,
+ "LeftOf": LeftOf,
+ "RightOf": RightOf,
+ "LeftMost": LeftMost,
+ "RightMost": RightMost,
+ "HowMany": HowMany,
+ "MostClusteredObjects": MostClusteredObjects,
+ "WhichMore": WhichMore,
+ "AreMore": AreMore,
+ "Quadrants": Quadrants,
+ "LeftMostWidthVsHeight": LeftMostWidthVsHeight,
+ "RightMostWidthVsHeight": RightMostWidthVsHeight,
+ "ObjectsInRow": ObjectsInRow,
+ "ObjectsInLine": ObjectsInLine,
+ "MoreThanThresholdHowMany": MoreThanThresholdHowMany,
+ "LessThanThresholdHowMany": LessThanThresholdHowMany,
+ "MultiChoiceHowMany": MultiChoiceHowMany,
+ }
+
+ for config in question_configs:
+ question_name = config.get("name")
+ question_params = config.get("params", {})
+
+ if question_name not in question_class_map:
+ logger.warning(f"Unknown question type: {question_name}")
+ continue
+
+ question_class = question_class_map[question_name]
+
+ # Handle questions that require parameters
+ if question_params:
+ try:
+ question_instance = question_class(**question_params)
+ except Exception as e:
+ logger.error(
+ f"Failed to initialize {question_name} with params {question_params}: {e}"
+ )
+ # Fall back to default initialization
+ question_instance = question_class()
+ else:
+ question_instance = question_class()
+
+ questions.append(question_instance)
+
+ if not questions:
+ raise ValueError("No valid questions configured")
+
+ return questions
+
+ def _init_dataset_loader(self):
+ """Initialize the appropriate dataset loader using the common factory."""
+ from graid.data.loaders import DatasetLoaderFactory
+
+ try:
+ self.dataset_loader = DatasetLoaderFactory.create(
+ dataset_name=self.dataset_name,
+ split=self.split,
+ transform=self.transform,
+ )
+ except Exception as e:
+ logger.error(f"Failed to initialize dataset loader: {e}")
+ raise
+
+ def _prepare_wbf_ensemble(self):
+ """Prepare WBF ensemble from individual models."""
+ # Import WBF classes locally
+ from graid.models.Detectron import Detectron_obj
+ from graid.models.MMDetection import MMdetection_obj
+ from graid.models.Ultralytics import RT_DETR, Yolo
+ from graid.models.WBF import WBF
+
+ # Group models by backend
+ detectron_models = []
+ mmdet_models = []
+ ultralytics_models = []
+
+ for model in self.models:
+ if isinstance(model, Detectron_obj):
+ detectron_models.append(model)
+ elif isinstance(model, MMdetection_obj):
+ mmdet_models.append(model)
+ elif isinstance(model, (Yolo, RT_DETR)):
+ ultralytics_models.append(model)
+
+ # Create WBF ensemble
+ self.wbf_ensemble = WBF(
+ detectron2_models=detectron_models if detectron_models else None,
+ mmdet_models=mmdet_models if mmdet_models else None,
+ ultralytics_models=ultralytics_models if ultralytics_models else None,
+ **self.wbf_config,
+ )
+
+ def _infer_source_name(self, example: Dict[str, Any]) -> Optional[str]:
+ """Extract source filename from dataset example."""
+ if isinstance(example, dict) and "name" in example:
+ return example["name"]
+ return None
+
+ def _generate_filename(self, index: int, source_name: Optional[str]) -> str:
+ """Generate filename based on configuration."""
+ if self.use_original_filenames and source_name:
+ return Path(source_name).name
+ return f"{self.filename_prefix}{index:06d}.jpg"
+
+ def _convert_image_to_pil(
+ self, image: Union[torch.Tensor, np.ndarray]
+ ) -> Image.Image:
+ """Convert tensor or numpy array to PIL Image."""
+ if isinstance(image, torch.Tensor):
+ if image.dim() == 3: # (C, H, W)
+ image = image.permute(1, 2, 0).cpu().numpy()
+ elif image.dim() == 4: # (B, C, H, W)
+ image = image[0].permute(1, 2, 0).cpu().numpy()
+
+ # Ensure proper data type and range
+ if not isinstance(image, np.ndarray):
+ image = np.array(image)
+
+ if image.dtype in [np.float32, np.float64]:
+ if image.max() > 1.0:
+ # Values already in [0, 255] range, just convert to uint8
+ image = image.astype(np.uint8)
+ else:
+ image = (image * 255).astype(np.uint8)
+ elif image.dtype != np.uint8:
+ image = image.astype(np.uint8)
+
+ return Image.fromarray(image)
+
+ def _build_coco_annotations(
+ self, detections: List[Any], image_width: int, image_height: int
+ ) -> List[Dict[str, Any]]:
+ """
+ Build COCO-style annotations from detections.
+
+ Args:
+ detections: List of detection objects
+ image_width: Image width in pixels
+ image_height: Image height in pixels
+
+ Returns:
+ List of COCO annotation dictionaries
+ """
+ annotations = []
+
+ for detection in detections:
+ # Get bounding box in XYWH format
+ xywh = detection.as_xywh()[0]
+ x, y, w, h = float(xywh[0]), float(xywh[1]), float(xywh[2]), float(xywh[3])
+
+ # Build COCO annotation
+ annotation = {
+ "bbox": [x, y, w, h], # COCO format: [x, y, width, height]
+ "category_id": 1, # Default category ID
+ "category": detection.label, # Add category string
+ "iscrowd": 0,
+ "area": float(w * h),
+ "score": float(detection.score) if hasattr(detection, "score") else 1.0,
+ }
+ annotations.append(annotation)
+
+ return annotations
+
+ def _qa_for_image(
+ self,
+ pil_image: Image.Image,
+ detections: List[Any],
+ source_id: str,
+ image_index: int,
+ ) -> Union[
+ List[Dict[str, Any]], tuple[List[Dict[str, Any]], Dict[str, tuple[float, int]]]
+ ]:
+ """Generate question-answer pairs for a single image with embedded image bytes."""
+ qa_pairs = []
+ local_timings: Dict[str, tuple[float, int]] = (
+ {} if self.profile_questions else {}
+ )
+
+ # Ensure image is in RGB format for consistency
+ rgb_img = (
+ pil_image if pil_image.mode in ("RGB", "L") else pil_image.convert("RGB")
+ )
+
+ # SOLUTION: Embed image bytes directly instead of saving separate files
+ # This solves HuggingFace 10k file limit by storing images in parquet
+ # No compression - preserve original format or store as uncompressed PNG
+ import io
+
+ # Try to preserve original format from source_id extension
+ _, ext = os.path.splitext(source_id)
+ original_format = ext.upper().lstrip('.') if ext else 'PNG'
+
+ # Map common extensions to PIL formats
+ format_map = {'JPG': 'JPEG', 'JPEG': 'JPEG', 'PNG': 'PNG', 'BMP': 'BMP', 'TIFF': 'TIFF'}
+ pil_format = format_map.get(original_format, 'PNG') # Default to PNG if unknown
+
+ buffer = io.BytesIO()
+ if pil_format == 'JPEG':
+ # For JPEG, save without additional compression (quality=100)
+ rgb_img.save(buffer, format=pil_format, quality=100, optimize=False)
+ elif pil_format == 'PNG':
+ # For PNG, save without compression
+ rgb_img.save(buffer, format=pil_format, compress_level=0, optimize=False)
+ else:
+ # For other formats, save as-is
+ rgb_img.save(buffer, format=pil_format)
+
+ image_bytes = buffer.getvalue()
+
+ # Store image as bytes with original format info
+ image_reference = {"bytes": image_bytes, "path": None}
+
+ # Generate COCO annotations
+ annotations = self._build_coco_annotations(
+ detections, pil_image.width, pil_image.height
+ )
+
+ # Generate questions and answers with enhanced profiling
+ for question in self.questions:
+ qname = question.__class__.__name__
+
+ # Initialize detailed stats for this question if not exists
+ if self.profile_questions and qname not in self.question_detailed_stats:
+ self.question_detailed_stats[qname] = {
+ "is_applicable_time": (0.0, 0),
+ "is_applicable_true_count": 0,
+ "apply_time": (0.0, 0),
+ "apply_empty_results": 0,
+ "total_qa_generated": 0,
+ "question_text": getattr(question, "question", qname) # Full question string
+ }
+
+ # Profile is_applicable timing
+ if detections:
+ is_applicable_start = time.perf_counter() if self.profile_questions else None
+ is_applicable_result = question.is_applicable(pil_image, detections)
+
+ if self.profile_questions and is_applicable_start is not None:
+ is_applicable_time = time.perf_counter() - is_applicable_start
+ current_time, current_count = self.question_detailed_stats[qname]["is_applicable_time"]
+ self.question_detailed_stats[qname]["is_applicable_time"] = (
+ current_time + is_applicable_time, current_count + 1
+ )
+
+ if is_applicable_result:
+ if self.profile_questions:
+ self.question_detailed_stats[qname]["is_applicable_true_count"] += 1
+
+ # Profile apply timing
+ apply_start = time.perf_counter() if self.profile_questions else None
+ try:
+ qa_results = question.apply(pil_image, detections)
+
+ if self.profile_questions and apply_start is not None:
+ apply_time = time.perf_counter() - apply_start
+ current_time, current_count = self.question_detailed_stats[qname]["apply_time"]
+ self.question_detailed_stats[qname]["apply_time"] = (
+ current_time + apply_time, current_count + 1
+ )
+
+ # Track legacy timing for backward compatibility
+ dt = apply_time
+ t_total, t_cnt = local_timings.get(qname, (0.0, 0))
+ local_timings[qname] = (t_total + dt, t_cnt + 1)
+
+ # Check if apply returned empty results despite is_applicable=True
+ if not qa_results and self.profile_questions:
+ self.question_detailed_stats[qname]["apply_empty_results"] += 1
+
+ for qa_item in qa_results:
+ if not isinstance(qa_item, (tuple, list)) or len(qa_item) != 2:
+ logger.warning(
+ f"{question.__class__.__name__}.apply() returned malformed item: {qa_item!r}"
+ )
+ continue
+
+ question_text, answer_text = qa_item
+
+ # Build the final QA pair with embedded image bytes
+ qa_pair = {
+ "image": image_reference, # Embedded bytes dict format
+ "annotations": annotations,
+ "question": question_text,
+ "answer": answer_text,
+ "reasoning": None,
+ "question_type": question.__class__.__name__,
+ "source_id": source_id,
+ }
+
+ # Add source_filename if using generated filenames for reference
+ if not self.use_original_filenames:
+ source_name = (
+ self._infer_source_name({"name": source_id})
+ if hasattr(self, "_current_example")
+ else None
+ )
+ if source_name:
+ qa_pair["source_filename"] = source_name
+
+ qa_pairs.append(qa_pair)
+
+ # Track successful QA generation
+ if self.profile_questions:
+ self.question_detailed_stats[qname]["total_qa_generated"] += 1
+
+ except Exception as e:
+ logger.warning(
+ f"Question {question.__class__.__name__} failed on image {source_id}: {e}"
+ )
+ continue
+
+ if self.profile_questions:
+ return (qa_pairs, local_timings)
+ return qa_pairs
+
+ def _qa_for_image_threadsafe(
+ self, batch_args: tuple
+ ) -> Union[
+ List[Dict[str, Any]], tuple[List[Dict[str, Any]], Dict[str, tuple[float, int]]]
+ ]:
+ """Thread-safe wrapper for _qa_for_image using source_id for uniqueness."""
+ pil_image, detections, source_id, base_image_index, batch_j = batch_args
+
+ # Use source_id + batch_j for unique identification (no magic numbers)
+ unique_image_key = f"{source_id}_{batch_j}"
+
+ try:
+ return self._qa_for_image(
+ pil_image, detections, source_id, base_image_index + batch_j
+ )
+ except Exception as e:
+ logger.error(f"Error in threaded QA generation for {unique_image_key}: {e}")
+ # Return appropriate empty result based on profiling mode
+ return ([], {}) if self.profile_questions else []
+
+
+
+
+
+
+
+ def _cleanup_images(self):
+ """Clean up image files after successful dataset creation to avoid duplicate storage."""
+ if not self.save_path:
+ return
+
+ images_dir = self.save_path / self.split / "images"
+ if images_dir.exists():
+ import shutil
+
+ logger.info(
+ f"π§Ή Cleaning up image files in {images_dir} (images are embedded in Parquet)"
+ )
+ shutil.rmtree(images_dir)
+ logger.debug(f"β
Removed images directory: {images_dir}")
+
+ # Remove split directory if it's now empty
+ split_dir = self.save_path / self.split
+ if split_dir.exists() and not any(split_dir.iterdir()):
+ split_dir.rmdir()
+ logger.debug(f"β
Removed empty split directory: {split_dir}")
+
+ def _create_data_loader(self) -> DataLoader:
+ """Create and configure the PyTorch DataLoader."""
+ return DataLoader(
+ self.dataset_loader,
+ batch_size=self.batch_size,
+ shuffle=False,
+ collate_fn=lambda x: x,
+ num_workers=self.num_workers,
+ prefetch_factor=1,
+ persistent_workers=False,
+ )
+
+
+
+
+
+ def _should_stop_early(self, batch_idx: int, processed_images: int) -> bool:
+ """Check if processing should stop early due to limits."""
+ # Check max_batches environment variable
+ try:
+ max_batches_env = os.getenv("GRAID_MAX_BATCHES")
+ max_batches = int(max_batches_env) if max_batches_env else None
+ if max_batches is not None and batch_idx >= max_batches:
+ logger.info(
+ f"Stopping early after {batch_idx} batches (GRAID_MAX_BATCHES={max_batches})"
+ )
+ return True
+ except Exception:
+ pass
+
+ # Check num_samples limit
+ if (
+ self.num_samples is not None
+ and self.num_samples > 0
+ and processed_images >= int(self.num_samples)
+ ):
+ logger.info(
+ f"Reached num_samples={self.num_samples}. Stopping further processing."
+ )
+ return True
+
+ return False
+
+ def _calculate_total_batches(self, data_loader: DataLoader) -> Optional[int]:
+ """Calculate total number of batches considering early stopping."""
+ total_batches = len(data_loader)
+
+ # Adjust for num_samples limit
+ if self.num_samples is not None and self.num_samples > 0:
+ max_batches_for_samples = (
+ self.num_samples + self.batch_size - 1
+ ) // self.batch_size
+ total_batches = min(total_batches, max_batches_for_samples)
+
+ # Adjust for GRAID_MAX_BATCHES environment variable
+ try:
+ max_batches_env = os.getenv("GRAID_MAX_BATCHES")
+ if max_batches_env:
+ max_batches = int(max_batches_env)
+ total_batches = min(total_batches, max_batches)
+ except Exception:
+ pass
+
+ return total_batches
+
+ def _get_batch_predictions(
+ self, batch: List[Any]
+ ) -> Tuple[torch.Tensor, List[Any]]:
+ """Extract images and predictions from batch data."""
+ # Handle different dataset return formats
+ if isinstance(batch[0], tuple):
+ # Tuple format (BDD dataset)
+ batch_images = torch.stack([sample[0] for sample in batch])
+ ground_truth_labels = [sample[1] for sample in batch]
+ else:
+ # Dictionary format (NuImages/Waymo datasets)
+ batch_images = torch.stack([sample["image"] for sample in batch])
+ ground_truth_labels = [sample["labels"] for sample in batch]
+
+ # Get predictions from model(s)
+ if self.use_wbf and self.wbf_ensemble is not None:
+ batch_images = batch_images.to(self.device)
+ labels = self.wbf_ensemble.identify_for_image_batch(batch_images)
+ elif self.models:
+ batch_images = batch_images.to(self.device)
+ # Use first model if multiple models without WBF
+ model = self.models[0]
+ labels = model.identify_for_image_batch(batch_images)
+ else:
+ # Use ground truth
+ labels = ground_truth_labels
+
+ return batch_images, labels
+
+ def _prepare_batch_data(
+ self,
+ batch_idx: int,
+ batch: List[Any],
+ batch_images: torch.Tensor,
+ labels: List[Any],
+ ) -> List[Tuple[Image.Image, List[Any], str, int, int]]:
+ """Prepare batch data for QA processing."""
+ batch_data = []
+
+ # Prepare data for processing (parallel or sequential)
+ base_image_index = batch_idx * self.batch_size
+
+ for j, (image_tensor, detections) in enumerate(zip(batch_images, labels)):
+ # Convert to PIL Image
+ pil_image = self._convert_image_to_pil(image_tensor)
+
+ # Filter detections by confidence threshold
+ if detections:
+ detections = [d for d in detections if d.score >= self.conf_threshold]
+
+ # Filter detections by allowable set if specified
+ if detections and self.allowable_set:
+ filtered_detections = []
+ for detection in detections:
+ if detection.label in self.allowable_set:
+ filtered_detections.append(detection)
+ else:
+ logger.debug(
+ f"Filtered out detection of class '{detection.label}' (not in allowable set)"
+ )
+ detections = filtered_detections
+
+ # Extract source_id from batch sample
+ if isinstance(batch[j], dict) and "name" in batch[j]:
+ source_id = batch[j]["name"]
+ else:
+ source_id = f"{self.dataset_name}_{batch_idx}_{j}"
+
+ # Store current example for filename inference
+ self._current_example = (
+ batch[j] if isinstance(batch[j], dict) else {"name": source_id}
+ )
+
+ # Add this image to batch data for processing
+ batch_data.append((pil_image, detections, source_id, base_image_index, j))
+
+ return batch_data
+
+ def _process_qa_results(
+ self, batch_results_raw: List[Any]
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Tuple[float, int]]]:
+ """Process raw QA results and extract timings."""
+ batch_results: List[Dict[str, Any]] = []
+ batch_timings: Dict[str, Tuple[float, int]] = (
+ {} if self.profile_questions else {}
+ )
+
+ # Process results and collect timings
+ for ret in batch_results_raw:
+ if self.profile_questions and isinstance(ret, tuple) and len(ret) == 2:
+ qa_pairs, local_timings = ret
+ if isinstance(qa_pairs, list):
+ batch_results.extend(qa_pairs)
+ if isinstance(local_timings, dict):
+ for k, (t, n) in local_timings.items():
+ T, N = batch_timings.get(k, (0.0, 0))
+ batch_timings[k] = (T + t, N + n)
+ elif isinstance(ret, list):
+ batch_results.extend(ret)
+ else:
+ logger.warning(
+ f"Unexpected return type from QA processing: {type(ret)}"
+ )
+
+ return batch_results, batch_timings
+
+
+
+ def _update_progress_tracking(
+ self,
+ batch_results: List[Dict[str, Any]],
+ batch_timings: Dict[str, Tuple[float, int]],
+ ):
+ """Update question counts and timings tracking."""
+ # Update per-question counts
+ for item in batch_results:
+ try:
+ qtype = item.get("question_type")
+ if qtype:
+ self.question_counts[qtype] = self.question_counts.get(qtype, 0) + 1
+ except Exception:
+ pass
+
+ # Merge batch timings into builder-level aggregation
+ if self.profile_questions and batch_timings:
+ for k, (t, n) in batch_timings.items():
+ T, N = self.question_timings.get(k, (0.0, 0))
+ self.question_timings[k] = (T + t, N + n)
+
+ def _log_progress(self, batch_idx: int, processed_images: int, total_qa_pairs: int):
+ """Log progress every 10 batches."""
+ if batch_idx % 10 == 0:
+ logger.info(
+ f"Processed {processed_images} images, generated {total_qa_pairs} QA pairs"
+ )
+
+
+
+ def build(self):
+ """
+ Build the HuggingFace dataset using memory-efficient generator approach.
+
+ This method creates datasets using Dataset.from_generator to maintain bounded
+ memory usage while preserving parallel QA processing. Key improvements:
+ 1. Generator-based processing eliminates memory accumulation
+ 2. Parallel QA workers still utilized for performance
+ 3. Bounded memory via writer_batch_size parameter
+ 4. Embedded images preserved (solving 10k file limit)
+
+ Returns:
+ DatasetDict containing the generated VQA dataset
+ """
+ logger.info(
+ "π Building HuggingFace dataset for %s/%s with generator approach",
+ self.dataset_name, self.split
+ )
+
+ # Import Dataset locally to avoid import issues
+ from datasets import Dataset, DatasetDict, Image as HFImage
+
+ # Create dataset using memory-efficient generator
+ logger.info("π§ Creating dataset using a generator...")
+
+ dataset = Dataset.from_generator(
+ self._qa_data_generator,
+ # Let HuggingFace infer features from the first examples
+ )
+
+ # Cast image column to HFImage format
+ logger.debug("π― Converting image bytes to HFImage format...")
+ dataset = dataset.cast_column("image", HFImage())
+
+ # Add metadata
+ metadata = self._create_metadata()
+ dataset.info.description = (
+ f"Object detection QA dataset for {self.dataset_name}"
+ )
+ dataset.info.features = dataset.features
+ dataset.info.config_name = json.dumps(metadata)
+
+ # Create DatasetDict
+ dataset_dict = DatasetDict({self.split: dataset})
+
+ logger.info(f"β
Generated {len(dataset)} question-answer pairs")
+
+
+
+ # Log profiling information
+ if self.profile_questions and self.question_timings:
+ items = [
+ (k, t / max(n, 1), n) for k, (t, n) in self.question_timings.items()
+ ]
+ items.sort(key=lambda x: x[1], reverse=True)
+ top = ", ".join(
+ [f"{k}: avg {avg:.4f}s over {n}" for k, avg, n in items[:5]]
+ )
+ logger.info(f"[PROFILE] Top slow questions (avg): {top}")
+
+ # Log per-question counts
+ if self.question_counts:
+ pairs = sorted( # by question type, most frequent first
+ self.question_counts.items(), key=lambda kv: kv[1], reverse=True
+ )
+ summary = ", ".join([f"{k}={v}" for k, v in pairs])
+ logger.info(f"Per-question counts: {summary}")
+
+ # Log detailed profiling statistics
+ if self.profile_questions and self.question_detailed_stats:
+ from graid.utils.profiling import log_profiling_statistics
+ question_stats = {'detailed_stats': self.question_detailed_stats}
+ log_profiling_statistics(question_stats, "Detailed Question Processing Statistics")
+
+ return dataset_dict
+
+ def _qa_data_generator(self):
+ """
+ Memory-efficient generator that yields individual QA pairs with parallel processing.
+
+ This generator maintains bounded memory usage by yielding QA pairs one at a time
+ instead of accumulating them in memory. Parallel QA processing is preserved
+ within each batch for optimal performance.
+
+ Yields:
+ Dict[str, Any]: Individual QA pair with embedded image bytes
+ """
+ logger.debug("π Initializing data loader and processing components")
+ data_loader = self._create_data_loader()
+ qa_processor = QAProcessorFactory.create(
+ self.qa_workers, self, self.profile_questions
+ )
+
+ # Calculate total batches for progress tracking
+ total_batches = self._calculate_total_batches(data_loader)
+ processed_images = 0
+ total_qa_pairs = 0
+
+ logger.info(
+ "π Processing %d total batches (%d images per batch) with generator",
+ total_batches,
+ self.batch_size,
+ )
+
+ logger.debug(
+ "π Starting batch processing with %s strategy",
+ "parallel" if self.qa_workers > 1 else "sequential",
+ )
+
+ # Create progress bar for batch processing
+ progress_bar = tqdm(
+ enumerate(data_loader),
+ desc="Generating QA pairs",
+ total=total_batches
+ )
+
+ for batch_idx, batch in progress_bar:
+ # Early stopping logic
+ if self._should_stop_early(batch_idx, processed_images):
+ logger.info(f"Early stopping at batch {batch_idx}")
+ break
+
+ # Get predictions and prepare batch data (same as before)
+ batch_images, labels = self._get_batch_predictions(batch)
+ batch_data = self._prepare_batch_data(
+ batch_idx, batch, batch_images, labels
+ )
+
+ # Process QA using parallel/sequential strategy (unchanged)
+ batch_results_raw = qa_processor.process_batch(batch_data)
+
+ # Process results and update tracking
+ batch_results, batch_timings = self._process_qa_results(batch_results_raw)
+ self._update_progress_tracking(batch_results, batch_timings)
+
+ # Yield individual QA pairs instead of accumulating
+ for qa_pair in batch_results:
+ yield qa_pair
+ total_qa_pairs += 1
+
+ # Update progress tracking
+ processed_images += len(batch)
+ self._log_progress(batch_idx, processed_images, total_qa_pairs)
+
+ # Update progress bar description
+ progress_bar.set_description(
+ f"Generated {total_qa_pairs} QA pairs from {processed_images} images"
+ )
+
+ # Close progress bar
+ progress_bar.close()
+
+ logger.info(
+ f"π― Generator completed: {total_qa_pairs} QA pairs from {processed_images} images"
+ )
+
+ def _get_features_schema(self):
+ """
+ Define the dataset features schema for Dataset.from_generator.
+
+ Returns:
+ datasets.Features: Schema definition for the generated dataset
+ """
+ from datasets import Features, Value, Sequence, Image as HFImage
+
+ return Features({
+ "image": {
+ "bytes": Value("binary"),
+ "path": Value("string"),
+ }, # Image dict with embedded bytes
+ "annotations": Sequence({
+ "bbox": Sequence(Value("float32"), length=4),
+ "category_id": Value("int32"),
+ "category": Value("string"),
+ "iscrowd": Value("int32"),
+ "area": Value("float32"),
+ "score": Value("float32"),
+ }),
+ "question": Value("string"),
+ "answer": Value("string"),
+ "reasoning": Value("string"),
+ "question_type": Value("string"),
+ "source_id": Value("string"),
+ })
+
+
+
+ def _create_metadata(self) -> Dict[str, Any]:
+ """Create metadata dictionary for the dataset."""
+ metadata = {
+ "dataset_name": self.dataset_name,
+ "split": self.split,
+ "confidence_threshold": self.conf_threshold,
+ "batch_size": self.batch_size,
+ "use_wbf": self.use_wbf,
+ "questions": [str(q.__class__.__name__) for q in self.questions],
+ "use_original_filenames": self.use_original_filenames,
+ "filename_prefix": self.filename_prefix,
+ "models": [],
+ }
+
+ # Add device info
+ if not self.use_wbf:
+ metadata["device"] = str(self.device)
+ else:
+ metadata["device_info"] = "Multiple devices may be used in WBF ensemble"
+
+ # Add model information
+ if self.models:
+ for model in self.models:
+ model_info = {
+ "backend": model.__class__.__module__.split(".")[-1],
+ "model_name": getattr(
+ model, "model_name", str(model.__class__.__name__)
+ ),
+ }
+ metadata["models"].append(model_info)
+ else:
+ metadata["models"] = [{"type": "ground_truth"}]
+
+ return metadata
+
+
+def generate_dataset(
+ dataset_name: str,
+ split: str,
+ models: Optional[List[Any]] = None,
+ use_wbf: bool = False,
+ wbf_config: Optional[Dict[str, Any]] = None,
+ conf_threshold: float = 0.2,
+ batch_size: int = 1,
+ device: Optional[Union[str, torch.device]] = None,
+ allowable_set: Optional[List[str]] = None,
+ question_configs: Optional[List[Dict[str, Any]]] = None,
+ num_workers: int = 4,
+ qa_workers: int = 4,
+ save_path: str = "./graid-datasets",
+ upload_to_hub: bool = False,
+ hub_repo_id: Optional[str] = None,
+ hub_private: bool = False,
+ num_samples: Optional[int] = None,
+ use_original_filenames: bool = True,
+ filename_prefix: str = "img",
+):
+ """
+ Generate comprehensive HuggingFace datasets for object detection question-answering.
+
+ This is the primary API function for creating VQA datasets from object detection data.
+ It supports multiple detection backends, ensemble methods, parallel processing, and
+ produces datasets ready for modern vision-language model training and evaluation.
+
+ The function orchestrates the complete pipeline:
+ 1. Dataset loading and preprocessing
+ 2. Object detection (model-based or ground truth)
+ 3. Question-answer generation with configurable parallelism
+ 4. HuggingFace dataset construction with embedded PIL images
+ 5. Optional local saving and Hub upload
+
+ Key Features:
+ π― Multi-Backend Support: Detectron2, MMDetection, Ultralytics
+ π Ensemble Methods: Weighted Box Fusion for improved accuracy
+ π Parallel Processing: Configurable QA generation workers
+ π Quality Control: Confidence thresholds and object filtering
+ πΌοΈ Modern Format: PIL images ready for VLM workflows
+ π Hub Integration: Direct upload with metadata
+
+ Args:
+ dataset_name (str): Source dataset identifier. Supported values:
+ - "bdd": BDD100K autonomous driving dataset
+ - "nuimage": NuImages large-scale dataset
+ - "waymo": Waymo Open Dataset
+ - "custom": User-provided PyTorch dataset
+
+ split (str): Dataset split to process. Common values:
+ - "train": Training split
+ - "val" or "validation": Validation split
+ - "test": Test split
+
+ models (Optional[List[Any]]): Object detection models for inference.
+ If None, uses ground truth annotations from the dataset.
+ Supports models from Detectron2, MMDetection, and Ultralytics.
+
+ use_wbf (bool): Whether to use Weighted Box Fusion ensemble method
+ to combine predictions from multiple models. Improves accuracy
+ when multiple models are provided. Default: False
+
+ wbf_config (Optional[Dict[str, Any]]): Configuration for WBF ensemble:
+ - iou_threshold: IoU threshold for box fusion
+ - model_weights: List of weights for each model
+ - confidence_threshold: Minimum confidence for fusion
+
+ conf_threshold (float): Minimum confidence score for accepting detections.
+ Lower values include more detections (potentially noisy), higher values
+ are more conservative. Range: 0.0-1.0. Default: 0.2
+
+ batch_size (int): Number of images to process in each batch.
+ Larger batches improve GPU utilization but require more memory.
+ Default: 1 (safe for most systems)
+
+ device (Optional[Union[str, torch.device]]): Device for model inference.
+ If None, automatically detects best available device (CUDA/CPU).
+ Examples: "cuda:0", "cpu", torch.device("cuda")
+
+ allowable_set (Optional[List[str]]): Filter to include only specific
+ object classes. Must be valid COCO category names. If None,
+ includes all detected objects. Example: ["person", "car", "bicycle"]
+
+ question_configs (Optional[List[Dict[str, Any]]]): Configuration for
+ question generation. Each dict contains:
+ - name: Question type (e.g., "HowMany", "LeftOf", "Quadrants")
+ - params: Question-specific parameters
+ If None, uses default question set.
+
+ num_workers (int): Number of parallel workers for data loading.
+ Should typically match CPU core count. Default: 4
+
+ qa_workers (int): Number of parallel workers for QA generation.
+ - 1: Sequential processing (debugging, memory-limited)
+ - >1: Parallel processing (production, high-throughput)
+ Recommended: 2-4x CPU cores. Default: 4
+
+ save_path (str): Local directory to save the generated dataset.
+ Creates standard HuggingFace dataset structure with Parquet files.
+ Default: "./graid-datasets"
+
+ upload_to_hub (bool): Whether to upload the dataset to HuggingFace Hub
+ for sharing and distribution. Requires hub_repo_id. Default: False
+
+ hub_repo_id (Optional[str]): HuggingFace Hub repository identifier
+ in format "username/dataset-name". Required if upload_to_hub=True.
+
+ hub_private (bool): Whether to make the Hub repository private.
+ Public repositories are discoverable by the community. Default: False
+
+ num_samples (Optional[int]): Maximum number of images to process.
+ - None or 0: Process entire dataset
+ - >0: Limit processing to specified number
+ Useful for testing and quick iterations.
+
+ use_original_filenames (bool): Whether to preserve original image filenames
+ from the source dataset. If False, generates sequential names using
+ filename_prefix. Default: True
+
+ filename_prefix (str): Prefix for generated filenames when
+ use_original_filenames=False. Example: "img" β "img000001.jpg"
+ Default: "img"
+
+
+
+ Returns:
+ DatasetDict: HuggingFace dataset dictionary containing the generated
+ VQA dataset. Keys correspond to the processed split(s). Each dataset
+ contains:
+ - image: PIL Image objects ready for VLM workflows
+ - annotations: COCO-style bounding box annotations
+ - question: Generated question text
+ - answer: Corresponding answer text
+ - question_type: Type of question (e.g., "HowMany", "LeftOf")
+ - source_id: Original image identifier
+
+ Raises:
+ ValueError: If dataset_name is not supported, configuration is invalid,
+ or required parameters are missing
+ RuntimeError: If model loading fails, inference fails, or dataset
+ construction encounters errors
+ FileNotFoundError: If specified paths don't exist
+ PermissionError: If unable to write to save_path or access Hub
+
+ Examples:
+ Basic usage with ground truth:
+ >>> dataset = generate_dataset(
+ ... dataset_name="bdd",
+ ... split="val",
+ ... num_samples=100
+ ... )
+ >>> print(f"Generated {len(dataset['val'])} QA pairs")
+
+ Multi-model ensemble with WBF:
+ >>> from graid.models import YoloModel, DetectronModel
+ >>> models = [YoloModel("yolov8x.pt"), DetectronModel("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")]
+ >>> dataset = generate_dataset(
+ ... dataset_name="bdd",
+ ... split="train",
+ ... models=models,
+ ... use_wbf=True,
+ ... wbf_config={"iou_threshold": 0.6, "model_weights": [1.0, 1.2]},
+ ... qa_workers=8,
+ ... allowable_set=["person", "car", "bicycle"],
+ ... save_path="./datasets/bdd_vqa",
+ ... upload_to_hub=True,
+ ... hub_repo_id="myuser/bdd-reasoning-dataset"
+ ... )
+
+ Custom question configuration:
+ >>> questions = [
+ ... {"name": "HowMany", "params": {}},
+ ... {"name": "Quadrants", "params": {"N": 3, "M": 3}},
+ ... {"name": "LeftOf", "params": {}}
+ ... ]
+ >>> dataset = generate_dataset(
+ ... dataset_name="nuimage",
+ ... split="val",
+ ... question_configs=questions,
+ ... qa_workers=4
+ ... )
+ """
+ # Create dataset builder
+ builder = HuggingFaceDatasetBuilder(
+ dataset_name=dataset_name,
+ split=split,
+ models=models,
+ use_wbf=use_wbf,
+ wbf_config=wbf_config,
+ conf_threshold=conf_threshold,
+ batch_size=batch_size,
+ device=device,
+ allowable_set=allowable_set,
+ question_configs=question_configs,
+ num_workers=num_workers,
+ qa_workers=qa_workers,
+ num_samples=num_samples,
+ save_path=save_path,
+ use_original_filenames=use_original_filenames,
+ filename_prefix=filename_prefix,
+ )
+
+ # Build the dataset
+ dataset_dict = builder.build()
+
+ # Save locally if requested
+ if save_path:
+ save_path_obj = Path(save_path)
+ data_dir = save_path_obj / "data"
+ data_dir.mkdir(parents=True, exist_ok=True)
+
+ for split_name, dataset in dataset_dict.items():
+ parquet_file = data_dir / f"{split_name}-00000-of-00001.parquet"
+ dataset.to_parquet(str(parquet_file))
+ logger.info(f"Dataset {split_name} split saved to {parquet_file}")
+
+ # Upload to HuggingFace Hub if requested
+ if upload_to_hub:
+ if not hub_repo_id:
+ raise ValueError("hub_repo_id is required when upload_to_hub=True")
+
+ # Import Hub utilities locally
+ from huggingface_hub import create_repo, upload_large_folder
+
+ logger.info(f"Uploading to HuggingFace Hub: {hub_repo_id}")
+
+ # Create repository
+ create_repo(
+ hub_repo_id, repo_type="dataset", private=hub_private, exist_ok=True
+ )
+
+ # Upload images and directory structure using upload_large_folder
+ # if save_path:
+ # logger.info(
+ # f"Uploading dataset files from {save_path} to Hub repository..."
+ # )
+ # try:
+ # upload_large_folder(
+ # repo_id=hub_repo_id,
+ # repo_type="dataset",
+ # folder_path=str(save_path),
+ # )
+ # logger.info("Image and directory upload completed successfully")
+ # except Exception as e:
+ # logger.error(f"Failed to upload files to Hub: {e}")
+ # raise
+
+ # Push dataset (images already cast to HFImage in builder.build())
+
+ # Push dataset with proper settings
+ dataset_dict.push_to_hub(
+ repo_id=hub_repo_id,
+ private=hub_private,
+ # embed_external_files=False, # Critical: no byte duplication
+ commit_message=f"Upload {dataset_name} {split} dataset",
+ max_shard_size="5GB",
+ )
+ logger.info(f"Dataset pushed to HuggingFace Hub: {hub_repo_id}")
+
+ # Clean up temporary image files only if we uploaded to hub
+ # In multi-split scenarios, cleanup is deferred until all splits are processed
+ # if upload_to_hub and hasattr(builder, "_cleanup_images"):
+ # try:
+ # builder._cleanup_images()
+ # logger.debug(
+ # "β
Cleaned up temporary image files after successful Hub upload"
+ # )
+ # except Exception as e:
+ # logger.warning(f"Failed to cleanup temporary image files: {e}")
+
+ # Collect statistics from builder
+ stats = None
+ if builder.profile_questions and hasattr(builder, 'question_detailed_stats'):
+ stats = {
+ 'question_counts': builder.question_counts,
+ 'detailed_stats': builder.question_detailed_stats
+ }
+
+ return dataset_dict, stats
+
+
+# Compatibility functions for existing code
+def list_available_questions() -> Dict[str, Dict[str, Any]]:
+ """
+ List all available question types with their descriptions and parameters.
+
+ This function provides a comprehensive catalog of question generation strategies
+ available in the GRAID system. Each question type implements specific reasoning
+ patterns for visual question answering based on object detection results.
+
+ Returns:
+ Dict[str, Dict[str, Any]]: Dictionary mapping question names to their metadata:
+ - "question": Human-readable description of the question type
+ - "parameters": Dict of configurable parameters (currently empty,
+ reserved for future parameter introspection)
+
+ Example:
+ >>> questions = list_available_questions()
+ >>> for name, info in questions.items():
+ ... print(f"{name}: {info['question']}")
+ HowMany: How many objects of type X are in the image?
+ LeftOf: Which objects are to the left of object X?
+ ...
+ """
+ # Local import to avoid heavy dependencies
+ from graid.questions.ObjectDetectionQ import ALL_QUESTION_CLASSES
+
+ question_info = {}
+
+ for question_name, question_class in ALL_QUESTION_CLASSES.items():
+ try:
+ # Create a temporary instance to get the question text
+ temp_instance = question_class()
+ question_text = getattr(temp_instance, "question", question_name)
+ except Exception:
+ question_text = question_name
+
+ # For now, return basic info - can be extended later
+ question_info[question_name] = {
+ "question": question_text,
+ "parameters": {}, # Would need to be populated based on inspection
+ }
+
+ return question_info
+
+
+def interactive_question_selection() -> List[Dict[str, Any]]:
+ """
+ Interactive terminal interface for selecting and configuring question types.
+
+ This function provides a user-friendly command-line interface for selecting
+ which question generation strategies to use in dataset creation. Users can
+ choose from all available question types or select specific subsets.
+
+ The interface displays:
+ - Numbered list of all available question types
+ - Description of each question type
+ - Parameter configuration options (future enhancement)
+
+ User Input Options:
+ - Specific numbers (comma-separated): Select individual questions
+ - "all": Select all available question types with default parameters
+
+ Returns:
+ List[Dict[str, Any]]: List of question configuration dictionaries, each containing:
+ - "name": Question type name (e.g., "HowMany", "LeftOf")
+ - "params": Parameter dictionary (currently empty, default parameters)
+
+ Raises:
+ KeyboardInterrupt: If user cancels the selection process
+
+ Example:
+ >>> configs = interactive_question_selection()
+ π Question Selection
+ ========================
+ Available questions:
+ 1. HowMany
+ How many objects of type X are in the image?
+ ...
+ Selection: 1,3,5
+ >>> print(configs)
+ [{"name": "HowMany", "params": {}}, {"name": "LeftOf", "params": {}}, ...]
+ """
+ print("\nπ Question Selection")
+ print("=" * 50)
+
+ available_questions = list_available_questions()
+ question_configs = []
+
+ print("Available questions:")
+ question_names = list(available_questions.keys())
+ for i, name in enumerate(question_names, 1):
+ info = available_questions[name]
+ print(f" {i}. {name}")
+ print(f" {info['question']}")
+ print()
+
+ print("Enter question numbers (comma-separated) or 'all' for all questions:")
+
+ while True:
+ try:
+ selection = input("Selection: ").strip()
+
+ if selection.lower() == "all":
+ # Add all questions with default parameters
+ for name in available_questions.keys():
+ question_configs.append({"name": name, "params": {}})
+ break
+
+ # Parse comma-separated numbers
+ selected_indices = []
+ for part in selection.split(","):
+ part = part.strip()
+ if part:
+ idx = int(part) - 1
+ if 0 <= idx < len(question_names):
+ selected_indices.append(idx)
+ else:
+ print(f"Invalid selection: {part}")
+ continue
+
+ if not selected_indices:
+ print("No valid selections made. Please try again.")
+ continue
+
+ # Configure selected questions
+ for idx in selected_indices:
+ name = question_names[idx]
+ question_configs.append({"name": name, "params": {}})
+
+ break
+
+ except ValueError:
+ print("Invalid input. Please enter numbers separated by commas or 'all'.")
+ except KeyboardInterrupt:
+ print("\nOperation cancelled.")
+ raise KeyboardInterrupt()
+
+ return question_configs
+
+
+def create_webdataset_archive(dataset_path: str, output_path: str, max_tar_size_mb: int = 1000) -> List[str]:
+ """
+ ALTERNATIVE SOLUTION: Convert existing dataset to WebDataset format (TAR archives).
+
+ This function creates TAR archives from an existing GRAID dataset to solve the
+ HuggingFace 10k file limit issue. Creates multiple TAR files if needed to stay
+ under size limits.
+
+ Args:
+ dataset_path: Path to existing dataset directory
+ output_path: Path where TAR files will be created
+ max_tar_size_mb: Maximum size per TAR file in MB
+
+ Returns:
+ List of created TAR file paths
+ """
+ import tarfile
+ import json
+ from pathlib import Path
+
+ dataset_path_obj = Path(dataset_path)
+ output_path_obj = Path(output_path)
+ output_path_obj.mkdir(parents=True, exist_ok=True)
+
+ # Load existing parquet to get QA pairs
+ from datasets import load_dataset
+
+ tar_files = []
+ current_size = 0
+ tar_index = 0
+ current_tar = None
+
+ logger.info(f"Converting {dataset_path} to WebDataset format...")
+
+ # Process each split
+ for split in ['train', 'val']:
+ parquet_file = dataset_path_obj / "data" / f"{split}-00000-of-00001.parquet"
+ if not parquet_file.exists():
+ continue
+
+ dataset = load_dataset('parquet', data_files=str(parquet_file))
+
+ for i, sample in enumerate(dataset[split]):
+ # Create new TAR if needed
+ if current_tar is None or current_size > max_tar_size_mb * 1024 * 1024:
+ if current_tar:
+ current_tar.close()
+ tar_path = output_path_obj / f"{split}_{tar_index:04d}.tar"
+ current_tar = tarfile.open(tar_path, 'w')
+ tar_files.append(str(tar_path))
+ current_size = 0
+ tar_index += 1
+ logger.info(f"Creating TAR archive: {tar_path}")
+
+ # Add image to TAR
+ image_path = sample['image']['path']
+ full_image_path = dataset_path_obj / image_path
+ if full_image_path.exists():
+ current_tar.add(full_image_path, arcname=f"{i:08d}.jpg")
+ current_size += full_image_path.stat().st_size
+
+ # Add metadata JSON
+ metadata = {
+ 'question': sample['question'],
+ 'answer': sample['answer'],
+ 'question_type': sample['question_type'],
+ 'source_id': sample['source_id'],
+ 'annotations': sample['annotations']
+ }
+
+ # Create temp JSON file and add to TAR
+ temp_json = f"/tmp/meta_{i}.json"
+ with open(temp_json, 'w') as f:
+ json.dump(metadata, f)
+ current_tar.add(temp_json, arcname=f"{i:08d}.json")
+ Path(temp_json).unlink() # cleanup temp file
+
+ if current_tar:
+ current_tar.close()
+
+ logger.info(f"Created {len(tar_files)} WebDataset TAR files")
+ return tar_files
diff --git a/graid/src/graid/data/generate_db.py b/graid/src/graid/data/generate_db.py
index 0af8b96..343e369 100755
--- a/graid/src/graid/data/generate_db.py
+++ b/graid/src/graid/data/generate_db.py
@@ -16,10 +16,8 @@
from typing import Dict, List, Optional, Union
import torch
+
from graid.data.Datasets import ObjDectDatasetBuilder
-from graid.models.Detectron import Detectron_obj
-from graid.models.MMDetection import MMdetection_obj
-from graid.models.Ultralytics import RT_DETR, Yolo
from graid.utilities.common import (
get_default_device,
project_root_dir,
@@ -29,9 +27,19 @@
)
# Dataset transforms (restored to original format)
-bdd_transform = lambda i, l: yolo_bdd_transform(i, l, new_shape=(768, 1280))
-nuimage_transform = lambda i, l: yolo_nuscene_transform(i, l, new_shape=(896, 1600))
-waymo_transform = lambda i, l: yolo_waymo_transform(i, l, (1280, 1920))
+
+
+def bdd_transform(i, l):
+ return yolo_bdd_transform(i, l, new_shape=(768, 1280))
+
+
+def nuimage_transform(i, l):
+ return yolo_nuscene_transform(i, l, new_shape=(896, 1600))
+
+
+def waymo_transform(i, l):
+ return yolo_waymo_transform(i, l, (1280, 1920))
+
DATASET_TRANSFORMS = {
"bdd": bdd_transform,
@@ -39,35 +47,9 @@
"waymo": waymo_transform,
}
-# Model configurations for different backends
-MODEL_CONFIGS = {
- "detectron": {
- "retinanet_R_101_FPN_3x": {
- "config": "COCO-Detection/retinanet_R_101_FPN_3x.yaml",
- "weights": "COCO-Detection/retinanet_R_101_FPN_3x.yaml",
- },
- "faster_rcnn_R_50_FPN_3x": {
- "config": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
- "weights": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
- },
- },
- "mmdetection": {
- "co_detr": {
- "config": "projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_lsj_16xb1_3x_coco.py",
- "checkpoint": "https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_lsj_swin_large_1x_coco-3af73af2.pth",
- },
- "dino": {
- "config": "configs/dino/dino-5scale_swin-l_8xb2-12e_coco.py",
- "checkpoint": "https://download.openmmlab.com/mmdetection/v3.0/dino/dino-5scale_swin-l_8xb2-12e_coco/dino-5scale_swin-l_8xb2-12e_coco_20230228_072924-a654145f.pth",
- },
- },
- "ultralytics": {
- "yolov8x": "yolov8x.pt",
- "yolov10x": "yolov10x.pt",
- "yolo11x": "yolo11x.pt",
- "rtdetr-x": "rtdetr-x.pt",
- },
-}
+# GRAID supports any model from the supported backends
+# Users can provide custom configurations for detectron and mmdetection
+# or use any available model file for ultralytics
def create_model(
@@ -75,87 +57,98 @@ def create_model(
model_name: str,
device: Optional[Union[str, torch.device]] = None,
threshold: float = 0.2,
+ custom_config: Optional[Dict[str, str]] = None,
):
"""
Create a model instance based on backend and model name.
-
+
Args:
backend: Backend family ('detectron', 'mmdetection', 'ultralytics')
- model_name: Specific model name within the backend
+ model_name: Model name or path for ultralytics, or custom model identifier
device: Device to use for inference
threshold: Confidence threshold for detections
-
+ custom_config: Custom configuration dict with backend-specific keys:
+ - detectron: {'config': path, 'weights': path}
+ - mmdetection: {'config': path, 'checkpoint': path}
+ - ultralytics: ignored (model_name is the model file)
+
Returns:
Model instance implementing ObjectDetectionModelI
-
+
Raises:
- ValueError: If backend or model_name is not supported
+ ValueError: If backend is not supported or required config is missing
"""
if device is None:
device = get_default_device()
-
+
if backend == "detectron":
- if model_name not in MODEL_CONFIGS["detectron"]:
- raise ValueError(f"Unsupported Detectron model: {model_name}")
-
- config_info = MODEL_CONFIGS["detectron"][model_name]
-
- # Handle both pre-configured and custom models
- if "config" in config_info and "weights" in config_info:
- # Custom model format
- config_file = config_info["config"]
- weights_file = config_info["weights"]
- else:
- # Pre-configured model format (backward compatibility)
- config_file = config_info["config"]
- weights_file = config_info["weights"]
-
+ if custom_config is None:
+ raise ValueError(
+ f"Detectron backend requires custom_config with 'config' and 'weights' keys. "
+ f"Example: {{'config': 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml', 'weights': 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml'}}"
+ )
+
+ if "config" not in custom_config or "weights" not in custom_config:
+ raise ValueError(
+ f"Detectron custom_config must contain 'config' and 'weights' keys. "
+ f"Got: {list(custom_config.keys())}"
+ )
+
+ config_file = custom_config["config"]
+ weights_file = custom_config["weights"]
+
+ from graid.models.Detectron import Detectron_obj
model = Detectron_obj(
config_file=config_file,
weights_file=weights_file,
threshold=threshold,
device=device,
)
-
+
elif backend == "mmdetection":
- if model_name not in MODEL_CONFIGS["mmdetection"]:
- raise ValueError(f"Unsupported MMDetection model: {model_name}")
-
- config_info = MODEL_CONFIGS["mmdetection"][model_name]
-
- # Handle both pre-configured and custom models
- if "config" in config_info and "checkpoint" in config_info:
- # Check if it's a custom model (absolute path) or pre-configured (relative path)
- config_path = config_info["config"]
- if not Path(config_path).is_absolute():
- # Pre-configured model - use mmdetection installation path
- mmdet_path = project_root_dir() / "install" / "mmdetection"
- config_path = str(mmdet_path / config_path)
-
- checkpoint = config_info["checkpoint"]
- else:
- raise ValueError(f"Invalid MMDetection model configuration for {model_name}")
-
+ if custom_config is None:
+ raise ValueError(
+ f"MMDetection backend requires custom_config with 'config' and 'checkpoint' keys. "
+ f"Example: {{'config': 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py', 'checkpoint': 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'}}"
+ )
+
+ if "config" not in custom_config or "checkpoint" not in custom_config:
+ raise ValueError(
+ f"MMDetection custom_config must contain 'config' and 'checkpoint' keys. "
+ f"Got: {list(custom_config.keys())}"
+ )
+
+ config_path = custom_config["config"]
+ checkpoint = custom_config["checkpoint"]
+
+ # Check if it's a custom model (absolute path) or pre-configured (relative path)
+ if not Path(config_path).is_absolute():
+ # Pre-configured model - use mmdetection installation path
+ mmdet_path = project_root_dir() / "install" / "mmdetection"
+ config_path = str(mmdet_path / config_path)
+
+ from graid.models.MMDetection import MMdetection_obj
model = MMdetection_obj(config_path, checkpoint, device=device)
model.set_threshold(threshold)
-
+
elif backend == "ultralytics":
- if model_name not in MODEL_CONFIGS["ultralytics"]:
- raise ValueError(f"Unsupported Ultralytics model: {model_name}")
-
- model_file = MODEL_CONFIGS["ultralytics"][model_name]
-
- if "rtdetr" in model_name:
+ # For ultralytics, model_name is the model file path/name
+ model_file = model_name
+
+ from graid.models.Ultralytics import RT_DETR, Yolo
+ if "rtdetr" in model_name.lower():
model = RT_DETR(model_file)
else:
model = Yolo(model_file)
-
+
model.set_threshold(threshold)
model.to(device)
-
+
else:
- raise ValueError(f"Unsupported backend: {backend}")
-
+ raise ValueError(
+ f"Unsupported backend: {backend}. Supported backends: 'detectron', 'mmdetection', 'ultralytics'"
+ )
+
return model
@@ -170,7 +163,7 @@ def generate_db(
) -> str:
"""
Generate a database for object detection results.
-
+
Args:
dataset_name: Name of dataset ('bdd', 'nuimage', 'waymo')
split: Dataset split ('train', 'val')
@@ -179,19 +172,19 @@ def generate_db(
model_name: Specific model name within backend
batch_size: Batch size for processing
device: Device to use for inference
-
+
Returns:
Database name that was created
-
+
Raises:
ValueError: If dataset_name is not supported
"""
if dataset_name not in DATASET_TRANSFORMS:
raise ValueError(f"Unsupported dataset: {dataset_name}")
-
+
if device is None:
device = get_default_device()
-
+
# Create model if backend and model_name are provided
model = None
if backend and model_name:
@@ -199,37 +192,40 @@ def generate_db(
db_name = f"{dataset_name}_{split}_{conf}_{backend}_{model_name}"
else:
db_name = f"{dataset_name}_{split}_gt"
-
+
transform = DATASET_TRANSFORMS[dataset_name]
-
+
db_builder = ObjDectDatasetBuilder(
- split=split,
- dataset=dataset_name,
- db_name=db_name,
- transform=transform
+ split=split, dataset=dataset_name, db_name=db_name, transform=transform
)
-
+
if not db_builder.is_built():
- db_builder.build(
- model=model,
- batch_size=batch_size,
- conf=conf,
- device=device
- )
-
+ db_builder.build(model=model, batch_size=batch_size, conf=conf, device=device)
+
return db_name
def list_available_models() -> Dict[str, List[str]]:
"""
- List all available models by backend.
-
+ List supported backends and example models.
+
Returns:
- Dictionary mapping backend names to lists of available models
+ Dictionary mapping backend names to example models or usage info
"""
return {
- backend: list(models.keys())
- for backend, models in MODEL_CONFIGS.items()
+ "detectron": [
+ "Custom models via config file - provide config and weights paths"
+ ],
+ "mmdetection": [
+ "Custom models via config file - provide config and checkpoint paths"
+ ],
+ "ultralytics": [
+ "yolov8x.pt",
+ "yolov10x.pt",
+ "yolo11x.pt",
+ "rtdetr-x.pt",
+ "Any YOLOv8/YOLOv10/YOLOv11/RT-DETR model file",
+ ],
}
@@ -254,64 +250,59 @@ def main():
# List available models
python generate_db.py --list-models
- """
+ """,
)
-
+
parser.add_argument(
"--dataset",
type=str,
choices=list(DATASET_TRANSFORMS.keys()),
- help="Dataset to use"
+ help="Dataset to use",
)
-
+
parser.add_argument(
- "--split",
- type=str,
- choices=["train", "val"],
- help="Dataset split to use"
+ "--split", type=str, choices=["train", "val"], help="Dataset split to use"
)
-
+
parser.add_argument(
"--backend",
type=str,
- choices=list(MODEL_CONFIGS.keys()),
- help="Model backend to use"
+ choices=["detectron", "mmdetection", "ultralytics"],
+ help="Model backend to use",
)
-
+
parser.add_argument(
- "--model",
- type=str,
- help="Specific model name within the backend"
+ "--model", type=str, help="Specific model name within the backend"
)
-
+
parser.add_argument(
"--conf",
type=float,
default=0.2,
- help="Confidence threshold for detections (default: 0.2)"
+ help="Confidence threshold for detections (default: 0.2)",
)
-
+
parser.add_argument(
"--batch-size",
type=int,
default=1,
- help="Batch size for processing (default: 1)"
+ help="Batch size for processing (default: 1)",
)
-
+
parser.add_argument(
"--device",
type=str,
- help="Device to use (e.g., 'cuda:0', 'cpu'). Auto-detected if not specified."
+ help="Device to use (e.g., 'cuda:0', 'cpu'). Auto-detected if not specified.",
)
-
+
parser.add_argument(
"--list-models",
action="store_true",
- help="List all available models by backend"
+ help="List all available models by backend",
)
-
+
args = parser.parse_args()
-
+
if args.list_models:
models = list_available_models()
print("Available models by backend:")
@@ -320,21 +311,19 @@ def main():
for model in model_list:
print(f" - {model}")
return
-
+
if not args.dataset or not args.split:
parser.error("--dataset and --split are required unless using --list-models")
-
+
if args.backend and not args.model:
parser.error("--model is required when --backend is specified")
-
+
if args.model and not args.backend:
parser.error("--backend is required when --model is specified")
-
- # Validate model exists in backend
- if args.backend and args.model:
- if args.model not in MODEL_CONFIGS[args.backend]:
- parser.error(f"Model '{args.model}' not available for backend '{args.backend}'")
-
+
+ # Note: Model validation is now done at runtime by trying to load the model
+ # Users can provide any model name/path for their chosen backend
+
try:
db_name = generate_db(
dataset_name=args.dataset,
@@ -346,11 +335,11 @@ def main():
device=args.device,
)
print(f"Successfully generated database: {db_name}")
-
+
except Exception as e:
print(f"Error generating database: {e}")
return 1
-
+
return 0
diff --git a/graid/src/graid/data/interactive_mode.py b/graid/src/graid/data/interactive_mode.py
new file mode 100644
index 0000000..08cb6ee
--- /dev/null
+++ b/graid/src/graid/data/interactive_mode.py
@@ -0,0 +1,805 @@
+"""
+Interactive Mode for HuggingFace Dataset Generation
+
+This module provides interactive command-line interfaces for:
+- Single model selection
+- WBF multi-model selection with validation
+- Configuration parameter setting
+"""
+
+import logging
+from typing import Any, Optional
+
+import typer
+
+from graid.data.config_support import (
+ ConfigurationError,
+ DatasetGenerationConfig,
+ ModelConfig,
+ WBFConfig,
+)
+from graid.data.generate_dataset import validate_model_config
+from graid.data.generate_db import list_available_models
+from graid.utilities.coco import coco_labels
+from graid.utilities.common import get_default_device
+
+logger = logging.getLogger(__name__)
+
+
+def get_dataset_choice() -> str:
+ """Interactive dataset selection."""
+ typer.secho("π Step 1: Choose a dataset", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+
+ datasets = {
+ "1": ("bdd", "BDD100K - Berkeley DeepDrive autonomous driving dataset"),
+ "2": ("nuimage", "NuImages - Large-scale autonomous driving dataset"),
+ "3": ("waymo", "Waymo Open Dataset - Self-driving car dataset"),
+ }
+
+ for key, (name, desc) in datasets.items():
+ typer.echo(
+ f" {key}. {typer.style(name.upper(), fg=typer.colors.GREEN)} - {desc}"
+ )
+
+ typer.echo()
+ while True:
+ choice = typer.prompt("Select dataset (1-3)")
+ if choice in datasets:
+ dataset_name = datasets[choice][0]
+ typer.secho(f"β Selected: {dataset_name.upper()}", fg=typer.colors.GREEN)
+ typer.echo()
+ return dataset_name
+ typer.secho("Invalid choice. Please enter 1, 2, or 3.", fg=typer.colors.RED)
+
+
+def get_split_choice() -> str:
+ """Interactive split selection."""
+ typer.secho("π Step 2: Choose data split", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+
+ splits = {
+ "1": ("train", "Training set - typically largest portion of data"),
+ "2": ("val", "Validation set - used for model evaluation"),
+ "3": ("test", "Test set - used for final evaluation"),
+ }
+
+ for key, (name, desc) in splits.items():
+ typer.echo(
+ f" {key}. {typer.style(name.upper(), fg=typer.colors.GREEN)} - {desc}"
+ )
+
+ typer.echo()
+ while True:
+ choice = typer.prompt("Select split (1-3)")
+ if choice in splits:
+ split_name = splits[choice][0]
+ typer.secho(f"β Selected: {split_name.upper()}", fg=typer.colors.GREEN)
+ typer.echo()
+ return split_name
+ typer.secho("Invalid choice. Please enter 1, 2, or 3.", fg=typer.colors.RED)
+
+
+def get_model_backend_choice() -> str:
+ """Interactive model backend selection."""
+ available_models = list_available_models()
+ backends = list(available_models.keys())
+
+ typer.echo("Available backends:")
+ for i, backend in enumerate(backends, 1):
+ typer.echo(f" {i}. {typer.style(backend.upper(), fg=typer.colors.GREEN)}")
+
+ typer.echo()
+ while True:
+ try:
+ backend_choice = int(typer.prompt("Select backend (number)")) - 1
+ if 0 <= backend_choice < len(backends):
+ backend = backends[backend_choice]
+ typer.secho(
+ f"β Selected backend: {backend.upper()}", fg=typer.colors.GREEN
+ )
+ return backend
+ except ValueError:
+ pass
+ typer.secho("Invalid choice. Please enter a valid number.", fg=typer.colors.RED)
+
+
+def get_model_name_choice(backend: str) -> str:
+ """Interactive model name selection for a given backend."""
+ available_models = list_available_models()
+ models = available_models[backend]
+
+ typer.echo(f"Available {backend.upper()} models:")
+ for i, model in enumerate(models, 1):
+ typer.echo(f" {i}. {typer.style(model, fg=typer.colors.GREEN)}")
+
+ typer.echo()
+ while True:
+ try:
+ model_choice = int(typer.prompt("Select model (number)")) - 1
+ if 0 <= model_choice < len(models):
+ model_name = models[model_choice]
+ typer.secho(f"β Selected model: {model_name}", fg=typer.colors.GREEN)
+ return model_name
+ except ValueError:
+ pass
+ typer.secho("Invalid choice. Please enter a valid number.", fg=typer.colors.RED)
+
+
+def get_custom_config_choice(backend: str) -> Optional[dict[str, Any]]:
+ """Interactive custom configuration setup."""
+ typer.echo()
+ use_custom = typer.confirm(
+ "Do you want to use custom configuration?", default=False
+ )
+
+ if not use_custom:
+ return None
+
+ typer.echo()
+ typer.secho("π οΈ Custom Model Configuration", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+
+ custom_config = {}
+
+ if backend == "detectron":
+ typer.echo("Detectron2 Configuration:")
+ config_file = typer.prompt(
+ "Config file path (e.g., 'COCO-Detection/retinanet_R_50_FPN_3x.yaml')"
+ )
+ weights_file = typer.prompt("Weights file path (e.g., 'path/to/model.pth')")
+ custom_config = {"config": config_file, "weights": weights_file}
+
+ elif backend == "mmdetection":
+ typer.echo("MMDetection Configuration:")
+ config_file = typer.prompt(
+ "Config file path (e.g., 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py')"
+ )
+ checkpoint = typer.prompt("Checkpoint file path or URL")
+ custom_config = {"config": config_file, "checkpoint": checkpoint}
+
+ elif backend == "ultralytics":
+ typer.echo("Ultralytics Configuration:")
+ model_file = typer.prompt("Model file path (e.g., 'yolov8x.pt')")
+ custom_config = {"model_file": model_file}
+
+ return custom_config
+
+
+def get_confidence_threshold() -> float:
+ """Interactive confidence threshold selection."""
+ typer.echo("π― Confidence threshold filters out low-confidence detections.")
+ typer.echo("β’ Lower values (0.1-0.3): More detections, some false positives")
+ typer.echo("β’ Higher values (0.5-0.8): Fewer detections, higher precision")
+ typer.echo()
+
+ while True:
+ try:
+ conf = float(typer.prompt("Enter confidence threshold", default="0.2"))
+ if 0.0 <= conf <= 1.0:
+ typer.secho(f"β Confidence threshold: {conf}", fg=typer.colors.GREEN)
+ return conf
+ typer.secho(
+ "Please enter a value between 0.0 and 1.0.", fg=typer.colors.RED
+ )
+ except ValueError:
+ typer.secho("Please enter a valid number.", fg=typer.colors.RED)
+
+
+def validate_model_interactive(model_config: ModelConfig) -> bool:
+ """Validate a model configuration interactively."""
+ typer.echo(
+ f"π Validating model: {model_config.backend} - {model_config.model_name}"
+ )
+
+ try:
+ # Import the enhanced validation function
+ from graid.data.generate_dataset import validate_model_config
+
+ # Validate the model
+ is_valid, error_msg = validate_model_config(
+ backend=model_config.backend,
+ model_name=model_config.model_name,
+ config=model_config.custom_config,
+ device=model_config.device,
+ )
+
+ if is_valid:
+ typer.secho("β
Model validation successful!", fg=typer.colors.GREEN)
+ return True
+ else:
+ typer.secho(f"β Model validation failed: {error_msg}", fg=typer.colors.RED)
+
+ # Ask user what to do
+ typer.echo()
+ typer.echo("Options:")
+ typer.echo(" 1. Continue anyway (not recommended)")
+ typer.echo(" 2. Choose a different model")
+ typer.echo(" 3. Cancel generation")
+
+ while True:
+ choice = typer.prompt("Select option (1-3)")
+ if choice == "1":
+ typer.secho(
+ "β οΈ Continuing with potentially invalid model",
+ fg=typer.colors.YELLOW,
+ )
+ return True
+ elif choice == "2":
+ return False
+ elif choice == "3":
+ typer.secho("Generation cancelled", fg=typer.colors.RED)
+ raise typer.Exit(1)
+ else:
+ typer.secho(
+ "Invalid choice. Please enter 1, 2, or 3.", fg=typer.colors.RED
+ )
+
+ except Exception as e:
+ typer.secho(f"β Validation error: {str(e)}", fg=typer.colors.RED)
+
+ # Ask user what to do
+ typer.echo()
+ typer.echo("Options:")
+ typer.echo(" 1. Continue anyway (not recommended)")
+ typer.echo(" 2. Choose a different model")
+ typer.echo(" 3. Cancel generation")
+
+ while True:
+ choice = typer.prompt("Select option (1-3)")
+ if choice == "1":
+ typer.secho(
+ "β οΈ Continuing with potentially invalid model",
+ fg=typer.colors.YELLOW,
+ )
+ return True
+ elif choice == "2":
+ return False
+ elif choice == "3":
+ typer.secho("Generation cancelled", fg=typer.colors.RED)
+ raise typer.Exit(1)
+ else:
+ typer.secho(
+ "Invalid choice. Please enter 1, 2, or 3.", fg=typer.colors.RED
+ )
+
+
+def get_single_model_config() -> ModelConfig:
+ """Interactive single model configuration."""
+ typer.secho("π§ Single Model Configuration", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+
+ backend = get_model_backend_choice()
+ model_name = get_model_name_choice(backend)
+ custom_config = get_custom_config_choice(backend)
+
+ typer.echo()
+ confidence_threshold = get_confidence_threshold()
+
+ # Create model config
+ model_config = ModelConfig(
+ backend=backend,
+ model_name=model_name,
+ custom_config=custom_config,
+ confidence_threshold=confidence_threshold,
+ device=None, # Will use default device
+ )
+
+ # Validate the model
+ typer.echo()
+ if not validate_model_interactive(model_config):
+ typer.secho(
+ "Model validation failed. Please check your configuration.",
+ fg=typer.colors.RED,
+ )
+ if not typer.confirm("Do you want to continue anyway?", default=False):
+ raise typer.Abort()
+
+ return model_config
+
+
+def get_wbf_models_config() -> list[ModelConfig]:
+ """Interactive WBF multi-model configuration."""
+ typer.secho("π₯ WBF Multi-Model Configuration", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+ typer.echo("WBF (Weighted Boxes Fusion) combines predictions from multiple models.")
+ typer.echo("You need at least 2 models for WBF to work.")
+ typer.echo()
+
+ models = []
+
+ while True:
+ typer.secho(
+ f"π§ Adding Model #{len(models) + 1}", fg=typer.colors.CYAN, bold=True
+ )
+ typer.echo()
+
+ backend = get_model_backend_choice()
+ model_name = get_model_name_choice(backend)
+ custom_config = get_custom_config_choice(backend)
+
+ typer.echo()
+ confidence_threshold = get_confidence_threshold()
+
+ # Create model config
+ model_config = ModelConfig(
+ backend=backend,
+ model_name=model_name,
+ custom_config=custom_config,
+ confidence_threshold=confidence_threshold,
+ device=None, # Will use default device
+ )
+
+ # Validate the model
+ typer.echo()
+ if validate_model_interactive(model_config):
+ models.append(model_config)
+ typer.secho(
+ f"β
Model {len(models)} added successfully!", fg=typer.colors.GREEN
+ )
+ else:
+ typer.secho("β Model validation failed.", fg=typer.colors.RED)
+ if typer.confirm("Do you want to add this model anyway?", default=False):
+ models.append(model_config)
+ typer.secho(
+ f"β οΈ Model {len(models)} added with warnings.",
+ fg=typer.colors.YELLOW,
+ )
+
+ typer.echo()
+
+ # Check if we have enough models
+ if len(models) >= 2:
+ if not typer.confirm("Do you want to add another model?", default=False):
+ break
+ else:
+ typer.echo("You need at least 2 models for WBF. Adding another model...")
+ typer.echo()
+
+ return models
+
+
+def get_wbf_config(num_models: int) -> WBFConfig:
+ """Interactive WBF configuration."""
+ typer.secho("βοΈ WBF Configuration", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+
+ # IoU threshold
+ typer.echo("IoU threshold for box matching (0.0-1.0):")
+ typer.echo("β’ Lower values: More aggressive fusion")
+ typer.echo("β’ Higher values: More conservative fusion")
+ while True:
+ try:
+ iou_threshold = float(typer.prompt("IoU threshold", default="0.55"))
+ if 0.0 <= iou_threshold <= 1.0:
+ break
+ typer.secho(
+ "Please enter a value between 0.0 and 1.0.", fg=typer.colors.RED
+ )
+ except ValueError:
+ typer.secho("Please enter a valid number.", fg=typer.colors.RED)
+
+ # Skip box threshold
+ typer.echo()
+ typer.echo("Skip box threshold (0.0-1.0):")
+ typer.echo("β’ Boxes with scores below this threshold will be ignored")
+ while True:
+ try:
+ skip_threshold = float(typer.prompt("Skip box threshold", default="0.0"))
+ if 0.0 <= skip_threshold <= 1.0:
+ break
+ typer.secho(
+ "Please enter a value between 0.0 and 1.0.", fg=typer.colors.RED
+ )
+ except ValueError:
+ typer.secho("Please enter a valid number.", fg=typer.colors.RED)
+
+ # Model weights
+ typer.echo()
+ use_weights = typer.confirm(
+ f"Do you want to specify custom weights for the {num_models} models?",
+ default=False,
+ )
+
+ model_weights = None
+ if use_weights:
+ typer.echo("Enter weights for each model (positive numbers):")
+ model_weights = []
+ for i in range(num_models):
+ while True:
+ try:
+ weight = float(
+ typer.prompt(f"Weight for model {i+1}", default="1.0")
+ )
+ if weight > 0:
+ model_weights.append(weight)
+ break
+ typer.secho("Weight must be positive.", fg=typer.colors.RED)
+ except ValueError:
+ typer.secho("Please enter a valid number.", fg=typer.colors.RED)
+
+ return WBFConfig(
+ iou_threshold=iou_threshold,
+ skip_box_threshold=skip_threshold,
+ model_weights=model_weights,
+ )
+
+
+def get_generation_settings() -> dict[str, Any]:
+ """Interactive generation settings."""
+ typer.secho("βοΈ Generation Settings", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+
+ # Batch size
+ while True:
+ try:
+ batch_size = int(typer.prompt("Batch size", default="1"))
+ if batch_size > 0:
+ break
+ typer.secho("Batch size must be positive.", fg=typer.colors.RED)
+ except ValueError:
+ typer.secho("Please enter a valid number.", fg=typer.colors.RED)
+
+ # Save path
+ save_path = typer.prompt("Save path (optional, press Enter to skip)", default="")
+ if not save_path:
+ save_path = None
+
+ # HuggingFace Hub settings
+ typer.echo()
+ upload_to_hub = typer.confirm("Upload to HuggingFace Hub?", default=False)
+
+ hub_repo_id = None
+ hub_private = False
+ if upload_to_hub:
+ hub_repo_id = typer.prompt("Hub repository ID (e.g., 'username/dataset-name')")
+ hub_private = typer.confirm("Make repository private?", default=False)
+
+ return {
+ "batch_size": batch_size,
+ "save_path": save_path,
+ "upload_to_hub": upload_to_hub,
+ "hub_repo_id": hub_repo_id,
+ "hub_private": hub_private,
+ }
+
+
+def get_allowable_set_choice() -> Optional[list[str]]:
+ """Interactive allowable set selection."""
+ typer.secho(
+ "π― Step: Configure object filtering (optional)",
+ fg=typer.colors.BLUE,
+ bold=True,
+ )
+ typer.echo()
+
+ typer.echo(
+ "Allowable set filters detections to only include specified COCO objects."
+ )
+ typer.echo(
+ "This is useful when you know your images are biased toward certain object types."
+ )
+ typer.echo("If you leave this empty, all detected objects will be included.")
+ typer.echo()
+
+ # Get valid COCO objects
+ valid_coco_objects = set(coco_labels.values())
+ # Remove undefined as it's not a real COCO class
+ valid_coco_objects.discard("undefined")
+ valid_objects_sorted = sorted(valid_coco_objects)
+
+ use_allowable_set = typer.confirm(
+ "Do you want to filter detections to specific object types?", default=False
+ )
+
+ if not use_allowable_set:
+ typer.secho(
+ "β No filtering - all detected objects will be included",
+ fg=typer.colors.GREEN,
+ )
+ return None
+
+ typer.echo()
+ typer.echo("π Choose how to specify the allowable objects:")
+ typer.echo(" 1. Select from common autonomous driving objects")
+ typer.echo(" 2. Select from all COCO objects")
+ typer.echo(" 3. Enter objects manually")
+ typer.echo()
+
+ while True:
+ choice = typer.prompt("Select option (1-3)")
+
+ if choice == "1":
+ return get_common_av_objects()
+ elif choice == "2":
+ return get_all_coco_objects_interactive(valid_objects_sorted)
+ elif choice == "3":
+ return get_manual_objects_input(valid_objects_sorted)
+ else:
+ typer.secho("Invalid choice. Please enter 1, 2, or 3.", fg=typer.colors.RED)
+
+
+def get_common_av_objects() -> list[str]:
+ """Get common autonomous vehicle objects."""
+ typer.echo()
+ typer.secho("π Common Autonomous Vehicle Objects", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+
+ av_objects = [
+ "person",
+ "bicycle",
+ "car",
+ "motorcycle",
+ "bus",
+ "train",
+ "truck",
+ "traffic light",
+ "stop sign",
+ "fire hydrant",
+ "parking meter",
+ "bench",
+ ]
+
+ typer.echo("Available objects:")
+ for i, obj in enumerate(av_objects, 1):
+ typer.echo(f" {i:2d}. {obj}")
+
+ typer.echo()
+ typer.echo(
+ "Enter the numbers of objects to include (comma-separated, e.g., '1,2,3-5,7'):"
+ )
+ typer.echo("Use ranges like '3-5' for objects 3, 4, and 5")
+ typer.echo("Or enter 'all' to include all objects")
+
+ while True:
+ selection = typer.prompt("Selection").strip()
+
+ if selection.lower() == "all":
+ selected_objects = av_objects.copy()
+ break
+
+ try:
+ selected_objects = []
+ for part in selection.split(","):
+ part = part.strip()
+ if "-" in part:
+ start, end = map(int, part.split("-"))
+ for i in range(start, end + 1):
+ if 1 <= i <= len(av_objects):
+ selected_objects.append(av_objects[i - 1])
+ else:
+ i = int(part)
+ if 1 <= i <= len(av_objects):
+ selected_objects.append(av_objects[i - 1])
+
+ # Remove duplicates while preserving order
+ selected_objects = list(dict.fromkeys(selected_objects))
+
+ if not selected_objects:
+ typer.secho(
+ "No valid objects selected. Please try again.", fg=typer.colors.RED
+ )
+ continue
+
+ break
+
+ except ValueError:
+ typer.secho(
+ "Invalid input format. Please use numbers, ranges, or 'all'.",
+ fg=typer.colors.RED,
+ )
+
+ typer.echo()
+ typer.secho(f"β Selected {len(selected_objects)} objects:", fg=typer.colors.GREEN)
+ for obj in selected_objects:
+ typer.echo(f" β’ {obj}")
+
+ return selected_objects
+
+
+def get_all_coco_objects_interactive(valid_objects: list[str]) -> list[str]:
+ """Interactive selection from all COCO objects."""
+ typer.echo()
+ typer.secho("π All COCO Objects", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+
+ typer.echo(f"Total COCO objects: {len(valid_objects)}")
+ typer.echo("This is a long list - consider using manual entry instead.")
+ typer.echo()
+
+ # Show objects in groups of 10
+ for i in range(0, len(valid_objects), 10):
+ group = valid_objects[i : i + 10]
+ typer.echo(f"Objects {i+1}-{min(i+10, len(valid_objects))}:")
+ for j, obj in enumerate(group, i + 1):
+ typer.echo(f" {j:2d}. {obj}")
+ typer.echo()
+
+ typer.echo(
+ "Enter the numbers of objects to include (comma-separated, e.g., '1,2,3-5,7'):"
+ )
+ typer.echo("Use ranges like '3-5' for objects 3, 4, and 5")
+
+ while True:
+ selection = typer.prompt("Selection").strip()
+
+ try:
+ selected_objects = []
+ for part in selection.split(","):
+ part = part.strip()
+ if "-" in part:
+ start, end = map(int, part.split("-"))
+ for i in range(start, end + 1):
+ if 1 <= i <= len(valid_objects):
+ selected_objects.append(valid_objects[i - 1])
+ else:
+ i = int(part)
+ if 1 <= i <= len(valid_objects):
+ selected_objects.append(valid_objects[i - 1])
+
+ # Remove duplicates while preserving order
+ selected_objects = list(dict.fromkeys(selected_objects))
+
+ if not selected_objects:
+ typer.secho(
+ "No valid objects selected. Please try again.", fg=typer.colors.RED
+ )
+ continue
+
+ break
+
+ except ValueError:
+ typer.secho(
+ "Invalid input format. Please use numbers and ranges.",
+ fg=typer.colors.RED,
+ )
+
+ typer.echo()
+ typer.secho(f"β Selected {len(selected_objects)} objects:", fg=typer.colors.GREEN)
+ for obj in selected_objects:
+ typer.echo(f" β’ {obj}")
+
+ return selected_objects
+
+
+def get_manual_objects_input(valid_objects: list[str]) -> list[str]:
+ """Manual object input with validation."""
+ typer.echo()
+ typer.secho("βοΈ Manual Object Entry", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+
+ typer.echo("Enter object names separated by commas (e.g., 'person, car, truck'):")
+ typer.echo(
+ "Valid COCO object names include: person, car, truck, bus, bicycle, etc."
+ )
+ typer.echo()
+
+ while True:
+ input_str = typer.prompt("Objects").strip()
+
+ if not input_str:
+ typer.secho("Please enter at least one object.", fg=typer.colors.RED)
+ continue
+
+ # Parse and validate objects
+ objects = [obj.strip() for obj in input_str.split(",")]
+ objects = [obj for obj in objects if obj] # Remove empty strings
+
+ valid_objects_set = set(valid_objects)
+ invalid_objects = [obj for obj in objects if obj not in valid_objects_set]
+
+ if invalid_objects:
+ typer.secho(f"Invalid objects: {invalid_objects}", fg=typer.colors.RED)
+ typer.echo("Please check spelling and use valid COCO object names.")
+ typer.echo(
+ "Common objects: person, car, truck, bus, bicycle, motorcycle, airplane, boat, train"
+ )
+ continue
+
+ # Remove duplicates while preserving order
+ objects = list(dict.fromkeys(objects))
+
+ typer.echo()
+ typer.secho(f"β Selected {len(objects)} objects:", fg=typer.colors.GREEN)
+ for obj in objects:
+ typer.echo(f" β’ {obj}")
+
+ return objects
+
+
+def create_interactive_config() -> DatasetGenerationConfig:
+ """Create a complete dataset generation configuration interactively."""
+ typer.secho(
+ "π GRAID HuggingFace Dataset Generation", fg=typer.colors.CYAN, bold=True
+ )
+ typer.echo()
+ typer.echo(
+ "This interactive wizard will help you create a dataset generation configuration."
+ )
+ typer.echo()
+
+ # Step 1: Dataset and split
+ dataset_name = get_dataset_choice()
+ split = get_split_choice()
+
+ # Step 2: Model selection mode
+ typer.secho(
+ "π§ Step 3: Choose model configuration", fg=typer.colors.BLUE, bold=True
+ )
+ typer.echo()
+
+ mode_options = {
+ "1": ("single", "Single Model - Use one model for predictions"),
+ "2": (
+ "wbf",
+ "WBF Multi-Model - Use multiple models with Weighted Boxes Fusion",
+ ),
+ "3": ("ground_truth", "Ground Truth - Use original dataset annotations"),
+ }
+
+ for key, (mode, desc) in mode_options.items():
+ typer.echo(
+ f" {key}. {typer.style(mode.upper(), fg=typer.colors.GREEN)} - {desc}"
+ )
+
+ typer.echo()
+ while True:
+ choice = typer.prompt("Select mode (1-3)")
+ if choice in mode_options:
+ mode = mode_options[choice][0]
+ typer.secho(f"β Selected: {mode.upper()}", fg=typer.colors.GREEN)
+ break
+ typer.secho("Invalid choice. Please enter 1, 2, or 3.", fg=typer.colors.RED)
+
+ # Step 3: Model configuration
+ typer.echo()
+ models = []
+ use_wbf = False
+ wbf_config = None
+
+ if mode == "single":
+ models = [get_single_model_config()]
+ elif mode == "wbf":
+ models = get_wbf_models_config()
+ use_wbf = True
+ typer.echo()
+ wbf_config = get_wbf_config(len(models))
+ # For ground_truth, models remains empty
+
+ # Step 4: Generation settings
+ typer.echo()
+ settings = get_generation_settings()
+
+ # Step 5: Allowable set selection
+ allowable_set = get_allowable_set_choice()
+
+ # Create configuration
+ try:
+ config = DatasetGenerationConfig(
+ dataset_name=dataset_name,
+ split=split,
+ models=models,
+ use_wbf=use_wbf,
+ wbf_config=wbf_config,
+ confidence_threshold=0.2, # Default, can be overridden per model
+ batch_size=settings["batch_size"],
+ device=None, # Will use default device
+ save_path=settings["save_path"],
+ upload_to_hub=settings["upload_to_hub"],
+ hub_repo_id=settings["hub_repo_id"],
+ hub_private=settings["hub_private"],
+ allowable_set=allowable_set,
+ )
+
+ typer.echo()
+ typer.secho(
+ "β
Configuration created successfully!", fg=typer.colors.GREEN, bold=True
+ )
+ return config
+
+ except ConfigurationError as e:
+ typer.secho(f"β Configuration error: {e}", fg=typer.colors.RED)
+ raise typer.Exit(1)
+ except Exception as e:
+ typer.secho(f"β Unexpected error: {e}", fg=typer.colors.RED)
+ raise typer.Exit(1)
diff --git a/graid/src/graid/data/loaders.py b/graid/src/graid/data/loaders.py
new file mode 100644
index 0000000..9c2d483
--- /dev/null
+++ b/graid/src/graid/data/loaders.py
@@ -0,0 +1,150 @@
+"""
+Common Dataset Loader Factory for GRAID
+
+Provides a centralized way to create dataset loaders across the entire codebase.
+Eliminates duplicate dataset initialization logic scattered throughout different modules.
+"""
+
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Any
+
+
+class DatasetLoaderCreator(ABC):
+ """Abstract base class for dataset loader creators."""
+
+ @staticmethod
+ @abstractmethod
+ def create(split: str, transform: Any, **kwargs) -> Any:
+ """Create a dataset loader instance."""
+ pass
+
+
+class BddLoaderCreator(DatasetLoaderCreator):
+ """Creator for BDD100K dataset loaders."""
+
+ @staticmethod
+ def create(split: str, transform: Any, **kwargs) -> Any:
+ """Create BDD100K dataset loader."""
+ from graid.data.ImageLoader import Bdd100kDataset
+
+ pkl_root = Path("data") / f"bdd_{split}"
+ rebuild_needed = not (pkl_root / "0.pkl").exists()
+
+ # Allow override of rebuild behavior
+ rebuild = kwargs.get('rebuild', rebuild_needed)
+ use_time_filtered = kwargs.get('use_time_filtered', False)
+
+ return Bdd100kDataset(
+ split=split,
+ transform=transform,
+ use_time_filtered=use_time_filtered,
+ rebuild=rebuild,
+ )
+
+
+class NuImagesLoaderCreator(DatasetLoaderCreator):
+ """Creator for NuImages dataset loaders."""
+
+ @staticmethod
+ def create(split: str, transform: Any, **kwargs) -> Any:
+ """Create NuImages dataset loader."""
+ from graid.data.ImageLoader import NuImagesDataset
+
+ # Allow override of size parameter
+ size = kwargs.get('size', 'all')
+
+ return NuImagesDataset(
+ split=split,
+ size=size,
+ transform=transform
+ )
+
+
+class WaymoLoaderCreator(DatasetLoaderCreator):
+ """Creator for Waymo dataset loaders."""
+
+ @staticmethod
+ def create(split: str, transform: Any, **kwargs) -> Any:
+ """Create Waymo dataset loader."""
+ from graid.data.ImageLoader import WaymoDataset
+
+ # Convert split name for Waymo's naming convention
+ split_name = "validation" if split == "val" else split + "ing"
+
+ return WaymoDataset(
+ split=split_name,
+ transform=transform
+ )
+
+
+class DatasetLoaderFactory:
+ """
+ Centralized factory for creating dataset loaders.
+
+ This factory can be used throughout the GRAID codebase to eliminate
+ duplicate dataset initialization logic.
+
+ Example usage:
+ transform = get_some_transform()
+ loader = DatasetLoaderFactory.create("bdd", "train", transform)
+ """
+
+ # Registry of available dataset creators
+ _CREATORS: dict[str, DatasetLoaderCreator] = {
+ "bdd": BddLoaderCreator,
+ "nuimage": NuImagesLoaderCreator,
+ "waymo": WaymoLoaderCreator
+ }
+
+ @classmethod
+ def create(cls, dataset_name: str, split: str, transform: Any, **kwargs) -> Any:
+ """
+ Create a dataset loader for the specified dataset.
+
+ Args:
+ dataset_name: Name of the dataset ("bdd", "nuimage", "waymo")
+ split: Dataset split ("train", "val", "test")
+ transform: Transform function to apply to images
+ **kwargs: Additional arguments passed to the specific creator
+
+ Returns:
+ Dataset loader instance
+
+ Raises:
+ ValueError: If dataset_name is not supported
+ """
+ if dataset_name not in cls._CREATORS:
+ available = list(cls._CREATORS.keys())
+ raise ValueError(f"Unsupported dataset: {dataset_name}. Available: {available}")
+
+ creator = cls._CREATORS[dataset_name]
+ return creator.create(split, transform, **kwargs)
+
+ @classmethod
+ def register_creator(cls, dataset_name: str, creator: DatasetLoaderCreator):
+ """
+ Register a new dataset creator.
+
+ Args:
+ dataset_name: Name to register the dataset under
+ creator: Creator class implementing DatasetLoaderCreator interface
+ """
+ cls._CREATORS[dataset_name] = creator
+
+ @classmethod
+ def get_supported_datasets(cls) -> list[str]:
+ """Get list of supported dataset names."""
+ return list(cls._CREATORS.keys())
+
+
+# Convenience function for backward compatibility
+def create_dataset_loader(dataset_name: str, split: str, transform: Any, **kwargs) -> Any:
+ """
+ Convenience function to create dataset loaders.
+
+ This is a simple wrapper around DatasetLoaderFactory.create() for
+ easier migration from existing code.
+ """
+ return DatasetLoaderFactory.create(dataset_name, split, transform, **kwargs)
+
diff --git a/graid/src/graid/evaluator/eval_vlms.py b/graid/src/graid/evaluator/eval_vlms.py
index 8d0fff4..c423fda 100644
--- a/graid/src/graid/evaluator/eval_vlms.py
+++ b/graid/src/graid/evaluator/eval_vlms.py
@@ -42,11 +42,10 @@
import re
import sqlite3
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Optional
import numpy as np
import pandas as pd
-from PIL import Image
from sqlitedict import SqliteDict
from tqdm import tqdm
@@ -183,7 +182,7 @@
}
-def list_available_vlms() -> Dict[str, List[str]]:
+def list_available_vlms() -> dict[str, list[str]]:
"""
List all available VLM types and their models.
@@ -199,7 +198,7 @@ def list_available_vlms() -> Dict[str, List[str]]:
return result
-def list_available_metrics() -> List[str]:
+def list_available_metrics() -> list[str]:
"""
List all available evaluation metrics.
@@ -209,7 +208,7 @@ def list_available_metrics() -> List[str]:
return list(METRIC_CONFIGS.keys())
-def list_available_prompts() -> List[str]:
+def list_available_prompts() -> list[str]:
"""
List all available prompt types.
@@ -328,7 +327,7 @@ def create_prompt(
def validate_configuration(
vlm_type: str, metric_type: str, prompt_type: str
-) -> Tuple[bool, Optional[str]]:
+) -> tuple[bool, Optional[str]]:
"""
Validate that the combination of VLM, metric, and prompt is compatible.
@@ -382,7 +381,7 @@ def _create_output_directory(
return results_dir
-def _load_database_tables(db_path: str) -> Dict[str, pd.DataFrame]:
+def _load_database_tables(db_path: str) -> dict[str, pd.DataFrame]:
"""Load all tables from SQLite database."""
conn = sqlite3.connect(db_path)
@@ -401,8 +400,8 @@ def _load_database_tables(db_path: str) -> Dict[str, pd.DataFrame]:
def _filter_and_sample_data(
- dataframes: Dict[str, pd.DataFrame], sample_size: int
-) -> Dict[str, Tuple[pd.DataFrame, int]]:
+ dataframes: dict[str, pd.DataFrame], sample_size: int
+) -> dict[str, tuple[pd.DataFrame, int]]:
"""Filter and sample data from database tables."""
sampled_dataframes = {}
print("Filtering rows...")
@@ -553,16 +552,49 @@ def evaluate_vlm(
)
continue
+ # Get the dataframe for the current table
+ df_to_process, _ = sampled_dataframes[table_name]
+
+ # Limit to min_available_samples
+ df_to_process = df_to_process.head(min_available_samples)
+
# Process new samples if needed
print(f"Processing table {table_idx}: {table_name}")
+
+ # Lists to store data for this table
+ questions, answers, preds, correctness = [], [], [], []
+
+ for _, row in tqdm(df_to_process.iterrows(), total=len(df_to_process)):
+ d = row.to_dict()
+ image_path, v = d["key"], json.loads(d["value"])
+
+ # Construct full image path
+ image_path = str(db_base_path / image_path)
+
+ qa_list = v.get("qa_list", [])
+ if not qa_list or qa_list == "Question not applicable":
+ continue
+
+ qa_pair = random.choice(qa_list) if isinstance(qa_list[0], list) else qa_list
+ q, a = qa_pair[0], qa_pair[1]
+
+ # Generate prompt and unique cache key
+ annotated_image, messages = my_prompt.generate_prompt(image_path, q)
+ cache_key = f"{vlm_type}_{prompt}_{image_path}_{str(messages)}"
+
+ if cache_key in vlm_cache:
+ pred = vlm_cache[cache_key]
+ else:
+ pred, _ = my_vlm.generate_answer(annotated_image, messages)
+ vlm_cache[cache_key] = pred
+
+ correct = my_metric.evaluate(pred, a)
+
+ questions.append(q)
+ answers.append(a)
+ preds.append(pred)
+ correctness.append(correct)
- # Implementation of evaluation logic would go here
- # This is a simplified version - the full implementation would include
- # image processing, question answering, and metric evaluation
-
- # For now, we'll create a placeholder that would be replaced
- # with the actual evaluation logic from the original file
- correctness = [0.5] * min_available_samples # Placeholder
all_correctness.extend(correctness)
# Save results
@@ -571,8 +603,11 @@ def evaluate_vlm(
log_file.write(f"VLM: {vlm_type}\n")
log_file.write(f"Metric: {metric}\n")
log_file.write(f"Prompt: {prompt}\n")
- log_file.write(f"Sample Size: {min_available_samples}\n")
+ log_file.write(f"Sample Size: {len(correctness)}\n")
log_file.write(f"Correctness: \n{correctness}\n")
+ log_file.write(f"Questions: \n{questions}\n")
+ log_file.write(f"Answers: \n{answers}\n")
+ log_file.write(f"Predictions: \n{preds}\n")
# Calculate and return final accuracy
if len(all_correctness) == 0:
diff --git a/graid/src/graid/evaluator/metrics.py b/graid/src/graid/evaluator/metrics.py
index eb2a359..ef3de26 100644
--- a/graid/src/graid/evaluator/metrics.py
+++ b/graid/src/graid/evaluator/metrics.py
@@ -29,7 +29,8 @@ def evaluate(self, pred, gt) -> float:
pred_as_json = None
try:
pred_as_json = json.loads(pred)
- except:
+ except (json.JSONDecodeError, TypeError):
+ # Keep pred_as_json as None if JSON parsing fails
pass
try:
if pred_as_json and "answer" in pred_as_json:
@@ -42,7 +43,8 @@ def evaluate(self, pred, gt) -> float:
pred = match.group(1).strip()
else:
pred = pred.strip()
- except:
+ except (AttributeError, TypeError, ValueError):
+ # Return 0.0 if string processing fails
return 0.0
return 1.0 if str(pred).lower() == gt.strip().lower() else 0.0
@@ -59,7 +61,8 @@ def evaluate(self, pred, gt) -> float:
pred_as_json = None
try:
pred_as_json = json.loads(pred)
- except:
+ except (json.JSONDecodeError, TypeError):
+ # Keep pred_as_json as None if JSON parsing fails
pass
try:
if pred_as_json and "answer" in pred_as_json:
@@ -72,7 +75,8 @@ def evaluate(self, pred, gt) -> float:
pred = match.group(1).strip()
else:
pred = pred.strip()
- except:
+ except (AttributeError, TypeError, ValueError):
+ # Return 0.0 if string processing fails
return 0.0
return 1.0 if gt.strip().lower() in pred.strip().lower() else 0.0
diff --git a/graid/src/graid/evaluator/prompts.py b/graid/src/graid/evaluator/prompts.py
index bf734ca..c859348 100644
--- a/graid/src/graid/evaluator/prompts.py
+++ b/graid/src/graid/evaluator/prompts.py
@@ -1,10 +1,10 @@
-import os
from textwrap import dedent
import cv2
import numpy as np
import supervision as sv
import torch
+from numpy.typing import NDArray
from graid.utilities.common import get_default_device
@@ -29,13 +29,16 @@ def __init__(self, using_cd=False):
)
def generate_prompt(self, image, question):
- prompt = f"""\
- Answer the following question related to the image. If this question involves object naming, you may only identify objects from the COCO dataset (80 labels).{self.ans_format_str}
-
- Here's the question: {question}.
+ system_prompt = f"""\
+ Answer the following question related to the image. If this question involves object naming, you may only identify objects that are specified from the question or if none are specified, you may only identify objects from the COCO dataset (80 labels).{self.ans_format_str}
"""
- return image, dedent(prompt)
+ messages = [
+ {"role": "system", "content": dedent(system_prompt)},
+ {"role": "user", "content": question},
+ ]
+
+ return image, messages
def __str__(self):
return "ZeroShotPrompt"
@@ -45,12 +48,14 @@ class ZeroShotPrompt_batch(PromptingStrategy):
"""Zero-shot prompting method."""
def generate_prompt(self, image, question):
- prompt = f"""\
- Answer the following questions related to the image. Provide your answers to each question, separated by commas. Here are the questions:
- {question}
+ system_prompt = """\
+ Answer the following questions related to the image. Provide your answers to each question, separated by commas.
"""
-
- return image, dedent(prompt)
+ messages = [
+ {"role": "system", "content": dedent(system_prompt)},
+ {"role": "user", "content": question},
+ ]
+ return image, messages
def __str__(self):
return "ZeroShotPrompt_batch"
@@ -60,62 +65,80 @@ class CoT(PromptingStrategy):
"""CoT prompting method."""
def generate_prompt(self, image, question):
- prompt = f"""\
- Look at the image carefully and think through each question step by step. Use the process below to guide your reasoning and arrive at the correct answer. Here are some examples of how to answer the question:
-
- Question: Are there any motorcyclists to the right of any pedestrians?
-
- Steps:
- 1. I see three pedestrians walking on the left sidewalk, roughly in the left third of the image.
- 2. I also see a single motorcyclist riding away from the camera, positioned nearer the center of the road and center of the camera frame but clearly to the right of those pedestrians.
- 3. Comparing their horizontal positions, the motorcyclist's xβcoordinate is larger (further to the right) than either pedestrian's.
-
- Conclusion: The motorcyclist is to the right of the pedestrians.
- Final_Answer: Yes.
-
-
- Question: What group of objects are most clustered together?
-
- Steps:
- 1. Scanning for COCO categories only, I identify the following objects
- 2. Person:
- I spot three pedestrians on the left sidewalk: one nearest the foreground, one a few meters behind, and a third just past the white box truck.
- They are spaced roughly 2β3 m apart along the sidewalk.
-
- 3. Motorcycle
- A single motorcyclist is riding down the center of the road, about midway up the frame.
- Only one instance, so no clustering.
-
- 4. Truck
- A single white box truck is parked on the left curb beyond the first two pedestrians.
- Again only one, so no cluster.
+ system_prompt = dedent(
+ """\
+ Look at the image carefully and think through each question step by step. Use the provided examples to guide your reasoning and arrive at the correct answer. Answer in the same step-by-step reasoning format as the examples.
+ """
+ )
- 5. Car
- At least six cars parked behind the french on the right and at least four cars in the distance near the center of the image
- Both clusters of cars, especially the parked ones behind the fence occupy a small contiguous area, tightly packed together.
+ example_1_q = "Are there any motorcyclists to the right of any pedestrians?"
+ example_1_a = dedent(
+ """\
+ Steps:
+ 1. I see three pedestrians walking on the left sidewalk, roughly in the left third of the image.
+ 2. I also see a single motorcyclist riding away from the camera, positioned nearer the center of the road and center of the camera frame but clearly to the right of those pedestrians.
+ 3. Comparing their horizontal positions, the motorcyclist's xβcoordinate is larger (further to the right) than either pedestrian's.
+ Conclusion: The motorcyclist is to the right of the pedestrians.
+ Final_Answer: Yes.
+ """
+ )
- Conclusion: We can compare the densities of the groups we found.
- The three people, while grouped, are separated by a few meters each.
- The six-plus cars are parked immediately adjacent in a compact line.
+ example_2_q = "What group of objects are most clustered together?"
+ example_2_a = dedent(
+ """\
+ Steps:
+ 1. Scanning for COCO categories only, I identify the following objects
+ 2. Person:
+ I spot three pedestrians on the left sidewalk: one nearest the foreground, one a few meters behind, and a third just past the white box truck.
+ They are spaced roughly 2-3 m apart along the sidewalk.
- Final_Answer: The cars are the most clustered together.
+ 3. Motorcycle
+ A single motorcyclist is riding down the center of the road, about midway up the frame.
+ Only one instance, so no clustering.
+ 4. Truck
+ A single white box truck is parked on the left curb beyond the first two pedestrians.
+ Again only one, so no cluster.
- Question: Does the leftmost object in the image appear to be wider than it is tall?
+ 5. Car
+ At least six cars parked behind the french on the right and at least four cars in the distance near the center of the image
+ Both clusters of cars, especially the parked ones behind the fence occupy a small contiguous area, tightly packed together.
- Steps:
- 1. Among the COCO categories present, the object farthest to the left is the bench under the busβstop canopy.
- 2. That bench's bounding area is much broader horizontally than it is tall vertically.
- Conclusion: The bench is wider than it is tall.
- Final_Answer: Yes.
+ Conclusion: We can compare the densities of the groups we found.
+ The three people, while grouped, are separated by a few meters each.
+ The six-plus cars are parked immediately adjacent in a compact line.
- Now that you have seen the examples, answer only the following question in the same step-by-step reasoning format as the examples:
- Question: {question}
+ Final_Answer: The cars are the most clustered together.
+ """
+ )
+ example_3_q = (
+ "Does the leftmost object in the image appear to be wider than it is tall?"
+ )
+ example_3_a = dedent(
+ """\
+ Steps:
+ 1. Among the COCO categories present, the object farthest to the left is the bench under the busβstop canopy.
+ 2. That bench's bounding area is much broader horizontally than it is tall vertically.
+
+ Conclusion: The bench is wider than it is tall.
+ Final_Answer: Yes.
"""
- return image, dedent(prompt)
+ )
+
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": example_1_q},
+ {"role": "assistant", "content": example_1_a},
+ {"role": "user", "content": example_2_q},
+ {"role": "assistant", "content": example_2_a},
+ {"role": "user", "content": example_3_q},
+ {"role": "assistant", "content": example_3_a},
+ {"role": "user", "content": question},
+ ]
+ return image, messages
def __str__(self):
return "CoT"
@@ -135,12 +158,18 @@ def generate_prompt(self, image, question):
if not self.examples:
raise ValueError("Few-shot examples are required but not provided.")
- prompt = "Here are some examples:\n"
- for i, (inp, out) in enumerate(self.examples):
- prompt += f"Example {i+1}:\nInput: {inp}\nOutput: {out}\n\n"
+ messages = [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant that provides answers to questions.",
+ }
+ ]
+ for inp, out in self.examples:
+ messages.append({"role": "user", "content": inp})
+ messages.append({"role": "assistant", "content": out})
- prompt += f"Now, answer the following question:\n{question}"
- return prompt
+ messages.append({"role": "user", "content": question})
+ return image, messages
def __str__(self):
return "FewShotPrompt"
@@ -148,32 +177,24 @@ def __str__(self):
class SetOfMarkPrompt(PromptingStrategy):
def __init__(self, gpu=1):
- from segment_anything import (
- SamAutomaticMaskGenerator,
- SamPredictor,
- sam_model_registry,
- )
-
- CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
- print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))
-
- DEVICE = get_default_device()
- MODEL_TYPE = "vit_h"
+ from graid.utilities.sam_utils import SAMMaskGenerator
- sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(
- device=f"cuda:{gpu}"
- )
- self.mask_generator = SamAutomaticMaskGenerator(sam)
+ self.mask_generator = SAMMaskGenerator(gpu=gpu)
self.MIN_AREA_PERCENTAGE = 0.005
self.MAX_AREA_PERCENTAGE = 0.05
def generate_prompt(self, image, question):
- prompt = f"""Answer the following question related to the image. If this question involves object naming, you may only identify objects from the COCO dataset (80 labels). Make sure to wrap the answer in triple backticks. "```"
- Here's the question: {question}.
+ system_prompt = f"""Answer the following question related to the image. If this question involves object naming, you may only identify objects that are specified from the question or if none are specified, you may only identify objects from the COCO dataset (80 labels). Make sure to wrap the answer in triple backticks. "```"
"""
+ messages = [
+ {"role": "system", "content": dedent(system_prompt)},
+ {"role": "user", "content": question},
+ ]
if isinstance(image, str):
image_bgr = cv2.imread(image)
+ if image_bgr is None:
+ raise ValueError(f"Could not read image from path: {image}")
elif isinstance(image, torch.Tensor):
image_bgr = image.mul(255).permute(1, 2, 0).numpy().astype(np.uint8)
else:
@@ -193,7 +214,7 @@ def generate_prompt(self, image, question):
max_area_mask = (detections.area / image_area) < self.MAX_AREA_PERCENTAGE
detections = detections[min_area_mask & max_area_mask]
- def Find_Center(mask: np.ndarray) -> tuple[int, int]:
+ def Find_Center(mask: NDArray[np.uint8]) -> tuple[int, int]:
mask_8u = mask.astype(np.uint8)
# Distance transform
@@ -233,7 +254,7 @@ def Mark_Allocation(masks: list[np.ndarray]) -> list[tuple[int, int]]:
all_masks = [detections[i].mask for i in range(len(detections))]
if not all_masks:
- return image, prompt
+ return image, messages
centers = Mark_Allocation(all_masks)
@@ -247,7 +268,7 @@ def Mark_Allocation(masks: list[np.ndarray]) -> list[tuple[int, int]]:
annotated_image = image_bgr.copy()
annotated_image = mask_annotator.annotate(
- scene=annotated_image, detections=detections
+ scene=annotated_image, detections=sorted_detections
)
for idx, (x, y) in enumerate(centers, start=1):
@@ -263,7 +284,25 @@ def Mark_Allocation(masks: list[np.ndarray]) -> list[tuple[int, int]]:
cv2.LINE_AA,
)
- return annotated_image, prompt
+ return annotated_image, messages
def __str__(self):
return "SetOfMarkPrompt"
+
+
+class PassthroughPrompt(PromptingStrategy):
+ """A minimal prompt strategy that leaves the image unaltered and forwards
+ the question verbatim.
+
+ Useful when no special instructions or visual annotations are required.
+ """
+
+ def generate_prompt(self, image, question): # noqa: D401, ANN001
+ # Simply echo back the inputs as a message list for consistency
+ messages = [
+ {"role": "user", "content": question}
+ ]
+ return image, messages
+
+ def __str__(self):
+ return "PassthroughPrompt"
diff --git a/graid/src/graid/evaluator/vlms.py b/graid/src/graid/evaluator/vlms.py
index e85bf85..1e1033f 100644
--- a/graid/src/graid/evaluator/vlms.py
+++ b/graid/src/graid/evaluator/vlms.py
@@ -3,12 +3,12 @@
import os
import re
import time
+from abc import ABC, abstractmethod
from enum import Enum
-from typing import Any, Dict, List, Literal, Type, cast
+from typing import Any, Literal, cast
import cv2
import numpy as np
-import requests
import torch
from dotenv import load_dotenv
from google import genai
@@ -19,10 +19,29 @@
from torchvision import transforms
from graid.utilities.coco import coco_labels
-from graid.utilities.common import project_root_dir
-class GPT:
+class VLM(ABC):
+ """Abstract Base Class for Vision Language Models."""
+
+ @abstractmethod
+ def generate_answer(
+ self, image, messages: list[dict[str, str]]
+ ) -> tuple[Any, list[dict[str, str]]]:
+ """
+ Generates an answer from the VLM.
+
+ Args:
+ image: The input image (potentially annotated).
+ messages: The list of messages for the conversation.
+
+ Returns:
+ A tuple containing the VLM's response and the messages passed.
+ """
+ raise NotImplementedError
+
+
+class GPT(VLM):
def __init__(self, model_name="gpt-4o", port=None):
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
@@ -46,49 +65,77 @@ def encode_image(self, image):
success, buffer = cv2.imencode(".jpg", image)
return base64.b64encode(buffer).decode("utf-8")
+ def _convert_messages_for_openai(self, messages, base64_image):
+ """Convert message list to OpenAI API format with image attached to last user message."""
+ converted_messages = []
+
+ # Find the last user message to attach the image
+ last_user_idx = None
+ for i in range(len(messages) - 1, -1, -1):
+ if messages[i]["role"] == "user":
+ last_user_idx = i
+ break
+
+ for i, msg in enumerate(messages):
+ if msg["role"] in ["system", "assistant"]:
+ # Pass through system and assistant messages as-is
+ converted_messages.append({
+ "role": msg["role"],
+ "content": msg["content"]
+ })
+ elif msg["role"] == "user":
+ if i == last_user_idx:
+ # Attach image to the last user message
+ converted_messages.append({
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{base64_image}",
+ "detail": "high",
+ },
+ },
+ {
+ "type": "text",
+ "text": msg["content"],
+ },
+ ],
+ })
+ else:
+ # Regular user message without image
+ converted_messages.append({
+ "role": "user",
+ "content": msg["content"]
+ })
+
+ return converted_messages
+
@retry(
wait=wait_exponential(multiplier=1, min=2, max=10),
stop=stop_after_attempt(5),
)
- def generate_answer(self, image, questions: str, prompting_style):
+ def generate_answer(self, image, messages: list[dict[str, str]]):
# reference: https://platform.openai.com/docs/guides/vision
-
- image, prompt = prompting_style.generate_prompt(image, questions)
-
base64_image = self.encode_image(image)
+
+ converted_messages = self._convert_messages_for_openai(messages, base64_image)
completion = self.client.chat.completions.create(
model=self.model_name,
- messages=[
- {
- "role": "user",
- "content": [
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/jpeg;base64,{base64_image}",
- "detail": "high",
- },
- },
- {
- "type": "text",
- "text": prompt,
- },
- ],
- }
- ],
+ messages=converted_messages,
temperature=0.0,
)
responses = completion.choices[0].message.content
- return responses, prompt
+ return responses, messages
def __str__(self):
return self.model_name
-class Gemini:
+class Gemini(VLM):
def __init__(self, model_name="gemini-1.5-pro", location="us-central1"):
self.client = genai.Client(
vertexai=True,
@@ -108,21 +155,60 @@ def encode_image(self, image):
pil_image = transform(image)
return pil_image
+ def _prepare_gemini_request(self, messages, image):
+ """Prepare (system_instruction, contents) tuple for Gemini client.
+
+ Gemini accepts:
+ β’ Optional system_instruction via `config.system_instruction`
+ β’ `contents` β list that can mix strings & PIL.Image
+ We consolidate conversational turns into a single prompt string so we
+ only need **one** text block + the image.
+ """
+ system_instruction: str | None = None
+ text_parts: list[str] = []
+
+ # Traverse messages in order and build text parts / extract system
+ for msg in messages:
+ role, content = msg["role"], msg["content"]
+ if role == "system":
+ # Keep the very first system prompt (merge if multiple)
+ system_instruction = (
+ content
+ if system_instruction is None
+ else f"{system_instruction}\n{content}"
+ )
+ elif role == "user":
+ text_parts.append(f"User: {content}")
+ elif role == "assistant":
+ text_parts.append(f"Assistant: {content}")
+
+ combined_prompt = "\n\n".join(text_parts) if text_parts else ""
+
+ # According to docs we can pass [text, image] (text first) or vice-versa.
+ contents = [combined_prompt, image]
+ return system_instruction, contents
+
@retry(
wait=wait_exponential(multiplier=1, min=2, max=10),
stop=stop_after_attempt(5),
)
- def generate_answer(self, image, questions: str, prompting_style):
- image, prompt = prompting_style.generate_prompt(image, questions)
+ def generate_answer(self, image, messages: list[dict[str, str]]):
+ from google.genai import types
+
image = self.encode_image(image)
+ system_instruction, contents = self._prepare_gemini_request(messages, image)
+
response = None
for _ in range(3):
try:
- response = self.client.models.generate_content(
- model=self.model,
- contents=[image, prompt],
- )
+ params = {"model": self.model, "contents": contents}
+ if system_instruction:
+ params["config"] = types.GenerateContentConfig(
+ system_instruction=system_instruction
+ )
+
+ response = self.client.models.generate_content(**params)
break
except Exception as e:
print(e)
@@ -130,13 +216,13 @@ def generate_answer(self, image, questions: str, prompting_style):
if response is None:
raise Exception("Failed to generate content after multiple attempts")
- return response.text, prompt
+ return response.text, messages
def __str__(self) -> str:
return self.model
-class Llama:
+class Llama(VLM):
def __init__(
self, model_name="meta-llama/Llama-3.2-90B-Vision-Instruct", use_vllm=False
):
@@ -155,28 +241,6 @@ def __init__(
# api_key=self.token,
)
else:
- # import os
- # os.environ["GOOGLE_APPLICATION_CREDENTIALS"]="token.txt"
-
- # with open("token.txt", "r") as token_file:
- # self.token = token_file.read().strip()
-
- # # google_url = f"https://{MAAS_ENDPOINT}/v1beta1/projects/{PROJECT_ID}/locations/{REGION}/endpoints/openapi"
-
- # print("Using Google Vertex hosted Llama")
- # self.client = genai.Client(
- # vertexai=True,
- # project=PROJECT_ID,
- # location=REGION,
- # )
- # self.model = "meta/llama-3.2-90b-vision-instruct-maas"
-
- # from google.auth import default, transport
-
- # # Get credentials
- # credentials, _ = default()
- # auth_request = transport.requests.Request()
- # credentials.refresh(auth_request)
from google.auth import default
from google.auth.transport.requests import Request
@@ -185,9 +249,6 @@ def __init__(
)
credentials.refresh(Request())
- # with open("token.txt", "r") as token_file:
- # self.token = token_file.read().strip()
-
google_url = f"https://{MAAS_ENDPOINT}/v1beta1/projects/{PROJECT_ID}/locations/{REGION}/endpoints/openapi"
print("Using Google Vertex hosted Llama")
@@ -212,31 +273,59 @@ def encode_image(self, image):
with open(image, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
+ def _convert_messages_for_openai(self, messages, base64_image):
+ """Convert message list to OpenAI API format with image attached to last user message."""
+ converted_messages = []
+
+ # Find the last user message to attach the image
+ last_user_idx = None
+ for i in range(len(messages) - 1, -1, -1):
+ if messages[i]["role"] == "user":
+ last_user_idx = i
+ break
+
+ for i, msg in enumerate(messages):
+ if msg["role"] in ["system", "assistant"]:
+ # Pass through system and assistant messages as-is
+ converted_messages.append({
+ "role": msg["role"],
+ "content": msg["content"]
+ })
+ elif msg["role"] == "user":
+ if i == last_user_idx:
+ # Attach image to the last user message
+ converted_messages.append({
+ "role": "user",
+ "content": [
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
+ {"type": "text", "text": msg["content"]},
+ ],
+ })
+ else:
+ # Regular user message without image
+ converted_messages.append({
+ "role": "user",
+ "content": msg["content"]
+ })
+
+ return converted_messages
+
@retry(
wait=wait_exponential(multiplier=1, min=2, max=10),
# stop=stop_after_attempt(5),
)
- def generate_answer(self, image, questions: str, prompting_style):
- image, prompt = prompting_style.generate_prompt(image, questions)
+ def generate_answer(self, image, messages: list[dict[str, str]]):
base64_image = self.encode_image(image)
- image_gcs_url = f"data:image/jpeg;base64,{base64_image}"
+ converted_messages = self._convert_messages_for_openai(messages, base64_image)
response = self.client.chat.completions.create(
model=self.model,
- messages=[
- {
- "role": "user",
- "content": [
- {"type": "image_url", "image_url": {"url": image_gcs_url}},
- {"type": "text", "text": prompt},
- ],
- }
- ],
+ messages=converted_messages,
temperature=0.0,
)
print(response)
- return response.choices[0].message.content, prompt
+ return response.choices[0].message.content, messages
def __str__(self) -> str:
return "Llama"
@@ -402,7 +491,7 @@ class MostClusteredObjectsAnswer(Answer):
answer: CocoLabelEnum
-QUESTION_CLASS_MAP: Dict[str, Type[Answer]] = {
+QUESTION_CLASS_MAP: dict[str, type[Answer]] = {
r"centered in the image": IsObjectCenteredAnswer,
r"width of the .* larger than the height": WidthVsHeightAnswer,
r"In what quadrant does .* appear": QuadrantsAnswer,
@@ -424,7 +513,7 @@ class MostClusteredObjectsAnswer(Answer):
}
-def get_answer_class_from_question(question: str) -> Type[Answer]:
+def get_answer_class_from_question(question: str) -> type[Answer]:
for pattern, cls in QUESTION_CLASS_MAP.items():
if re.search(pattern, question, flags=re.IGNORECASE):
return cls
@@ -438,7 +527,7 @@ class Step(BaseModel):
class Reasoning(BaseModel):
- steps: List[Step]
+ steps: list[Step]
conclusion: str = Field(
description="A concluding statement summarizing or linking the steps"
)
@@ -451,32 +540,17 @@ class GPT_CD(GPT):
def __init__(self, model_name="gpt-4o", port=None):
super().__init__(model_name)
- def generate_answer(self, image, questions: str, prompting_style):
- image, prompt = prompting_style.generate_prompt(image, questions)
+ def generate_answer(self, image, messages: list[dict[str, str]]):
base64_image = self.encode_image(image)
- answer_cls = get_answer_class_from_question(questions)
+ question = messages[-1]["content"]
+ answer_cls = get_answer_class_from_question(question)
+
+ converted_messages = self._convert_messages_for_openai(messages, base64_image)
completion = self.client.beta.chat.completions.parse(
model=self.model_name,
- messages=[
- {
- "role": "user",
- "content": [
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/jpeg;base64,{base64_image}",
- "detail": "high",
- },
- },
- {
- "type": "text",
- "text": prompt,
- },
- ],
- }
- ],
+ messages=converted_messages,
response_format=answer_cls,
temperature=0.0,
)
@@ -487,7 +561,7 @@ def generate_answer(self, image, questions: str, prompting_style):
else:
final_answer = message.refusal
- return final_answer, prompt
+ return final_answer, messages
def __str__(self):
return self.model_name + "_CD"
@@ -497,27 +571,19 @@ class Llama_CD(Llama):
def __init__(self, model_name="meta-llama/Llama-3.2-90B-Vision-Instruct"):
super().__init__(model_name, use_vllm=False)
- def generate_answer(self, image, questions: str, prompting_style):
- image, prompt = prompting_style.generate_prompt(image, questions)
+ def generate_answer(self, image, messages: list[dict[str, str]]):
base64_image = self.encode_image(image)
- image_gcs_url = f"data:image/jpeg;base64,{base64_image}"
-
- answer_cls = get_answer_class_from_question(questions)
+ question = messages[-1]["content"]
+ answer_cls = get_answer_class_from_question(question)
# There doesn't seem to be a good way of dynamically setting the final answer type
# to be the answer_cls so we will include it in the prompt
+ converted_messages = self._convert_messages_for_openai(messages, base64_image)
+
response = self.client.beta.chat.completions.parse(
model=self.model,
- messages=[
- {
- "role": "user",
- "content": [
- {"type": "image_url", "image_url": {"url": image_gcs_url}},
- {"type": "text", "text": prompt},
- ],
- },
- ],
+ messages=converted_messages,
temperature=0.0,
response_format=answer_cls,
)
@@ -528,7 +594,7 @@ def generate_answer(self, image, questions: str, prompting_style):
else:
final_answer = message.refusal
- return final_answer, prompt
+ return final_answer, messages
def __str__(self):
return "Llama_CD"
@@ -538,31 +604,35 @@ class Gemini_CD(Gemini):
def __init__(self, model_name="gemini-1.5-pro", location="us-central1"):
super().__init__(model_name, location)
- def generate_answer(self, image, questions: str, prompting_style):
- image, prompt = prompting_style.generate_prompt(image, questions)
+ def generate_answer(self, image, messages: list[dict[str, str]]):
+ from google.genai import types
+
image = self.encode_image(image)
- response_format = get_answer_class_from_question(questions)
+ question = messages[-1]["content"]
+ response_format = get_answer_class_from_question(question)
+
+ system_instruction, contents = self._prepare_gemini_request(messages, image)
+
+ config_kwargs = {
+ "response_mime_type": "application/json",
+ "response_schema": response_format,
+ "temperature": 0.0,
+ "topK": 1,
+ }
+ if system_instruction:
+ config_kwargs["system_instruction"] = system_instruction
response = self.client.models.generate_content(
model=self.model,
- contents=[
- image,
- prompt,
- # f"The final_answer should be of type: {response_format.model_json_schema()}",
- ],
- config={
- "response_mime_type": "application/json",
- "response_schema": response_format,
- "temperature": 0.0,
- "topK": 1,
- },
+ contents=contents,
+ config=types.GenerateContentConfig(**config_kwargs),
)
answers: Answer = cast(Answer, response.parsed)
final_answer = answers.answer
- return final_answer, prompt
+ return final_answer, messages
def __str__(self):
return self.model + "_CD"
@@ -576,34 +646,21 @@ def __init__(self, model_name="gpt-4o", port=None):
wait=wait_exponential(multiplier=1, min=2, max=10),
stop=stop_after_attempt(5),
)
- def generate_answer(self, image, questions: str, prompting_style):
- image, prompt = prompting_style.generate_prompt(image, questions)
+ def generate_answer(self, image, messages: list[dict[str, str]]):
base64_image = self.encode_image(image)
+ converted_messages = self._convert_messages_for_openai(messages, base64_image)
+
+ question = messages[-1]["content"]
+ # Add the additional system message for CoT_CD
+ converted_messages.append({
+ "role": "system",
+ "content": f"The final_answer should be of type: {get_answer_class_from_question(question).model_json_schema()}",
+ })
+
completion = self.client.beta.chat.completions.parse(
model=self.model_name,
- messages=[
- {
- "role": "user",
- "content": [
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/jpeg;base64,{base64_image}",
- "detail": "high",
- },
- },
- {
- "type": "text",
- "text": prompt,
- },
- ],
- },
- {
- "role": "system",
- "content": f"The final_answer should be of type: {get_answer_class_from_question(questions).model_json_schema()}",
- },
- ],
+ messages=converted_messages,
response_format=Reasoning,
temperature=0.0,
)
@@ -615,7 +672,7 @@ def generate_answer(self, image, questions: str, prompting_style):
else:
output = message.refusal
- return output, prompt
+ return output, messages
def __str__(self):
return self.model_name + "_CoT_CD"
@@ -625,31 +682,25 @@ class Llama_CoT_CD(Llama):
def __init__(self, model_name="meta-llama/Llama-3.2-90B-Vision-Instruct"):
super().__init__(model_name, use_vllm=False)
- def generate_answer(self, image, questions: str, prompting_style):
- image, prompt = prompting_style.generate_prompt(image, questions)
+ def generate_answer(self, image, messages: list[dict[str, str]]):
base64_image = self.encode_image(image)
- image_gcs_url = f"data:image/jpeg;base64,{base64_image}"
-
- answer_cls = get_answer_class_from_question(questions)
+ question = messages[-1]["content"]
+ answer_cls = get_answer_class_from_question(question)
# There doesn't seem to be a good way of dynamically setting the final answer type
# to be the answer_cls so we will include it in the prompt
+ converted_messages = self._convert_messages_for_openai(messages, base64_image)
+
+ # Add the additional system message for CoT_CD
+ converted_messages.append({
+ "role": "system",
+ "content": f"The final_answer should be of type: {answer_cls.model_json_schema()}",
+ })
+
response = self.client.beta.chat.completions.parse(
model=self.model,
- messages=[
- {
- "role": "user",
- "content": [
- {"type": "image_url", "image_url": {"url": image_gcs_url}},
- {"type": "text", "text": prompt},
- ],
- },
- {
- "role": "system",
- "content": f"The final_answer should be of type: {answer_cls.model_json_schema()}",
- },
- ],
+ messages=converted_messages,
response_format=Reasoning,
temperature=0.0,
)
@@ -661,7 +712,7 @@ def generate_answer(self, image, questions: str, prompting_style):
else:
final_answer = message.refusal
- return final_answer, prompt
+ return final_answer, messages
def __str__(self):
return "Llama_CoT_CD"
@@ -671,36 +722,43 @@ class Gemini_CoT_CD(Gemini):
def __init__(self, model_name="gemini-1.5-pro", location="us-central1"):
super().__init__(model_name, location)
- def generate_answer(self, image, questions: str, prompting_style):
- image, prompt = prompting_style.generate_prompt(image, questions)
+ def generate_answer(self, image, messages: list[dict[str, str]]):
+ from google.genai import types
+
image = self.encode_image(image)
- response_format = get_answer_class_from_question(questions)
+ question = messages[-1]["content"]
+ response_format = get_answer_class_from_question(question)
+
+ system_instruction, contents = self._prepare_gemini_request(messages, image)
+
+ # Add the schema information for CoT_CD as an extra text element
+ contents.append(f"The final_answer should be of type: {response_format.model_json_schema()}")
+
+ config_kwargs = {
+ "response_mime_type": "application/json",
+ "response_schema": Reasoning,
+ "temperature": 0.0,
+ "topK": 1,
+ }
+ if system_instruction:
+ config_kwargs["system_instruction"] = system_instruction
response = self.client.models.generate_content(
model=self.model,
- contents=[
- image,
- prompt,
- f"The final_answer should be of type: {response_format.model_json_schema()}",
- ],
- config={
- "response_mime_type": "application/json",
- "response_schema": Reasoning,
- "temperature": 0.0,
- "topK": 1,
- },
+ contents=contents,
+ config=types.GenerateContentConfig(**config_kwargs),
)
reasoning_response: Reasoning = cast(Reasoning, response.parsed)
- return reasoning_response.model_dump_json(), prompt
+ return reasoning_response.model_dump_json(), messages
def __str__(self):
return self.model + "_CoT_CD"
-class Claude:
+class Claude(VLM):
def __init__(self, model_name="claude-3-7-sonnet-20250219"):
import anthropic
@@ -721,33 +779,66 @@ def encode_image(self, image):
success, buffer = cv2.imencode(".jpg", image)
return base64.b64encode(buffer).decode("utf-8")
- def generate_answer(self, image, questions: str, prompting_style):
- image, prompt = prompting_style.generate_prompt(image, questions)
+ def _convert_messages_for_claude(self, messages, base64_image):
+ """Convert message list to Claude API format.
+
+ Returns (claude_messages, system_prompt) where system_prompt may be None.
+ """
+ claude_messages = []
+ system_prompt: str | None = None
+
+ # Find the last user message to attach the image
+ last_user_idx = None
+ for i in range(len(messages) - 1, -1, -1):
+ if messages[i]["role"] == "user":
+ last_user_idx = i
+ break
+
+ for i, msg in enumerate(messages):
+ role, content = msg["role"], msg["content"]
+ if role == "system":
+ system_prompt = content if system_prompt is None else f"{system_prompt}\n{content}"
+ continue # system prompt handled separately
+
+ if role in ("user", "assistant"):
+ if role == "user" and i == last_user_idx:
+ # Attach image to the last user message
+ claude_messages.append(
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": base64_image,
+ },
+ },
+ {"type": "text", "text": content},
+ ],
+ }
+ )
+ else:
+ # Regular text-only message
+ claude_messages.append({"role": role, "content": content})
+
+ return claude_messages, system_prompt
+
+ def generate_answer(self, image, messages: list[dict[str, str]]):
base64_image = self.encode_image(image)
+ claude_messages, system_prompt = self._convert_messages_for_claude(messages, base64_image)
+
response = self.client.messages.create(
model=self.model,
max_tokens=1024,
temperature=0.0,
- messages=[
- {
- "role": "user",
- "content": [
- {
- "type": "image",
- "source": {
- "type": "base64",
- "media_type": "image/jpeg",
- "data": base64_image,
- },
- },
- {"type": "text", "text": prompt},
- ],
- }
- ],
+ system=system_prompt if system_prompt else None,
+ messages=claude_messages,
)
- return response.content[0].text, prompt
+ return response.content[0].text, messages
def __str__(self) -> str:
return self.model
@@ -757,42 +848,31 @@ class Claude_CD(Claude):
def __init__(self, model_name="claude-3-7-sonnet-20250219"):
super().__init__(model_name)
- def generate_answer(self, image, questions: str, prompting_style):
+ def generate_answer(self, image, messages: list[dict[str, str]]):
import anthropic
from pydantic import create_model
# Get the answer class based on the question
- answer_class = get_answer_class_from_question(questions)
+ question = messages[-1]["content"]
+ answer_class = get_answer_class_from_question(question)
if answer_class is None:
- return super().generate_answer(image, questions, prompting_style)
+ # Fallback to non-CD generation if no class matches
+ return super().generate_answer(image, messages)
- image, prompt = prompting_style.generate_prompt(image, questions)
base64_image = self.encode_image(image)
+ claude_messages, system_prompt = self._convert_messages_for_claude(messages, base64_image)
+
# Use anthropic.messages.create with response_model parameter for constrained decoding
response = self.client.messages.create(
model=self.model,
temperature=0.0,
- messages=[
- {
- "role": "user",
- "content": [
- {
- "type": "image",
- "source": {
- "type": "base64",
- "media_type": "image/jpeg",
- "data": base64_image,
- },
- },
- {"type": "text", "text": prompt},
- ],
- }
- ],
+ system=system_prompt if system_prompt else None,
+ messages=claude_messages,
response_model=answer_class,
)
- return response.answer, prompt
+ return response.answer, messages
def __str__(self):
return self.model + "_CD"
@@ -802,41 +882,35 @@ class Claude_CoT_CD(Claude):
def __init__(self, model_name="claude-3-7-sonnet-20250219"):
super().__init__(model_name)
- def generate_answer(self, image, questions: str, prompting_style):
+ def generate_answer(self, image, messages: list[dict[str, str]]):
# Get the answer class based on the question
- answer_class = get_answer_class_from_question(questions)
+ question = messages[-1]["content"]
+ answer_class = get_answer_class_from_question(question)
- image, prompt = prompting_style.generate_prompt(image, questions)
base64_image = self.encode_image(image)
- # Use anthropic.messages.create with response_model parameter for constrained decoding
+ claude_messages, system_prompt = self._convert_messages_for_claude(messages, base64_image)
+
+ # Add schema information to the last user message for Claude
+ if claude_messages and claude_messages[-1]["role"] == "user":
+ schema_text = f"\n\nThe final_answer should be of type: {answer_class.model_json_schema()}"
+ if isinstance(claude_messages[-1]["content"], list):
+ for content_item in claude_messages[-1]["content"]:
+ if content_item["type"] == "text":
+ content_item["text"] += schema_text
+ break
+ else:
+ claude_messages[-1]["content"] += schema_text
+
response = self.client.messages.create(
model=self.model,
temperature=0.0,
- messages=[
- {
- "role": "system",
- "content": f"The final_answer should be of type: {answer_class.model_json_schema()}",
- },
- {
- "role": "user",
- "content": [
- {
- "type": "image",
- "source": {
- "type": "base64",
- "media_type": "image/jpeg",
- "data": base64_image,
- },
- },
- {"type": "text", "text": prompt},
- ],
- },
- ],
+ system=system_prompt if system_prompt else None,
+ messages=claude_messages,
response_model=Reasoning,
)
- return response.model_dump_json(), prompt
+ return response.model_dump_json(), messages
def __str__(self):
return self.model + "_CoT_CD"
diff --git a/graid/src/graid/experiments/finetune_detr_on_bdd.py b/graid/src/graid/experiments/finetune_detr_on_bdd.py
new file mode 100644
index 0000000..9ba0ea9
--- /dev/null
+++ b/graid/src/graid/experiments/finetune_detr_on_bdd.py
@@ -0,0 +1,944 @@
+from __future__ import annotations
+
+import argparse
+import json
+from pathlib import Path
+from typing import Any, Union
+
+import numpy as np
+import torch
+import wandb
+from PIL import Image, ImageDraw
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from transformers import DetrForObjectDetection, DetrImageProcessor
+from transformers import ConditionalDetrForObjectDetection
+
+# Metric for COCO-style mAP
+from torchmetrics.detection import MeanAveragePrecision
+# LR scheduler
+from torch.optim.lr_scheduler import StepLR
+from collections import Counter
+from numpy.typing import NDArray
+
+# Imports for custom loss
+import torch.nn.functional as F
+from torchvision.ops.boxes import generalized_box_iou, box_convert
+import types # For monkey-patching
+
+
+from graid.data.ImageLoader import Bdd100kDataset
+from graid.interfaces.ObjectDetectionI import ObjectDetectionResultI
+
+
+def _sigmoid_focal_loss_with_class_weight(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ class_weights: torch.Tensor,
+ num_boxes: int,
+ alpha: float = 0.25,
+ gamma: float = 2.0,
+) -> torch.Tensor:
+ """
+ Weighted version of Focal Loss, inspired by transformers library implementation.
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ class_weights: A float tensor of shape (num_classes,).
+ num_boxes: The number of boxes in the batch.
+ Returns:
+ Loss tensor
+ """
+ prob = inputs.sigmoid()
+ ce_loss = F.binary_cross_entropy_with_logits(
+ inputs, targets, reduction="none")
+
+ # Modulating factor
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ # Alpha factor
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ # Apply per-class weights
+ # The shape of targets is (num_queries, num_classes)
+ # We can use it to index into the class_weights tensor
+ if class_weights is not None:
+ # Create a weight map of shape (num_queries, num_classes)
+ # where each row is the weight for the corresponding target class
+ weight_map = class_weights[targets.argmax(dim=1)]
+ # Since targets is one-hot, we can multiply directly
+ # and the zero elements will cancel out non-target weights
+ loss = loss * weight_map.unsqueeze(1)
+
+ return loss.mean(1).sum() / num_boxes
+
+
+def apply_custom_losses(
+ model: Union[DetrForObjectDetection, ConditionalDetrForObjectDetection],
+ area_loss_power: float,
+ class_weights: torch.Tensor | None = None,
+):
+ """
+ Apply custom area-weighted and class-weighted losses to a DETR model
+ by monkey-patching its loss functions.
+ """
+ is_conditional = isinstance(model, ConditionalDetrForObjectDetection)
+
+ # ----------------------------------------------------------------------
+ # 1. Area-weighted box/GIoU loss (compatible with both models)
+ # ----------------------------------------------------------------------
+ if area_loss_power > 0 and hasattr(model, 'loss'):
+
+ def loss_boxes_area_weighted(loss_self, outputs, targets, indices, num_boxes):
+ """Area-weighted replica of DETR loss_boxes."""
+ assert "pred_boxes" in outputs
+ idx = loss_self._get_src_permutation_idx(indices)
+ src_boxes = outputs["pred_boxes"][idx]
+ tgt_boxes = torch.cat([t["boxes"][i]
+ for t, (_, i) in zip(targets, indices)], dim=0)
+ areas = tgt_boxes[:, 2] * tgt_boxes[:, 3]
+ w = areas.clamp(min=1e-6) ** area_loss_power
+ ptr = 0
+ for (_, tgt_idx) in indices:
+ if len(tgt_idx):
+ segment = w[ptr:ptr+len(tgt_idx)]
+ w[ptr:ptr+len(tgt_idx)] = segment / segment.mean()
+ ptr += len(tgt_idx)
+ loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction="none")
+ losses = {}
+ losses["loss_bbox"] = (loss_bbox * w.unsqueeze(1)).sum() / num_boxes
+ giou = generalized_box_iou(
+ box_convert(
+ src_boxes, "center_x_center_y_width_height", "xyxy"),
+ box_convert(
+ tgt_boxes, "center_x_center_y_width_height", "xyxy"),
+ )
+ loss_giou = 1 - torch.diag(giou)
+ losses["loss_giou"] = (loss_giou * w).sum() / num_boxes
+ return losses
+
+ model.loss.loss_boxes = types.MethodType(
+ loss_boxes_area_weighted, model.loss)
+ print(f"β Enabled per-box area weighting (power={area_loss_power})")
+
+ # ----------------------------------------------------------------------
+ # 2. Class-weighted classification loss
+ # ----------------------------------------------------------------------
+ if class_weights is not None and hasattr(model, 'loss'):
+ if is_conditional:
+ # Conditional DETR uses Focal Loss. We need to patch loss_labels.
+ def loss_labels_class_weighted(loss_self, outputs, targets, indices, num_boxes, log=True):
+ """Class-weighted version of Conditional DETR's label loss."""
+ assert "pred_logits" in outputs
+ src_logits = outputs["pred_logits"]
+ idx = loss_self._get_src_permutation_idx(indices)
+ target_classes_o = torch.cat(
+ [t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
+ target_classes = torch.full(
+ src_logits.shape[:2], loss_self.num_classes, dtype=torch.int64, device=src_logits.device
+ )
+ target_classes[idx] = target_classes_o
+ target_classes_onehot = torch.zeros(
+ [src_logits.shape[0], src_logits.shape[1],
+ src_logits.shape[2] + 1],
+ dtype=src_logits.dtype,
+ layout=src_logits.layout,
+ device=src_logits.device,
+ )
+ target_classes_onehot.scatter_(
+ 2, target_classes.unsqueeze(-1), 1)
+ target_classes_onehot = target_classes_onehot[:, :, :-1]
+
+ loss_ce = _sigmoid_focal_loss_with_class_weight(
+ src_logits,
+ target_classes_onehot,
+ class_weights,
+ num_boxes,
+ alpha=loss_self.focal_loss_alpha,
+ gamma=loss_self.focal_loss_gamma,
+ ) * src_logits.shape[1]
+
+ losses = {"loss_ce": loss_ce}
+ if log:
+ losses["class_error"] = 100 * \
+ (target_classes_o !=
+ src_logits[idx].argmax(-1)).float().mean()
+ return losses
+
+ model.loss.loss_labels = types.MethodType(
+ loss_labels_class_weighted, model.loss)
+ print("β Enabled class-weighting for Conditional DETR (Focal Loss)")
+
+ else:
+ # Standard DETR uses cross-entropy and has a 'weight' parameter
+ model.class_weight = class_weights.to(model.device)
+ model.config.class_weight = class_weights.tolist()
+ print("β Enabled class-weighting for standard DETR (Cross-Entropy)")
+
+
+# ---------------------------------------------------------------------------
+# Class-imbalance utilities
+# ---------------------------------------------------------------------------
+
+
+def compute_median_freq_weights(dataset, num_classes: int, workers: int = 16) -> torch.Tensor:
+ """Compute median-frequency balancing weights for foreground classes.
+
+ For each class i: w_i = median(freq) / freq_i. Missing classes get weight 0.
+ Returned tensor shape: (num_classes,)
+ """
+ cache_dir = Path('.cache')
+ cache_dir.mkdir(exist_ok=True)
+ cache_file = cache_dir / 'bdd_class_counts.json'
+
+ if cache_file.exists():
+ counter_data = json.loads(cache_file.read_text())
+ counter = Counter({int(k): v for k, v in counter_data.items()})
+ else:
+ counter: Counter[int] = Counter()
+ dl = DataLoader(dataset, batch_size=1, shuffle=False,
+ num_workers=workers, collate_fn=lambda x: x)
+ for batch in dl:
+ item = batch[0]
+ for det in item["labels"]:
+ counter[int(det.cls)] += 1
+ cache_file.write_text(json.dumps(counter))
+
+ counts = torch.tensor([counter.get(i, 0)
+ for i in range(num_classes)], dtype=torch.float)
+
+ weights = torch.zeros_like(counts)
+ nonzero = counts > 0
+ if nonzero.any():
+ median = torch.median(counts[nonzero])
+ weights[nonzero] = median / counts[nonzero]
+
+ return weights
+
+
+# ---------------------------------------------------------------------------
+# Class mapping utilities
+# ---------------------------------------------------------------------------
+
+def get_original_detr_mapping():
+ """Get the original DETR model's class mappings."""
+ temp_model = DetrForObjectDetection.from_pretrained(
+ "facebook/detr-resnet-50", revision="no_timm"
+ )
+ return temp_model.config.id2label, temp_model.config.label2id
+
+
+def create_bdd_to_detr_mapping():
+ """Create mapping from BDD100K COCO class IDs to DETR model class IDs."""
+ _, original_label2id = get_original_detr_mapping()
+
+ # BDD100K uses these COCO class IDs (from dataset analysis)
+ bdd_coco_classes = {
+ 0: "person",
+ 1: "bicycle",
+ 2: "car",
+ 3: "motorcycle",
+ 5: "bus",
+ 6: "train",
+ 7: "truck",
+ 9: "traffic light",
+ 11: "stop sign"
+ }
+
+ # Create mapping from our COCO class ID to DETR class ID
+ our_id_to_detr_id = {}
+ for our_id, class_name in bdd_coco_classes.items():
+ if class_name in original_label2id:
+ detr_id = original_label2id[class_name]
+ our_id_to_detr_id[our_id] = detr_id
+ else:
+ print(
+ f"Warning: {class_name} not found in DETR model, mapping to 0 (N/A)")
+ our_id_to_detr_id[our_id] = 0
+
+ print("BDD100K COCO -> DETR class mapping:")
+ for our_id, detr_id in our_id_to_detr_id.items():
+ class_name = bdd_coco_classes[our_id]
+ print(f" {our_id} ({class_name}) -> {detr_id}")
+
+ return our_id_to_detr_id
+
+
+def create_bdd_direct_mapping():
+ """Create direct BDD100K class mapping (no COCO intermediate step)."""
+ # BDD100K has 12 original classes (0-11)
+ bdd_classes = {
+ 0: "pedestrian",
+ 1: "person",
+ 2: "rider",
+ 3: "car",
+ 4: "truck",
+ 5: "bus",
+ 6: "train",
+ 7: "motorcycle",
+ 8: "bicycle",
+ 9: "traffic light",
+ 10: "traffic sign",
+ 11: "sidewalk"
+ }
+
+ # Direct mapping (identity function)
+ direct_mapping = {i: i for i in range(12)}
+
+ print("BDD100K Direct class mapping:")
+ for bdd_id, mapped_id in direct_mapping.items():
+ class_name = bdd_classes.get(bdd_id, f"unknown_{bdd_id}")
+ print(f" {bdd_id} ({class_name}) -> {mapped_id}")
+
+ return direct_mapping, bdd_classes
+
+
+_COLOURS = [
+ (255, 56, 56), (255, 157, 151), (255, 112, 31), (255, 178, 29),
+ (207, 210, 49), (72, 249, 10), (146, 249, 10), (10, 249, 72),
+ (10, 249, 146), (10, 249, 249), (10, 146, 249), (10, 72, 249),
+ (72, 10, 249), (146, 10, 249), (249, 10, 249), (249, 10, 146),
+ (249, 10, 72), (249, 10, 10),
+]
+
+
+def _draw_boxes(
+ image: Image.Image,
+ boxes: list[list[float]],
+ scores: list[float],
+ labels: list[int],
+ model: Union[DetrForObjectDetection, ConditionalDetrForObjectDetection],
+ bdd_classes: dict[int, str] | None = None,
+) -> NDArray[np.uint8]:
+ """Overlay bounding-boxes on an image and return the result."""
+ draw = ImageDraw.Draw(image)
+ for i, (box, score, label_id) in enumerate(zip(boxes, scores, labels)):
+ colour = _COLOURS[label_id % len(_COLOURS)]
+ # The processor's post_process returns (x1, y1, x2, y2) already
+ x1, y1, x2, y2 = map(int, box)
+
+ # Use BDD class names if available, otherwise use model's id2label
+ if bdd_classes:
+ label_name = bdd_classes.get(label_id, f"CLS_{label_id}")
+ else:
+ label_name = model.config.id2label.get(label_id, f"CLS_{label_id}")
+
+ caption = f"{label_name}: {score:.2%}"
+
+ draw.rectangle([x1, y1, x2, y2], outline=colour, width=3)
+ text_w = draw.textlength(caption)
+
+ # Draw text background
+ draw.rectangle([x1, y1, x1 + text_w + 4, y1 + 15], fill=colour)
+ draw.text((x1 + 2, y1), caption, fill=(0, 0, 0))
+
+ return np.array(image)
+
+
+# ---------------------------------------------------------------------------
+# Data loading and collation
+# ---------------------------------------------------------------------------
+
+class DetrDataCollator:
+ """Collator to prepare data for DETR, adapted from BDD100K format."""
+
+ def __init__(self, processor: DetrImageProcessor, class_mapping: dict[int, int]):
+ self.processor = processor
+ self.class_mapping = class_mapping
+
+ def __call__(self, batch: list[dict[str, Any]]) -> Any:
+ # Extract images and annotations from the batch
+ images = [item["image"] for item in batch]
+ annotations = []
+ for idx, item in enumerate(batch):
+ img_annots = []
+ labels: list[ObjectDetectionResultI] = item["labels"]
+ for label in labels:
+ # Map our class ID to target class ID
+ our_cls = int(label.cls) # Ensure it's an int
+ target_cls = self.class_mapping.get(our_cls, 0) # fallback to 0
+
+ # Convert box from (x1, y1, x2, y2) to COCO format (x_min, y_min, width, height)
+ xyxy = label.as_xyxy()[0].tolist()
+ x1, y1, x2, y2 = xyxy
+ coco_bbox = [x1, y1, x2 - x1, y2 - y1]
+
+ img_annots.append({
+ "bbox": coco_bbox,
+ "category_id": target_cls, # Use mapped class ID
+ "area": label.get_area().item(),
+ })
+ # Use stable integer image_id within the batch; the value is only
+ # used to group annotations that belong to the same image.
+ annotations.append({"image_id": idx, "annotations": img_annots})
+
+ # Process batch with DETR processor
+ processed_batch = self.processor(
+ images=images,
+ annotations=annotations,
+ return_tensors="pt"
+ )
+ return processed_batch
+
+
+# ---------------------------------------------------------------------------
+# Training and Evaluation
+# ---------------------------------------------------------------------------
+
+def train_one_epoch(
+ model: Union[DetrForObjectDetection, ConditionalDetrForObjectDetection],
+ dataloader: DataLoader[Any],
+ optimizer: torch.optim.Optimizer,
+ device: torch.device,
+ epoch: int,
+):
+ """Train the model for one epoch."""
+ model.train()
+ total_loss = 0
+ progress = tqdm(dataloader, desc=f"Epoch {epoch+1} [TRAIN]")
+ for batch in progress:
+ # Move batch to device
+ inputs: dict[str, Any] = {k: v.to(device) for k, v in batch.items()
+ if isinstance(v, torch.Tensor)}
+ # The processor formats labels into a list of dicts, must be handled separately
+ inputs["labels"] = [{k: v.to(device) for k, v in t.items()}
+ for t in batch["labels"]]
+
+ # Forward pass
+ outputs = model(**inputs)
+
+ # Loss: From DETR config
+ # "class_cost": 1, # Classification weight in Hungarian matching
+ # "bbox_cost": 5, # L1 bbox weight in Hungarian matching
+ # "giou_cost": 2, # GIoU weight in Hungarian matching
+ # "bbox_loss_coefficient": 5, # L1 bbox weight in final loss
+ # "giou_loss_coefficient": 2, # GIoU weight in final loss
+ # "eos_coefficient": 0.1, # "No-object" class weight
+ # Total Loss = Classification Loss + Ξ»β Γ L1 Loss + Ξ»β Γ GIoU Loss
+ loss = outputs.loss
+ loss_dict = outputs.loss_dict
+
+ # Backward pass
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ total_loss += loss.item()
+ progress.set_postfix(loss=loss.item())
+ wandb.log({
+ "train_loss_step": loss.item(),
+ **{f"train_{k}": v.item() for k, v in loss_dict.items()}
+ })
+
+ avg_loss = total_loss / len(dataloader)
+ print(f"Epoch {epoch+1} Train Loss: {avg_loss:.4f}")
+ return avg_loss
+
+
+@torch.no_grad()
+def evaluate(
+ model: Union[DetrForObjectDetection, ConditionalDetrForObjectDetection],
+ dataloader: DataLoader[Any],
+ device: torch.device,
+ epoch: int,
+ processor: DetrImageProcessor,
+):
+ """Evaluate on the validation set and compute COCO-style mAP."""
+ model.eval()
+
+ total_loss = 0.0
+ metric = MeanAveragePrecision()
+
+ progress = tqdm(dataloader, desc=f"Epoch {epoch+1} [VAL]")
+ for batch in progress:
+ # Move tensor inputs to device
+ inputs: dict[str, Any] = {k: v.to(device)
+ for k, v in batch.items() if isinstance(v, torch.Tensor)}
+ # Labels need special handling (list[dict])
+ inputs["labels"] = [{k: v.to(device) for k, v in t.items()}
+ for t in batch["labels"]]
+
+ # Forward pass
+ outputs = model(**inputs)
+ loss = outputs.loss
+ total_loss += loss.item()
+
+ # ------------------------------------------------------------------
+ # Prepare predictions & targets for mAP computation
+ # ------------------------------------------------------------------
+ # Determine original image sizes (h, w) for post-processing
+ # The processor adds 'orig_size' to the labels
+ target_sizes = torch.stack([lbl["orig_size"]
+ for lbl in batch["labels"]]).to(device)
+
+ processed_outputs = processor.post_process_object_detection(
+ # no threshold, metric handles scores
+ outputs, target_sizes=target_sizes.tolist(), threshold=0.0
+ )
+
+ preds_for_metric = []
+ for pred in processed_outputs:
+ preds_for_metric.append({
+ "boxes": pred["boxes"].cpu(),
+ "scores": pred["scores"].cpu(),
+ "labels": pred["labels"].cpu(),
+ })
+
+ targets_for_metric = []
+ for tgt in batch["labels"]:
+ # Get original image size to scale the target boxes to absolute pixel coords
+ h, w = tgt["orig_size"].cpu().tolist()
+ scaler = torch.tensor([w, h, w, h])
+
+ # Convert boxes from relative (cx, cy, w, h) to absolute (x1, y1, x2, y2)
+ boxes_cxcywh_abs = tgt["boxes"].cpu() * scaler
+ cx, cy, width, height = boxes_cxcywh_abs.unbind(-1)
+ x1 = cx - 0.5 * width
+ y1 = cy - 0.5 * height
+ x2 = cx + 0.5 * width
+ y2 = cy + 0.5 * height
+ boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=-1)
+
+ targets_for_metric.append({
+ "boxes": boxes_xyxy,
+ "labels": tgt["class_labels"].cpu(),
+ })
+
+ metric.update(preds_for_metric, targets_for_metric)
+
+ progress.set_postfix(loss=loss.item())
+
+ # Aggregate metrics
+ avg_loss = total_loss / len(dataloader)
+ metric_results = metric.compute()
+ map_score = metric_results["map"].item()
+
+ print(
+ f"Epoch {epoch+1} Validation Loss: {avg_loss:.4f} | mAP: {map_score:.4f}")
+ return avg_loss, map_score
+
+
+# ---------------------------------------------------------------------------
+# Main script
+# ---------------------------------------------------------------------------
+
+def validate_model_name(model_name: str) -> None:
+ """Validate that the model name is supported and provide helpful error messages."""
+ supported_models = [
+ "facebook/detr-resnet-50",
+ "facebook/detr-resnet-101",
+ "microsoft/conditional-detr-resnet-50"
+ ]
+
+ if model_name not in supported_models:
+ # Check if it's a similar name that might work
+ model_name_lower = model_name.lower()
+ if "conditional-detr" in model_name_lower:
+ raise ValueError(
+ f"Model '{model_name}' requires ConditionalDetrForObjectDetection, "
+ f"but it's not available in your transformers version. "
+ f"Please update transformers: pip install transformers>=4.21.0"
+ )
+ elif "detr" not in model_name_lower:
+ raise ValueError(
+ f"Model '{model_name}' doesn't appear to be a DETR model. "
+ f"Supported models: {', '.join(supported_models)}"
+ )
+ else:
+ print(f"Warning: '{model_name}' is not in the tested model list.")
+ print(f"Tested models: {', '.join(supported_models)}")
+ print("Proceeding anyway - this may work if it's a compatible DETR variant.")
+
+
+def main(args: argparse.Namespace) -> None:
+ """Main function to run the training and evaluation."""
+ # Validate model name
+ validate_model_name(args.model_name)
+
+ # Setup
+ # Select GPU/CPU device based on the provided gpu_id
+ if torch.cuda.is_available():
+ device = torch.device(f"cuda:{args.gpu_id}")
+ else:
+ device = torch.device("cpu")
+ print(f"Using device: {device}")
+
+ if args.use_bdd_direct_mapping:
+ class_mapping, bdd_classes = create_bdd_direct_mapping()
+ num_classes = 12 # BDD100K has 12 classes
+ else:
+ class_mapping = create_bdd_to_detr_mapping()
+ bdd_classes = None
+ num_classes = 91 # Keep original DETR model's 91 classes
+
+ # Determine model class and info based on model name
+ if "conditional-detr" in args.model_name:
+ base_model_class = ConditionalDetrForObjectDetection
+ model_info = {
+ "name": "Conditional DETR",
+ "type": "conditional"
+ }
+ else:
+ base_model_class = DetrForObjectDetection
+ model_info = {
+ "name": "DETR",
+ "type": "standard"
+ }
+ print(
+ f"Using {model_info['name']} ({model_info['type']}) model: {args.model_name}")
+
+ # Model and Processor
+ print("Loading DETR model and processor...")
+ # Conditional DETR doesn't have the no_timm revision
+ if "conditional-detr" in args.model_name:
+ processor = DetrImageProcessor.from_pretrained(args.model_name)
+ else:
+ processor = DetrImageProcessor.from_pretrained(
+ args.model_name, revision="no_timm")
+
+ # Check if we should load a pre-trained model or train from scratch
+ # Use a GPU-specific directory so parallel runs don't overwrite each other
+ output_dir = Path(f"detr_finetuned_model_gpu{args.gpu_id}")
+ best_model_path = output_dir / "best_model"
+
+ model: Union[DetrForObjectDetection, ConditionalDetrForObjectDetection]
+ if args.load_model and best_model_path.exists():
+ print(f"Loading pre-trained model from {best_model_path}")
+ model = base_model_class.from_pretrained(best_model_path)
+ assert model is not None
+ model = model.to(device)
+ processor = DetrImageProcessor.from_pretrained(best_model_path)
+ print("β Pre-trained model loaded.")
+ skip_training = True
+ else:
+ print("Loading base model for training...")
+
+ # Create custom id2label and label2id if using direct BDD mapping
+ id2label, label2id = None, None
+ if args.use_bdd_direct_mapping and bdd_classes is not None:
+ print(f"Using direct BDD100K mapping with {num_classes} classes")
+ id2label = {i: bdd_classes[i] for i in range(num_classes)}
+ label2id = {v: k for k, v in id2label.items()}
+
+ model_class = base_model_class
+
+ model_kwargs: dict[str, Any] = {
+ "num_labels": num_classes,
+ "id2label": id2label,
+ "label2id": label2id,
+ "ignore_mismatched_sizes": True,
+ }
+ # Regular DETR models use no_timm revision, conditional DETR doesn't
+ if "conditional-detr" not in args.model_name:
+ model_kwargs["revision"] = "no_timm"
+
+ model = model_class.from_pretrained(args.model_name, **model_kwargs)
+
+ # ------------------------------------------------------------------
+ # Freeze parameters except heads + last k layers if requested
+ # ------------------------------------------------------------------
+ trainable_params = None
+ if args.train_last_k >= 0:
+ for p in model.parameters():
+ p.requires_grad = False
+
+ # Unfreeze heads
+ for p in model.class_labels_classifier.parameters():
+ p.requires_grad = True
+ for p in model.bbox_predictor.parameters():
+ p.requires_grad = True
+
+ k = args.train_last_k
+ if k > 0:
+ # Unfreeze last layer only for the transformer
+ transformer_k = 1
+ if hasattr(model.model.encoder, 'layers'):
+ for layer in model.model.encoder.layers[-transformer_k:]:
+ for p in layer.parameters():
+ p.requires_grad = True
+ if hasattr(model.model.decoder, 'layers'):
+ for layer in model.model.decoder.layers[-transformer_k:]:
+ for p in layer.parameters():
+ p.requires_grad = True
+
+ # Unfreeze last k ResNet layers if backbone exists
+ if hasattr(model.model.backbone, 'conv_encoder'):
+ conv_encoder = model.model.backbone.conv_encoder
+ # Flatten all bottleneck layers from layer1 through layer4
+ all_resnet_layers = []
+ for i in range(1, 5):
+ stage_name = f'layer{i}'
+ if hasattr(conv_encoder, stage_name):
+ stage = getattr(conv_encoder, stage_name)
+ all_resnet_layers.extend(list(stage))
+
+ # Unfreeze the last k bottleneck layers
+ for layer in all_resnet_layers[-k:]:
+ for p in layer.parameters():
+ p.requires_grad = True
+
+ # After setting requires_grad, filter for trainable parameters for the optimizer
+ trainable_params = [
+ p for p in model.parameters() if p.requires_grad]
+
+ # Set loss coefficients from args
+ model.config.eos_coefficient = args.eos_coefficient
+ if args.bbox_loss_coef is not None:
+ model.config.bbox_loss_coefficient = args.bbox_loss_coef
+ if args.giou_loss_coef is not None:
+ model.config.giou_loss_coefficient = args.giou_loss_coef
+
+ model = model.to(device)
+ print("β Model and processor loaded.")
+ skip_training = False
+
+ # Datasets and Dataloaders
+ print("Loading BDD100K datasets...")
+ train_dataset = Bdd100kDataset(
+ split="train",
+ # Use BDD categories for direct mapping, COCO for COCO mapping
+ use_original_categories=args.use_bdd_direct_mapping,
+ use_time_filtered=True
+ )
+ val_dataset = Bdd100kDataset(
+ split="val",
+ use_original_categories=args.use_bdd_direct_mapping,
+ use_time_filtered=True
+ )
+ print(
+ f"β Loaded {len(train_dataset)} training images and {len(val_dataset)} validation images.")
+
+ # -------------------------------------------------------------------
+ # Optional: apply class-rebalancing when using direct BDD mapping
+ # -------------------------------------------------------------------
+ class_weights = None
+ if args.use_bdd_direct_mapping and not skip_training and not args.no_class_weighting:
+ print("Computing class-balancing weights (median-frequency)...")
+ weights_fg = compute_median_freq_weights(
+ train_dataset, num_classes=num_classes)
+ min_weight = 1e-6
+ weights_fg = torch.clamp(weights_fg, min=min_weight)
+ class_weights = torch.cat(
+ [weights_fg, torch.tensor([args.eos_coefficient])])
+ print("Class-weight vector:", class_weights.cpu().numpy())
+ if bdd_classes:
+ print("Per-class weights (BDD order):")
+ for cls_id in range(num_classes):
+ cls_name = bdd_classes.get(cls_id, str(cls_id))
+ print(f" {cls_name}: {weights_fg[cls_id].item():.3f}")
+
+ # Apply all custom loss modifications after model is loaded
+ apply_custom_losses(model, args.area_loss_power, class_weights)
+
+ # -------------------------------------------------------------------
+ # DataLoaders (validation always needed)
+ # -------------------------------------------------------------------
+ collator = DetrDataCollator(processor, class_mapping)
+ val_dataloader = DataLoader(
+ val_dataset,
+ batch_size=args.batch_size,
+ collate_fn=collator,
+ num_workers=4,
+ )
+
+ if not skip_training:
+ # Init wandb only if training
+ import uuid
+ wandb.init(
+ project=args.wandb_project,
+ entity=args.wandb_entity,
+ config=vars(args),
+ name=f"detr-finetune-bdd-{str(uuid.uuid4())[:8]}"
+ )
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ collate_fn=collator,
+ num_workers=4,
+ )
+
+ # Optimizer
+ optimizer = torch.optim.AdamW(
+ trainable_params if trainable_params is not None else model.parameters(),
+ lr=args.lr
+ )
+
+ # ------------------------------------------------------------------
+ # Learning-rate scheduler (optional)
+ # ------------------------------------------------------------------
+ if args.lr_schedule == "step":
+ scheduler = StepLR(
+ optimizer, step_size=args.lr_step, gamma=args.lr_gamma)
+ print(
+ f"β Using StepLR: step={args.lr_step} epochs, gamma={args.lr_gamma}")
+ else:
+ scheduler = None
+
+ # Training loop
+ print("\nStarting training...")
+ best_val_loss = float('inf')
+ output_dir.mkdir(exist_ok=True)
+
+ for epoch in range(args.epochs):
+ train_loss = train_one_epoch(
+ model, train_dataloader, optimizer, device, epoch)
+ val_loss, val_map = evaluate(
+ model, val_dataloader, device, epoch, processor)
+
+ wandb.log({
+ "epoch": epoch + 1,
+ "train_loss_epoch": train_loss,
+ "val_loss_epoch": val_loss,
+ "val_map_epoch": val_map,
+ })
+
+ if scheduler is not None:
+ scheduler.step()
+
+ # Log current LR (scheduler or static)
+ current_lr = scheduler.get_last_lr()[0] if scheduler else args.lr
+
+ if val_loss < best_val_loss:
+ best_val_loss = val_loss
+ model.save_pretrained(output_dir / "best_model")
+ processor.save_pretrained(output_dir / "best_model")
+ print(f"β New best model saved to {output_dir / 'best_model'}")
+
+ print("\nβ Training finished.")
+ wandb.finish()
+ else:
+ print("\nSkipping training - using pre-trained model.")
+
+ # ------------------------------------------------------------------
+ # Evaluate loaded model on validation set to compute fresh mAP
+ # ------------------------------------------------------------------
+ print("\nEvaluating loaded model on validation set β¦")
+ val_loss, val_map = evaluate(
+ model, val_dataloader, device, epoch=0, processor=processor)
+ print(f"Validation Loss: {val_loss:.4f} | mAP: {val_map:.4f}")
+
+ # Visualization
+ print("\nStarting visualization on validation images...")
+ # GPU-specific visualization directory to avoid overwriting between parallel runs
+ vis_dir = Path(f"detr_finetune_results_gpu{args.gpu_id}")
+ vis_dir.mkdir(exist_ok=True)
+
+ model.eval()
+ for i in range(6):
+ item = val_dataset[i]
+ # The dataset might return different image formats, ensure it's a PIL Image
+ image = item["image"]
+
+ # Convert to PIL Image if it's not already
+ if isinstance(image, torch.Tensor):
+ # Convert tensor to PIL Image
+ if image.shape[0] == 3: # CHW format
+ image = image.permute(1, 2, 0) # Convert to HWC
+
+ # FIX: Handle different tensor value ranges
+ if image.max() > 1.0:
+ # Image is in [0, 255] range, convert to uint8
+ image = image.clamp(0, 255).byte().cpu().numpy()
+ else:
+ # Image is in [0, 1] range, scale to [0, 255]
+ image = (image * 255).clamp(0, 255).byte().cpu().numpy()
+
+ pil_image = Image.fromarray(image)
+ elif hasattr(image, 'convert'): # Already a PIL Image
+ pil_image = image
+ else:
+ # Try to convert array-like to PIL Image
+ import numpy as np
+ if isinstance(image, np.ndarray):
+ # Handle different numpy array value ranges
+ if image.max() > 1.0:
+ # Array is in [0, 255] range
+ if image.dtype != np.uint8:
+ image = image.astype(np.uint8)
+ else:
+ # Array is in [0, 1] range
+ image = (image * 255).astype(np.uint8)
+ pil_image = Image.fromarray(image)
+ else:
+ print(
+ f"Warning: Unexpected image type {type(image)}, skipping visualization")
+ continue
+
+ inputs = processor(images=pil_image, return_tensors="pt").to(device)
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # The target size for post-processing should be based on the original PIL image size
+ image_size = pil_image.size # This should be a tuple (width, height)
+ # Convert to [[height, width]]
+ target_sizes = [image_size[::-1]]
+
+ results = processor.post_process_object_detection(
+ outputs, target_sizes=target_sizes, threshold=0.5
+ )[0]
+
+ vis_image = _draw_boxes(
+ pil_image.copy(),
+ results["boxes"].cpu().tolist(),
+ results["scores"].cpu().tolist(),
+ results["labels"].cpu().tolist(),
+ model,
+ bdd_classes if args.use_bdd_direct_mapping else None
+ )
+
+ save_path = vis_dir / f"vis_{item['name'].replace('/', '_')}.png"
+ Image.fromarray(vis_image).save(save_path)
+ print(f"β Saved visualization to {save_path}")
+
+ print("\n=== Done. ===")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Fine-tune DETR on BDD100K.")
+ parser.add_argument("--model_name", type=str,
+ default="facebook/detr-resnet-50",
+ help="HuggingFace model name. Supported models: "
+ "facebook/detr-resnet-50, facebook/detr-resnet-101, "
+ "microsoft/conditional-detr-resnet-50")
+ parser.add_argument("--epochs", type=int, default=10,
+ help="Number of training epochs.")
+ parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
+ parser.add_argument("--batch_size", type=int, default=4, help="Batch size.")
+ parser.add_argument("--load_model", action="store_true", default=False,
+ help="Load pre-trained model instead of training from scratch.")
+ parser.add_argument("--wandb_project", type=str,
+ default="graid-detr-finetune", help="Wandb project name.")
+ parser.add_argument("--wandb_entity", type=str, default=None,
+ help="Wandb entity (username or team).")
+ parser.add_argument("--use_bdd_direct_mapping", action="store_true", default=False,
+ help="Use direct BDD100K class mapping (0-11) instead of COCO mapping.")
+ parser.add_argument("--eos_coefficient", type=float, default=0.1,
+ help="Adjust the 'No-object' class weight (eos_coefficient) during fine-tuning.")
+ parser.add_argument("--gpu_id", type=int, default=7,
+ help="CUDA device id to run the training on (e.g. 0, 1, 2 β¦).")
+ parser.add_argument("--no_class_weighting", action="store_true", default=False,
+ help="Disable class weighting when using direct BDD mapping.")
+ parser.add_argument("--bbox_loss_coef", type=float, default=None,
+ help="Weight for the L1 box loss.")
+ parser.add_argument("--giou_loss_coef", type=float, default=None,
+ help="Weight for the GIoU loss.")
+ parser.add_argument("--lr_schedule", type=str, default="none", choices=["none", "step"],
+ help="Type of LR scheduler to use.")
+ parser.add_argument("--lr_step", type=int, default=3,
+ help="StepLR: number of epochs between LR decays.")
+ parser.add_argument("--lr_gamma", type=float, default=0.1,
+ help="Multiplicative factor for LR decay (gamma).")
+ parser.add_argument("--train_last_k", type=int, default=-1,
+ help="If >0, unfreeze heads plus last k transformer & ResNet layers. -1 trains all layers.")
+ parser.add_argument("--area_loss_power", type=float, default=0.0,
+ help="If >0, enable area weighting with this power (0.5=sqrt, 1=linear, etc.). Use 0 to disable.")
+
+ cli_args = parser.parse_args()
+ main(cli_args)
diff --git a/graid/src/graid/graid.py b/graid/src/graid/graid.py
index beb4214..418bbd6 100644
--- a/graid/src/graid/graid.py
+++ b/graid/src/graid/graid.py
@@ -5,17 +5,121 @@
using various models and datasets.
"""
+import logging
+import os
import sys
+import warnings
from pathlib import Path
-from typing import Optional
+from typing import Optional, List, Dict
import typer
+
+
+# Suppress common warnings for better user experience
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+warnings.filterwarnings(
+ "ignore", message=".*TorchScript.*functional optimizers.*deprecated.*"
+)
+
+# Suppress mmengine info messages
+logging.getLogger("mmengine").setLevel(logging.WARNING)
+
+
# Add the project root to Python path for imports
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))
+def _configure_logging():
+ # Simple logic: GRAID_DEBUG_VERBOSE controls console debug, file always gets debug
+ debug_verbose = bool(os.getenv("GRAID_DEBUG_VERBOSE"))
+ console_level = logging.DEBUG if debug_verbose else logging.INFO
+ file_level = logging.DEBUG # Always debug to file
+ root_level = logging.DEBUG # Root logger must be permissive for debug messages
+
+ # Configure root logger once with both console and file handlers
+ logger = logging.getLogger()
+ if logger.handlers:
+ # If already configured, update levels
+ logger.setLevel(root_level)
+ for handler in logger.handlers:
+ if isinstance(handler, logging.StreamHandler) and not isinstance(handler, logging.FileHandler):
+ handler.setLevel(console_level)
+ elif isinstance(handler, logging.FileHandler):
+ handler.setLevel(file_level)
+ return
+
+ logger.setLevel(root_level)
+ formatter = logging.Formatter("%(asctime)s %(levelname)s [%(name)s] %(message)s", datefmt="%H:%M:%S")
+
+ # Create a custom filter to only show GRAID logs on console
+ class GraidLogFilter(logging.Filter):
+ def filter(self, record):
+ # Only show logs from graid modules (and a few important system messages)
+ return (record.name.startswith('graid.') or
+ record.name == 'graid' or
+ record.levelno >= logging.WARNING) # Always show warnings/errors from any source
+
+ # Console handler with GRAID-only filter
+ ch = logging.StreamHandler()
+ ch.setLevel(console_level)
+ ch.setFormatter(formatter)
+ ch.addFilter(GraidLogFilter()) # Only show GRAID logs on console
+ logger.addHandler(ch)
+
+ # File handler with timestamp
+ log_dir = os.getenv("GRAID_LOG_DIR", "logs")
+ # Create log directory with proper error handling
+ try:
+ Path(log_dir).mkdir(parents=True, exist_ok=True)
+ except PermissionError:
+ # Log to stderr if we can't create log directory
+ print(f"Warning: Permission denied creating log directory: {log_dir}", file=sys.stderr)
+ log_dir = "/tmp" # Fallback to /tmp
+ try:
+ Path(log_dir).mkdir(parents=True, exist_ok=True)
+ except Exception as fallback_e:
+ print(f"Warning: Could not create fallback log directory: {fallback_e}", file=sys.stderr)
+ log_dir = None
+ except OSError as e:
+ print(f"Warning: OS error creating log directory {log_dir}: {e}", file=sys.stderr)
+ log_dir = None
+ except Exception as e:
+ print(f"Warning: Unexpected error creating log directory {log_dir}: {e}", file=sys.stderr)
+ log_dir = None
+
+ # Generate timestamped log filename and create file handler
+ from datetime import datetime
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M")
+ log_filename = f"graid_{timestamp}.log"
+
+ # Only create file handler if we have a valid log directory
+ if log_dir is not None:
+ try:
+ fh = logging.FileHandler(Path(log_dir) / log_filename)
+ fh.setLevel(file_level)
+ fh.setFormatter(formatter)
+ logger.addHandler(fh)
+ except Exception as e:
+ print(f"Warning: Failed to create log file handler: {e}", file=sys.stderr)
+ print("Logging will only go to console", file=sys.stderr)
+ else:
+ print("Warning: No log directory available, logging only to console", file=sys.stderr)
+ # Quiet noisy libraries more aggressively
+ logging.getLogger("mmengine").setLevel(logging.WARNING)
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
+ logging.getLogger("PIL").setLevel(logging.WARNING)
+ logging.getLogger("PIL.Image").setLevel(logging.WARNING)
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
+ logging.getLogger("datasets").setLevel(logging.WARNING)
+ logging.getLogger("transformers").setLevel(logging.WARNING)
+ logging.getLogger("torch").setLevel(logging.WARNING)
+
+
+_configure_logging()
+
+
app = typer.Typer(
name="graid",
help="GRAID: Generating Reasoning questions from Analysis of Images via Discriminative artificial intelligence",
@@ -33,12 +137,18 @@ def print_welcome():
typer.echo()
typer.echo("GRAID provides two main capabilities:")
typer.echo()
- typer.secho("π Database Generation (generate):", fg=typer.colors.BLUE, bold=True)
- typer.echo("β’ Multiple datasets: BDD100K, NuImages, Waymo")
- typer.echo("β’ Various model backends: Detectron, MMDetection, Ultralytics")
+ typer.secho("ποΈ Dataset Generation (generate-dataset):",
+ fg=typer.colors.BLUE, bold=True)
+ typer.echo("β’ Multiple datasets: BDD100K, NuImages, Waymo + Custom datasets")
+ typer.echo("β’ Multi-backend support: Detectron2, MMDetection, Ultralytics")
+ typer.echo("β’ Ensemble models with Weighted Box Fusion (WBF)")
typer.echo("β’ Ground truth data or custom model predictions")
+ typer.echo("β’ Unlabeled image support (models generate detections)")
+ typer.echo("β’ Standard formats with COCO-style annotations")
+ typer.echo("β’ Interactive configuration and batch processing")
typer.echo()
- typer.secho("π§ VLM Evaluation (eval-vlms):", fg=typer.colors.BLUE, bold=True)
+ typer.secho("π§ VLM Evaluation (eval-vlms):",
+ fg=typer.colors.BLUE, bold=True)
typer.echo("β’ Evaluate Vision Language Models: GPT, Gemini, Llama")
typer.echo("β’ Multiple evaluation metrics: LLMJudge, ExactMatch, Contains")
typer.echo(
@@ -65,14 +175,23 @@ def get_dataset_choice() -> str:
)
typer.echo()
+ typer.secho("π‘ Custom Dataset Support:", fg=typer.colors.YELLOW, bold=True)
+ typer.echo(
+ " GRAID supports any PyTorch-compatible dataset. Only images are required for VQA."
+ )
+ typer.echo(" Annotations are optional (only needed for mAP/mAR evaluation).")
+ typer.echo()
+
while True:
choice = typer.prompt("Select dataset (1-3)")
if choice in datasets:
dataset_name = datasets[choice][0]
- typer.secho(f"β Selected: {dataset_name.upper()}", fg=typer.colors.GREEN)
+ typer.secho(
+ f"β Selected: {dataset_name.upper()}", fg=typer.colors.GREEN)
typer.echo()
return dataset_name
- typer.secho("Invalid choice. Please enter 1, 2, or 3.", fg=typer.colors.RED)
+ typer.secho("Invalid choice. Please enter 1, 2, or 3.",
+ fg=typer.colors.RED)
def get_split_choice() -> str:
@@ -95,7 +214,8 @@ def get_split_choice() -> str:
choice = typer.prompt("Select split (1-2)")
if choice in splits:
split_name = splits[choice][0]
- typer.secho(f"β Selected: {split_name.upper()}", fg=typer.colors.GREEN)
+ typer.secho(
+ f"β Selected: {split_name.upper()}", fg=typer.colors.GREEN)
typer.echo()
return split_name
typer.secho("Invalid choice. Please enter 1 or 2.", fg=typer.colors.RED)
@@ -107,8 +227,11 @@ def get_model_choice() -> tuple[Optional[str], Optional[str], Optional[dict]]:
typer.echo()
typer.echo(" 1. Ground Truth - Use original dataset annotations (fastest)")
- typer.echo(" 2. Pre-configured Models - Choose from built-in model configurations")
- typer.echo(" 3. Custom Model - Bring your own Detectron/MMDetection model")
+ typer.echo(
+ " 2. Pre-configured Models - Choose from built-in model configurations")
+ typer.echo(
+ " 3. Custom Model - Bring your own Detectron/MMDetection/Ultralytics model"
+ )
typer.echo()
while True:
@@ -125,7 +248,8 @@ def get_model_choice() -> tuple[Optional[str], Optional[str], Optional[dict]]:
elif choice == "3":
return get_custom_model()
- typer.secho("Invalid choice. Please enter 1, 2, or 3.", fg=typer.colors.RED)
+ typer.secho("Invalid choice. Please enter 1, 2, or 3.",
+ fg=typer.colors.RED)
def get_preconfigured_model() -> tuple[str, str, None]:
@@ -134,13 +258,15 @@ def get_preconfigured_model() -> tuple[str, str, None]:
typer.secho("π§ Pre-configured Models", fg=typer.colors.BLUE, bold=True)
typer.echo()
+ # Local import to avoid heavy dependencies
from graid.data.generate_db import list_available_models
available_models = list_available_models()
backends = list(available_models.keys())
typer.echo("Available backends:")
for i, backend in enumerate(backends, 1):
- typer.echo(f" {i}. {typer.style(backend.upper(), fg=typer.colors.GREEN)}")
+ typer.echo(
+ f" {i}. {typer.style(backend.upper(), fg=typer.colors.GREEN)}")
typer.echo()
while True:
@@ -149,9 +275,10 @@ def get_preconfigured_model() -> tuple[str, str, None]:
if 0 <= backend_choice < len(backends):
backend = backends[backend_choice]
break
- except ValueError:
- pass
- typer.secho("Invalid choice. Please enter a valid number.", fg=typer.colors.RED)
+ except ValueError as e:
+ typer.secho(f"Invalid input: Expected a number", fg=typer.colors.RED)
+ typer.secho("Invalid choice. Please enter a valid number.",
+ fg=typer.colors.RED)
typer.echo()
models = available_models[backend]
@@ -166,11 +293,13 @@ def get_preconfigured_model() -> tuple[str, str, None]:
if 0 <= model_choice < len(models):
model_name = models[model_choice]
break
- except ValueError:
- pass
- typer.secho("Invalid choice. Please enter a valid number.", fg=typer.colors.RED)
+ except ValueError as e:
+ typer.secho(f"Invalid input: Expected a number", fg=typer.colors.RED)
+ typer.secho("Invalid choice. Please enter a valid number.",
+ fg=typer.colors.RED)
- typer.secho(f"β Selected: {backend.upper()} - {model_name}", fg=typer.colors.GREEN)
+ typer.secho(
+ f"β Selected: {backend.upper()} - {model_name}", fg=typer.colors.GREEN)
typer.echo()
return backend, model_name, None
@@ -178,23 +307,29 @@ def get_preconfigured_model() -> tuple[str, str, None]:
def get_custom_model() -> tuple[str, str, dict]:
"""Interactive custom model configuration."""
typer.echo()
- typer.secho("π οΈ Custom Model Configuration", fg=typer.colors.BLUE, bold=True)
+ typer.secho("π οΈ Custom Model Configuration",
+ fg=typer.colors.BLUE, bold=True)
typer.echo()
typer.echo("Supported backends for custom models:")
typer.echo(" 1. Detectron2 - Facebook's object detection framework")
typer.echo(" 2. MMDetection - OpenMMLab's detection toolbox")
+ typer.echo(" 3. Ultralytics - YOLO and RT-DETR models")
typer.echo()
while True:
- choice = typer.prompt("Select backend (1-2)")
+ choice = typer.prompt("Select backend (1-3)")
if choice == "1":
backend = "detectron"
break
elif choice == "2":
backend = "mmdetection"
break
- typer.secho("Invalid choice. Please enter 1 or 2.", fg=typer.colors.RED)
+ elif choice == "3":
+ backend = "ultralytics"
+ break
+ typer.secho("Invalid choice. Please enter 1, 2, or 3.",
+ fg=typer.colors.RED)
typer.echo()
custom_config = {}
@@ -207,13 +342,15 @@ def get_custom_model() -> tuple[str, str, dict]:
config_file = typer.prompt(
"Config file path (e.g., 'COCO-Detection/retinanet_R_50_FPN_3x.yaml')"
)
- weights_file = typer.prompt("Weights file path (e.g., 'path/to/model.pth')")
+ weights_file = typer.prompt(
+ "Weights file path (e.g., 'path/to/model.pth')")
custom_config = {"config": config_file, "weights": weights_file}
elif backend == "mmdetection":
typer.echo("MMDetection Configuration:")
- typer.echo("You need to provide paths to configuration and checkpoint files.")
+ typer.echo(
+ "You need to provide paths to configuration and checkpoint files.")
typer.echo()
config_file = typer.prompt(
@@ -223,17 +360,29 @@ def get_custom_model() -> tuple[str, str, dict]:
custom_config = {"config": config_file, "checkpoint": checkpoint}
+ elif backend == "ultralytics":
+ typer.echo("Ultralytics Configuration:")
+ typer.echo("You need to provide the path to a custom trained model file.")
+ typer.echo()
+
+ model_path = typer.prompt(
+ "Model file path (e.g., 'path/to/custom_model.pt')")
+
+ custom_config = {"model_path": model_path}
+
# Generate a custom model name
- model_name = f"custom_{Path(custom_config.get('config', 'model')).stem}"
+ model_name = f"custom_{Path(custom_config.get('config', custom_config.get('model_path', 'model'))).stem}"
- typer.secho(f"β Custom model configured: {backend.upper()}", fg=typer.colors.GREEN)
+ typer.secho(
+ f"β Custom model configured: {backend.upper()}", fg=typer.colors.GREEN)
typer.echo()
return backend, model_name, custom_config
def get_confidence_threshold() -> float:
"""Interactive confidence threshold selection."""
- typer.secho("π― Step 4: Set confidence threshold", fg=typer.colors.BLUE, bold=True)
+ typer.secho("π― Step 4: Set confidence threshold",
+ fg=typer.colors.BLUE, bold=True)
typer.echo()
typer.echo("Confidence threshold filters out low-confidence detections.")
typer.echo("β’ Lower values (0.1-0.3): More detections, some false positives")
@@ -242,9 +391,11 @@ def get_confidence_threshold() -> float:
while True:
try:
- conf = float(typer.prompt("Enter confidence threshold", default="0.2"))
+ conf = float(typer.prompt(
+ "Enter confidence threshold", default="0.2"))
if 0.0 <= conf <= 1.0:
- typer.secho(f"β Confidence threshold: {conf}", fg=typer.colors.GREEN)
+ typer.secho(
+ f"β Confidence threshold: {conf}", fg=typer.colors.GREEN)
typer.echo()
return conf
typer.secho(
@@ -307,6 +458,8 @@ def generate(
)
raise typer.Exit(1)
+ # Local import for dataset validation
+ from graid.data.generate_db import DATASET_TRANSFORMS
if dataset not in DATASET_TRANSFORMS:
typer.secho(
f"Error: Invalid dataset '{dataset}'. Choose from: {list(DATASET_TRANSFORMS.keys())}",
@@ -323,22 +476,13 @@ def generate(
if conf is None:
conf = 0.2
- custom_config = None
- if config and checkpoint:
- if backend == "detectron":
- custom_config = {"config": config, "weights": checkpoint}
- elif backend == "mmdetection":
- custom_config = {"config": config, "checkpoint": checkpoint}
-
- # Handle custom model configuration
- if "custom_config" in locals() and custom_config:
- # Add custom model to MODEL_CONFIGS temporarily
- if backend not in MODEL_CONFIGS:
- MODEL_CONFIGS[backend] = {}
- MODEL_CONFIGS[backend][model] = custom_config
+ # Note: Custom model configuration support would need to be implemented
+ # in the model creation logic. For now, custom models are not supported
+ # in non-interactive mode.
# Start generation
- typer.secho("π Starting database generation...", fg=typer.colors.BLUE, bold=True)
+ typer.secho("π Starting database generation...",
+ fg=typer.colors.BLUE, bold=True)
typer.echo()
typer.echo(f"Dataset: {dataset}")
typer.echo(f"Split: {split}")
@@ -350,6 +494,7 @@ def generate(
typer.echo()
try:
+ from graid.data.generate_db import generate_db
db_name = generate_db(
dataset_name=dataset,
split=split,
@@ -366,12 +511,364 @@ def generate(
)
typer.echo(f"Database created: {db_name}")
+ except KeyboardInterrupt:
+ typer.echo("\nβΉοΈ Generation cancelled by user")
+ raise typer.Exit(130) # Standard exit code for SIGINT
+ except PermissionError as e:
+ typer.echo()
+ typer.secho(f"β Permission Error: {e}", fg=typer.colors.RED, bold=True)
+ typer.secho("Check file/directory permissions and try again.", fg=typer.colors.YELLOW)
+ raise typer.Exit(1)
+ except FileNotFoundError as e:
+ typer.echo()
+ typer.secho(f"β File Not Found: {e}", fg=typer.colors.RED, bold=True)
+ raise typer.Exit(1)
+ except ValueError as e:
+ typer.echo()
+ typer.secho(f"β Invalid Value: {e}", fg=typer.colors.RED, bold=True)
+ typer.secho("Check your input parameters and try again.", fg=typer.colors.YELLOW)
+ raise typer.Exit(1)
except Exception as e:
typer.echo()
- typer.secho(f"β Error during generation: {e}", fg=typer.colors.RED, bold=True)
+ typer.secho(f"β Unexpected Error: {e}", fg=typer.colors.RED, bold=True)
+ if os.getenv("GRAID_DEBUG_VERBOSE"):
+ import traceback
+ typer.echo("Detailed traceback:")
+ typer.echo(traceback.format_exc())
+ else:
+ typer.secho("Use GRAID_DEBUG_VERBOSE=1 for detailed error information.", fg=typer.colors.CYAN)
raise typer.Exit(1)
+def _handle_list_commands(
+ list_valid_objects: bool,
+ list_questions: bool
+) -> bool:
+ """Handle --list-objects and --list-questions commands."""
+ if list_valid_objects:
+ from graid.utilities.coco import coco_labels
+ typer.secho("π Valid COCO Objects", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+
+ valid_objects = list(coco_labels.values())
+ # Remove undefined as it's not a real COCO class
+ if "undefined" in valid_objects:
+ valid_objects.remove("undefined")
+ valid_objects.sort()
+
+ for i, obj in enumerate(valid_objects, 1):
+ typer.echo(f" {i:2d}. {obj}")
+
+ typer.echo()
+ typer.echo(f"Total: {len(valid_objects)} objects")
+ return True
+
+ if list_questions:
+ from graid.data.generate_dataset import list_available_questions
+ typer.secho("π Available Questions", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+
+ questions = list_available_questions()
+ for i, (name, info) in enumerate(questions.items(), 1):
+ typer.secho(f"{i:2d}. {name}", fg=typer.colors.GREEN, bold=True)
+ typer.echo(f" {info['question']}")
+ if info["parameters"]:
+ typer.echo(" Parameters:")
+ for param_name, param_info in info["parameters"].items():
+ typer.echo(
+ f" β’ {param_name}: {param_info['description']} (default: {param_info['default']})"
+ )
+ typer.echo()
+
+ typer.echo(f"Total: {len(questions)} question types")
+ return True
+
+ return False
+
+
+def _handle_interactive_questions(interactive_questions: bool) -> Optional[List[Dict]]:
+ """Handle interactive question selection."""
+ if interactive_questions:
+ from graid.data.generate_dataset import interactive_question_selection
+ return interactive_question_selection()
+ return None
+
+
+def _load_and_validate_config(
+ config_file: Optional[str],
+ **cli_args
+):
+ """Load configuration from file and apply CLI overrides."""
+ from graid.cli_helpers import ConfigurationManager
+
+ if config_file:
+ config = ConfigurationManager.load_from_file(config_file)
+ config = ConfigurationManager.apply_cli_overrides(config, **cli_args)
+ else:
+ # Create config from CLI args only
+ if not cli_args.get('dataset') or not cli_args.get('split'):
+ raise ValueError("Either --config file or both --dataset and --split are required")
+
+ # Local import for config creation
+ from graid.data.config_support import DatasetGenerationConfig
+ config = DatasetGenerationConfig(
+ dataset_name=cli_args['dataset'],
+ split=cli_args['split'],
+ models=[],
+ use_wbf=False,
+ confidence_threshold=0.0,
+ batch_size=32,
+ device=None,
+ allowable_set=cli_args.get('allowable_set', []).split(',') if cli_args.get('allowable_set') else None,
+ num_workers=cli_args.get('num_workers', 4),
+ qa_workers=cli_args.get('qa_workers', 4),
+ save_path=cli_args.get('save_path'),
+ upload_to_hub=cli_args.get('upload_to_hub', False),
+ hub_repo_id=cli_args.get('hub_repo_id'),
+ hub_private=cli_args.get('hub_private', False),
+ num_samples=None,
+ use_original_filenames=True,
+ filename_prefix='img',
+ force=cli_args.get('force', False),
+ question_configs=[{'name': 'HowMany', 'params': {}}] # Default question
+ )
+
+ # Final validation
+ ConfigurationManager.validate_final_config(config)
+ return config
+
+
+def _process_dataset_generation(config, question_configs: Optional[List] = None):
+ """Process the actual dataset generation."""
+ from graid.cli_helpers import DatasetProcessor, ArgumentValidator
+
+ # Override question configs if provided interactively
+ if question_configs:
+ config.question_configs = question_configs
+
+ # Determine processing strategy
+ splits = ArgumentValidator.validate_split_format(config.split)
+
+ if len(splits) == 1:
+ return DatasetProcessor.process_single_split(config)
+ else:
+ return DatasetProcessor.process_multiple_splits(config)
+
+
+@app.command("generate-dataset")
+def generate_dataset_cmd(
+ config_file: Optional[str] = typer.Option(
+ None, "--config", "-c", help="Path to configuration file"
+ ),
+ dataset: Optional[str] = typer.Option(
+ None,
+ help="Dataset name (bdd, nuimage, waymo) - supports custom PyTorch datasets",
+ ),
+ split: Optional[str] = typer.Option(
+ None, help="Data split (train, val, test)"),
+ allowable_set: Optional[str] = typer.Option(
+ None, help="Comma-separated list of allowed COCO objects"
+ ),
+ save_path: str = typer.Option(
+ "./graid-datasets", help="Path to save the generated dataset"
+ ),
+ upload_to_hub: bool = typer.Option(
+ False, help="Upload dataset to HuggingFace Hub"),
+ hub_repo_id: Optional[str] = typer.Option(
+ None, help="HuggingFace Hub repository ID"
+ ),
+ hub_private: bool = typer.Option(
+ False, help="Make HuggingFace Hub repository private"
+ ),
+ interactive: bool = typer.Option(True, help="Use interactive mode"),
+ list_valid_objects: bool = typer.Option(
+ False, "--list-objects", help="List valid COCO objects and exit"
+ ),
+ list_questions: bool = typer.Option(
+ False, "--list-questions", help="List available questions and exit"
+ ),
+ interactive_questions: bool = typer.Option(
+ False, "--interactive-questions", help="Use interactive question selection"
+ ),
+ num_workers: int = typer.Option(
+ 4, "--num-workers", "-j", help="DataLoader workers for parallel image loading"
+ ),
+ qa_workers: int = typer.Option(
+ 4, "--qa-workers", help="Parallel threads for QA generation"
+ ),
+ force: bool = typer.Option(
+ False, "--force", help="Force restart from scratch, ignore existing checkpoints"
+ ),
+):
+ """
+ Generate HuggingFace datasets for object detection question-answering.
+
+ Supports built-in datasets (BDD100K, NuImages, Waymo) and custom PyTorch datasets
+ with COCO-style annotations. Use interactive mode or config files for easy setup.
+ """
+
+ from graid.cli_helpers import (
+ ValidationError, ConfigurationError, ProcessingError, ErrorHandler
+ )
+
+ try:
+ # Handle list commands first
+ if _handle_list_commands(list_valid_objects, list_questions):
+ return
+
+ print_welcome()
+
+ # Handle interactive question selection
+ question_configs = _handle_interactive_questions(interactive_questions)
+
+ # Load and validate configuration
+ typer.echo("π Loading and validating configuration...")
+
+ # Only include CLI args that were explicitly provided (not defaults)
+ cli_args = {}
+
+ # String/Optional args (None means not provided)
+ if dataset is not None:
+ cli_args['dataset'] = dataset
+ if split is not None:
+ cli_args['split'] = split
+ if allowable_set is not None:
+ cli_args['allowable_set'] = allowable_set
+ if hub_repo_id is not None:
+ cli_args['hub_repo_id'] = hub_repo_id
+
+ # Boolean args - need to check if explicitly set vs default
+ # For typer, we need to detect if these were provided by user
+ # Since typer doesn't provide an easy way, we'll use a different approach
+ import sys
+ cli_flags = sys.argv[1:] # Get CLI arguments
+
+ if '--upload-to-hub' in cli_flags:
+ cli_args['upload_to_hub'] = upload_to_hub
+ if '--hub-private' in cli_flags:
+ cli_args['hub_private'] = hub_private
+ if '--force' in cli_flags:
+ cli_args['force'] = force
+
+ # For numeric args, check if they differ from defaults
+ if '--num-workers' in cli_flags or '-j' in cli_flags:
+ cli_args['num_workers'] = num_workers
+ if '--qa-workers' in cli_flags:
+ cli_args['qa_workers'] = qa_workers
+
+ # For save_path, only override if explicitly provided
+ if '--save-path' in cli_flags:
+ cli_args['save_path'] = save_path
+
+ config = _load_and_validate_config(config_file, **cli_args)
+ typer.secho("β Configuration validated successfully", fg=typer.colors.GREEN)
+
+ # Display configuration summary
+ typer.echo()
+ typer.secho("π Configuration Summary", fg=typer.colors.BLUE, bold=True)
+ typer.echo(f"Dataset: {config.dataset_name}")
+ typer.echo(f"Split: {config.split}")
+ typer.echo(f"Models: {len(getattr(config, 'models', [])) if getattr(config, 'models', None) else 0} (using ground truth if 0)")
+ typer.echo(f"Batch size: {getattr(config, 'batch_size', 32)}")
+ typer.echo(f"Workers: {config.num_workers} (loading), {config.qa_workers} (QA)")
+ typer.echo(f"Save path: {config.save_path}")
+ if config.allowable_set:
+ typer.echo(f"COCO filter: {', '.join(config.allowable_set[:3])}{'...' if len(config.allowable_set) > 3 else ''}")
+ typer.echo(f"Upload to Hub: {'Yes' if config.upload_to_hub else 'No'}")
+ if config.upload_to_hub:
+ typer.echo(f"Hub repo: {config.hub_repo_id}")
+ typer.echo()
+
+ # Start dataset generation
+ typer.secho("π Starting dataset generation...", fg=typer.colors.BLUE, bold=True)
+ dataset_dict = _process_dataset_generation(config, question_configs)
+
+ # Success reporting
+ typer.echo()
+ typer.secho("β
Dataset generation completed successfully!", fg=typer.colors.GREEN, bold=True)
+
+ total_pairs = sum(len(dataset) for dataset in dataset_dict.values())
+ typer.echo(f"π Generated {total_pairs} question-answer pairs")
+
+ if len(dataset_dict) > 1:
+ counts = ", ".join(f"{s}={len(dataset_dict[s])}" for s in dataset_dict.keys())
+ typer.echo(f"π Per-split counts: {counts}")
+
+ if config.save_path:
+ typer.echo(f"πΎ Saved to: {config.save_path}")
+
+ if config.upload_to_hub:
+ typer.echo(f"π€ Uploaded to HuggingFace Hub: {config.hub_repo_id}")
+
+ except ValidationError as e:
+ ErrorHandler.handle_validation_error(e)
+ except ConfigurationError as e:
+ ErrorHandler.handle_configuration_error(e)
+ except ProcessingError as e:
+ ErrorHandler.handle_processing_error(e)
+ except Exception as e:
+ ErrorHandler.handle_unexpected_error(e)
+
+
+
+
+
+def _load_configuration(config_file, interactive, interactive_questions, **cli_args):
+ """Load configuration from file or interactive mode with CLI overrides."""
+ from graid.cli import ConfigurationManager
+
+ if config_file:
+ # Load from file and apply CLI overrides
+ config = ConfigurationManager.load_from_file(config_file)
+ config = ConfigurationManager.apply_cli_overrides(config, **cli_args)
+
+ typer.secho("β Configuration loaded from:", fg=typer.colors.GREEN)
+ typer.echo(f" {config_file}")
+ typer.echo()
+ else:
+ # Interactive configuration
+ config = ConfigurationManager.create_interactive_config(
+ interactive_questions=interactive_questions, **cli_args
+ )
+
+ return config
+
+
+def _validate_configuration(config):
+ """Validate final configuration."""
+ from graid.cli import ConfigurationManager
+ ConfigurationManager.validate_configuration(config)
+
+
+def _report_success(dataset_dict, config):
+ """Report successful completion with summary."""
+ from graid.cli.validators import ArgumentValidator
+
+ requested_splits = ArgumentValidator.parse_and_validate_split(config.split)
+
+ # Success message
+ typer.echo()
+ typer.secho(
+ "β
Dataset generation completed successfully!",
+ fg=typer.colors.GREEN,
+ bold=True,
+ )
+
+ # Show summary
+ if len(requested_splits) == 1:
+ split_dataset = dataset_dict[requested_splits[0]]
+ typer.echo(f"π Generated {len(split_dataset)} question-answer pairs")
+ else:
+ counts = ", ".join(f"{s}={len(dataset_dict[s])}" for s in requested_splits)
+ typer.echo(f"π Generated per-split counts: {counts}")
+
+ if config.save_path:
+ typer.echo(f"πΎ Saved to: {config.save_path}")
+
+ if config.upload_to_hub:
+ typer.echo(f"π€ Uploaded to HuggingFace Hub: {config.hub_repo_id}")
+
+
@app.command("eval-vlms")
def eval_vlms(
db_path: Optional[str] = typer.Option(
@@ -413,6 +910,7 @@ def eval_vlms(
if list_vlms:
typer.secho("π€ Available VLM Types:", fg=typer.colors.BLUE, bold=True)
typer.echo()
+ # Local import to avoid heavy dependencies
from graid.evaluator.eval_vlms import VLM_CONFIGS
for vlm_type, config in VLM_CONFIGS.items():
typer.secho(f"{vlm_type}:", fg=typer.colors.GREEN, bold=True)
@@ -425,6 +923,7 @@ def eval_vlms(
if list_metrics:
typer.secho("π Available Metrics:", fg=typer.colors.BLUE, bold=True)
typer.echo()
+ # Local import to avoid heavy dependencies
from graid.evaluator.eval_vlms import METRIC_CONFIGS
for metric_type, config in METRIC_CONFIGS.items():
typer.secho(f"{metric_type}:", fg=typer.colors.GREEN, bold=True)
@@ -435,6 +934,7 @@ def eval_vlms(
if list_prompts:
typer.secho("π¬ Available Prompts:", fg=typer.colors.BLUE, bold=True)
typer.echo()
+ # Local import to avoid heavy dependencies
from graid.evaluator.eval_vlms import PROMPT_CONFIGS
for prompt_type, config in PROMPT_CONFIGS.items():
typer.secho(f"{prompt_type}:", fg=typer.colors.GREEN, bold=True)
@@ -446,7 +946,8 @@ def eval_vlms(
if interactive and not db_path:
typer.secho("π VLM Evaluation", fg=typer.colors.CYAN, bold=True)
typer.echo()
- typer.echo("This tool evaluates Vision Language Models using SQLite databases")
+ typer.echo(
+ "This tool evaluates Vision Language Models using SQLite databases")
typer.echo("containing questions and answers about images.")
typer.echo()
@@ -458,6 +959,7 @@ def eval_vlms(
raise typer.Exit(1)
# Check if model name is required
+ # Local import to avoid heavy dependencies
from graid.evaluator.eval_vlms import VLM_CONFIGS
vlm_config = VLM_CONFIGS.get(vlm)
if not vlm_config:
@@ -468,12 +970,14 @@ def eval_vlms(
raise typer.Exit(1)
if vlm_config["requires_model_selection"] and not model:
- typer.secho(f"Error: Model selection required for {vlm}.", fg=typer.colors.RED)
+ typer.secho(
+ f"Error: Model selection required for {vlm}.", fg=typer.colors.RED)
typer.echo(f"Available models: {', '.join(vlm_config['models'])}")
typer.echo("Use --model to specify a model.")
raise typer.Exit(1)
# Start evaluation
+ from graid.evaluator.eval_vlms import evaluate_vlm, METRIC_CONFIGS, PROMPT_CONFIGS, VLM_CONFIGS
typer.secho("π Starting VLM evaluation...", fg=typer.colors.BLUE, bold=True)
typer.echo()
typer.echo(f"Database: {db_path}")
@@ -506,9 +1010,38 @@ def eval_vlms(
)
typer.echo(f"Final accuracy: {accuracy:.4f}")
+ except KeyboardInterrupt:
+ typer.echo("\nβΉοΈ Evaluation cancelled by user")
+ raise typer.Exit(130) # Standard exit code for SIGINT
+ except FileNotFoundError as e:
+ typer.echo()
+ typer.secho(f"β File Not Found: {e}", fg=typer.colors.RED, bold=True)
+ typer.secho("Check that the database file exists and try again.", fg=typer.colors.YELLOW)
+ raise typer.Exit(1)
+ except PermissionError as e:
+ typer.echo()
+ typer.secho(f"β Permission Error: {e}", fg=typer.colors.RED, bold=True)
+ typer.secho("Check file permissions and try again.", fg=typer.colors.YELLOW)
+ raise typer.Exit(1)
+ except ValueError as e:
+ typer.echo()
+ typer.secho(f"β Invalid Parameter: {e}", fg=typer.colors.RED, bold=True)
+ typer.secho("Check your evaluation parameters and try again.", fg=typer.colors.YELLOW)
+ raise typer.Exit(1)
+ except ImportError as e:
+ typer.echo()
+ typer.secho(f"β Import Error: {e}", fg=typer.colors.RED, bold=True)
+ typer.secho("Check that VLM dependencies are installed.", fg=typer.colors.YELLOW)
+ raise typer.Exit(1)
except Exception as e:
typer.echo()
- typer.secho(f"β Error during evaluation: {e}", fg=typer.colors.RED, bold=True)
+ typer.secho(f"β Unexpected Error during evaluation: {e}", fg=typer.colors.RED, bold=True)
+ if os.getenv("GRAID_DEBUG_VERBOSE"):
+ import traceback
+ typer.echo("Detailed traceback:")
+ typer.echo(traceback.format_exc())
+ else:
+ typer.secho("Use GRAID_DEBUG_VERBOSE=1 for detailed error information.", fg=typer.colors.CYAN)
raise typer.Exit(1)
@@ -517,7 +1050,7 @@ def list_models():
"""List all available pre-configured models."""
typer.secho("π Available Models", fg=typer.colors.BLUE, bold=True)
typer.echo()
-
+ # Local import to avoid heavy dependencies
from graid.data.generate_db import list_available_models
models = list_available_models()
for backend, model_list in models.items():
@@ -527,6 +1060,32 @@ def list_models():
typer.echo()
+@app.command()
+def list_questions():
+ """List available questions with their parameters."""
+ typer.secho("π Available Questions:", fg=typer.colors.BLUE, bold=True)
+ typer.echo()
+ # Local import to avoid heavy dependencies
+ from graid.data.generate_dataset import list_available_questions
+ questions = list_available_questions()
+ for i, (name, info) in enumerate(questions.items(), 1):
+ typer.secho(f"{i:2d}. {name}", fg=typer.colors.GREEN, bold=True)
+ typer.echo(f" {info['question']}")
+ if info["parameters"]:
+ typer.echo(" Parameters:")
+ for param_name, param_info in info["parameters"].items():
+ typer.echo(
+ f" β’ {param_name}: {param_info['description']} (default: {param_info['default']})"
+ )
+ typer.echo()
+
+ typer.secho("π‘ Usage:", fg=typer.colors.YELLOW, bold=True)
+ typer.echo(
+ "Use --interactive-questions flag with generate-dataset for interactive selection"
+ )
+ typer.echo("Or configure questions in a config file")
+
+
@app.command()
def info():
"""Show information about GRAID and supported datasets/models."""
@@ -534,12 +1093,14 @@ def info():
from graid.data.generate_db import DATASET_TRANSFORMS, MODEL_CONFIGS
typer.secho("π Supported Datasets:", fg=typer.colors.BLUE, bold=True)
+ # Local import to avoid heavy dependencies
+ from graid.data.generate_db import DATASET_TRANSFORMS
for dataset in DATASET_TRANSFORMS.keys():
typer.echo(f" β’ {dataset.upper()}")
typer.echo()
typer.secho("π§ Supported Model Backends:", fg=typer.colors.BLUE, bold=True)
- for backend in MODEL_CONFIGS.keys():
+ for backend in ["detectron", "mmdetection", "ultralytics"]:
typer.echo(f" β’ {backend.upper()}")
typer.echo()
diff --git a/graid/src/graid/graid_cli.py b/graid/src/graid/graid_cli.py
deleted file mode 100644
index 51e42ce..0000000
--- a/graid/src/graid/graid_cli.py
+++ /dev/null
@@ -1,30 +0,0 @@
-#!/usr/bin/env python3
-"""
-GRAID CLI Entry Point
-
-Run this script to launch the GRAID interactive interface.
-"""
-
-import sys
-from pathlib import Path
-
-# Add the project to Python path
-project_root = Path(__file__).parent.parent.parent.parent
-sys.path.insert(0, str(project_root))
-
-if __name__ == "__main__":
- try:
- # Add the src path to make imports work
- src_path = Path(__file__).parent
- sys.path.insert(0, str(src_path))
-
- from graid import app
-
- app()
- except ImportError as e:
- print(f"Error importing GRAID modules: {e}")
- print(
- "Make sure you're in the project root directory and all dependencies are installed."
- )
- print("Try: poetry install")
- sys.exit(1)
diff --git a/graid/src/graid/interfaces/DepthPerceptionI.py b/graid/src/graid/interfaces/DepthPerceptionI.py
index d965ed0..3baf363 100644
--- a/graid/src/graid/interfaces/DepthPerceptionI.py
+++ b/graid/src/graid/interfaces/DepthPerceptionI.py
@@ -5,7 +5,6 @@
import numpy as np
import PIL
import PIL.Image
-from matplotlib import pyplot as plt
from PIL.Image import Image
from torch import Tensor
@@ -59,6 +58,8 @@ def visualize_inverse_depth(dpr: DepthPerceptionResult) -> Image:
inverse_depth_normalized = (inverse_depth - min_invdepth_vizu) / (
max_invdepth_vizu - min_invdepth_vizu
)
+ # Local import to avoid loading matplotlib unless visualization is needed
+ from matplotlib import pyplot as plt
cmap = plt.get_cmap("turbo")
color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype(np.uint8)
diff --git a/graid/src/graid/interfaces/ObjectDetectionI.py b/graid/src/graid/interfaces/ObjectDetectionI.py
index 19f7bfd..893342b 100644
--- a/graid/src/graid/interfaces/ObjectDetectionI.py
+++ b/graid/src/graid/interfaces/ObjectDetectionI.py
@@ -14,6 +14,7 @@
from PIL import Image
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from ultralytics.engine.results import Boxes as UltralyticsBoxes
+import threading
class BBox_Format(Enum):
@@ -267,6 +268,8 @@ def get_area(self) -> torch.Tensor:
class ObjectDetectionUtils:
+ # Thread-local storage for per-image context
+ _ctx_local = threading.local()
@staticmethod
def pairwise_iou(
boxes1: ObjectDetectionResultI, boxes2: ObjectDetectionResultI
@@ -438,6 +441,130 @@ def compute_metrics_for_single_img(
# tn += 1
# return {"TN": tn}
+ @staticmethod
+ def normalize_detections(
+ detections: List[ObjectDetectionResultI],
+ bbox_format: BBox_Format = BBox_Format.XYXY,
+ ) -> Dict[str, Any]:
+ """
+ Normalize a mixed list of detection results into per-box, tensor-safe structures.
+
+ Returns a dictionary with:
+ - detections: List[ObjectDetectionResultI] (one per box)
+ - labels: List[str]
+ - bboxes_xyxy: torch.Tensor of shape (N, 4)
+ - bbox_list: List[Dict[str, float]] [{'x1','y1','x2','y2'}]
+ - counts: Dict[str, int] class β count
+ """
+ # Flatten into one detection per box
+ flattened: List[ObjectDetectionResultI] = []
+ for det in detections:
+ flattened.extend(det.flatten())
+
+ labels: List[str] = []
+ boxes_xyxy: List[torch.Tensor] = []
+ counts: Dict[str, int] = {}
+
+ for det in flattened:
+ # Label as string
+ lbl = det.label
+ lbl_str = str(lbl.item()) if isinstance(lbl, torch.Tensor) else str(lbl)
+ labels.append(lbl_str)
+ counts[lbl_str] = counts.get(lbl_str, 0) + 1
+
+ # Bbox in XYXY first 4 coords
+ if bbox_format == BBox_Format.XYXY:
+ xyxy = det.as_xyxy()
+ elif bbox_format == BBox_Format.XYWH:
+ xyxy = det.as_xywh()
+ elif bbox_format == BBox_Format.XYWHN:
+ xyxy = det.as_xywhn()
+ elif bbox_format == BBox_Format.XYXYN:
+ xyxy = det.as_xyxyn()
+ else:
+ # Default to xyxy
+ xyxy = det.as_xyxy()
+
+ # Ensure shape (4,) tensor for this single detection
+ if xyxy.dim() == 2:
+ # expected (1, 6) or (1, 4+) layout from UltralyticsBoxes
+ coords = xyxy[0][:4]
+ else:
+ coords = xyxy[:4]
+ boxes_xyxy.append(coords)
+
+ # Stack xyxy to (N, 4)
+ bxyxy = torch.stack(boxes_xyxy) if boxes_xyxy else torch.empty((0, 4), dtype=torch.float32)
+
+ # Generate list-of-dicts format commonly used when writing out
+ bbox_list: List[Dict[str, float]] = [
+ {"x1": float(b[0]), "y1": float(b[1]), "x2": float(b[2]), "y2": float(b[3])}
+ for b in bxyxy
+ ]
+
+ return {
+ "detections": flattened,
+ "labels": labels,
+ "bboxes_xyxy": bxyxy,
+ "bbox_list": bbox_list,
+ "counts": counts,
+ }
+
+ @staticmethod
+ def build_question_context(
+ image: Optional[Union[np.ndarray, torch.Tensor, Image.Image]],
+ detections: List[ObjectDetectionResultI],
+ ) -> Dict[str, Any]:
+ """Precompute per-image features for questions to avoid recomputation.
+
+ Returns a dictionary with:
+ - detections, labels, bboxes_xyxy, bbox_list, counts (from normalize_detections)
+ - centers: Tensor (N,2)
+ - areas: Tensor (N,)
+ - aspects: Tensor (N,) width/height
+ - class_to_indices: Dict[str, List[int]]
+ """
+ norm = ObjectDetectionUtils.normalize_detections(detections)
+ bxyxy: torch.Tensor = norm["bboxes_xyxy"]
+ if bxyxy.numel() > 0:
+ widths = (bxyxy[:, 2] - bxyxy[:, 0]).clamp(min=1.0)
+ heights = (bxyxy[:, 3] - bxyxy[:, 1]).clamp(min=1.0)
+ centers = torch.stack([(bxyxy[:, 0] + bxyxy[:, 2]) / 2.0, (bxyxy[:, 1] + bxyxy[:, 3]) / 2.0], dim=1)
+ areas = widths * heights
+ aspects = widths / heights
+ else:
+ centers = torch.empty((0, 2), dtype=torch.float32)
+ areas = torch.empty((0,), dtype=torch.float32)
+ aspects = torch.empty((0,), dtype=torch.float32)
+
+ class_to_indices: Dict[str, List[int]] = {}
+ for idx, lbl in enumerate(norm["labels"]):
+ class_to_indices.setdefault(lbl, []).append(idx)
+
+ ctx = {
+ **norm,
+ "centers": centers,
+ "areas": areas,
+ "aspects": aspects,
+ "class_to_indices": class_to_indices,
+ "image": image,
+ }
+ return ctx
+
+ @staticmethod
+ def set_current_context(ctx: Optional[Dict[str, Any]]) -> None:
+ ObjectDetectionUtils._ctx_local.value = ctx
+
+ @staticmethod
+ def get_current_context() -> Optional[Dict[str, Any]]:
+ return getattr(ObjectDetectionUtils._ctx_local, "value", None)
+
+ @staticmethod
+ def clear_current_context() -> None:
+ """Clear any previously set per-image QuestionContext."""
+ if hasattr(ObjectDetectionUtils._ctx_local, "value"):
+ ObjectDetectionUtils._ctx_local.value = None
+
@staticmethod
def show_image_with_detections(
image: Image.Image, detections: List[ObjectDetectionResultI]
diff --git a/graid/src/graid/models/DepthPro.py b/graid/src/graid/models/DepthPro.py
index 63cd086..66e4500 100644
--- a/graid/src/graid/models/DepthPro.py
+++ b/graid/src/graid/models/DepthPro.py
@@ -1,9 +1,9 @@
from itertools import islice
-from typing import Iterator, List, Union, override
+from typing import Iterator, List, Union
import depth_pro
import torch
-from PIL import Image, ImageSequence
+from PIL import Image
from graid.interfaces.DepthPerceptionI import (
DepthPerceptionI,
DepthPerceptionResult,
@@ -18,8 +18,8 @@ def __init__(self, **kwargs) -> None:
)
if not model_path.exists():
raise FileNotFoundError(
- f"Model path does not exist: {model_path}",
- "Please follow the project's readme to install all components.",
+ f"Model path does not exist: {model_path}. "
+ "Please follow the project's readme to install all components."
)
depth_pro.depth_pro.DEFAULT_MONODEPTH_CONFIG_DICT.checkpoint_uri = model_path
@@ -30,16 +30,34 @@ def __init__(self, **kwargs) -> None:
)
self.model.eval()
- self._prediction = None
- self._depth_prediction = None
- self._focallength_px = None
- self._depth_map = None
-
- @override
def predict_depth(self, image: Image.Image) -> DepthPerceptionResult:
- image, _, f_px = depth_pro.load_rgb(image)
- image = self.transform(image)
- prediction = self.model.infer(image, f_px=f_px)
+ """
+ Predict depth for a single image.
+
+ Args:
+ image: PIL Image to process
+
+ Returns:
+ DepthPerceptionResult containing depth prediction and focal length
+ """
+ # Convert PIL Image to numpy array for direct processing
+ # (bypassing depth_pro.load_rgb which expects file paths)
+ import numpy as np
+
+ # Convert to RGB if needed
+ if image.mode != 'RGB':
+ image = image.convert('RGB')
+
+ # Convert to numpy array
+ image_array = np.array(image)
+
+ # Apply transform directly to the image array
+ # depth_pro.load_rgb normally returns (image_array, icc_profile, f_px)
+ # We'll set f_px to None and let the model estimate it
+ image_tensor = self.transform(image_array)
+ f_px = None # Let the model estimate focal length
+
+ prediction = self.model.infer(image_tensor, f_px=f_px)
depth_prediction = prediction["depth"]
focallength_px = prediction["focallength_px"]
@@ -49,7 +67,6 @@ def predict_depth(self, image: Image.Image) -> DepthPerceptionResult:
)
return result
- @override
def predict_depths(
self,
video: Union[Iterator[Image.Image], List[Image.Image]],
@@ -57,93 +74,53 @@ def predict_depths(
) -> Iterator[DepthPerceptionResult]:
"""
Predicts the depth of each frame in the input video.
- Note: The video must be a list of PIL images
- In this way, we force the callers to do any preprocessing they need.
- For example, skipping frames to reduce computation time.
-
+
Args:
video: An iterator or list of PIL images
batch_size: The number of frames to predict in one forward pass
Yields:
- An iterator of batches of DepthPerceptionResult objects
+ Iterator of DepthPerceptionResult objects (one per frame)
"""
def _batch_iterator(iterable, n):
iterator = iter(iterable)
return iter(lambda: list(islice(iterator, n)), [])
- # If video is a list, convert it to an iterator of batches
- if isinstance(video, list):
- video_iterator = _batch_iterator(video, batch_size)
- else:
- # If video is already an iterator, create batches from it
- video_iterator = _batch_iterator(video, batch_size)
+ # Convert to batch iterator regardless of input type
+ video_iterator = _batch_iterator(video, batch_size)
for batch in video_iterator:
if not batch: # End of iterator
break
+
images, f_px_list = [], []
for img in batch:
- img, _, f_px = depth_pro.load_rgb(img)
- img = self.transform(img)
- images.append(img)
+ img_tensor, _, f_px = depth_pro.load_rgb(img)
+ img_tensor = self.transform(img_tensor)
+ images.append(img_tensor)
f_px_list.append(f_px)
- images = torch.stack(images)
- f_px_list = torch.stack(f_px_list)
+ images_batch = torch.stack(images)
+ f_px_batch = torch.stack(f_px_list)
- predictions = self.model.infer(
- images, f_px=f_px_list
- ) # tensor of shape (batch_size, 1, H, W)
- batch_results = []
+ predictions = self.model.infer(images_batch, f_px=f_px_batch)
+
+ # Extract individual results from batch
+ depth_batch = predictions["depth"] # shape: (batch_size, H, W)
+ focallength_batch = predictions["focallength_px"] # shape: (batch_size,)
- for j in range(predictions.shape[0]):
- depth_perception = DepthPerceptionI(
- image=batch[j],
- prediction=predictions[j],
- focallength_px=f_px_list[j],
+ for j in range(depth_batch.shape[0]):
+ result = DepthPerceptionResult(
+ depth_prediction=depth_batch[j],
+ focallength_px=focallength_batch[j],
)
- batch_results.append(depth_perception)
- yield batch_results
+ yield result
+ def to(self, device: Union[str, torch.device]) -> "DepthPro":
+ """Move model to specified device."""
+ self.device = device
+ self.model = self.model.to(device)
+ return self
-class DepthProV:
- # TODO: ImageSequence is the wrong type. Should be list of PIL images but requires
- # fixing the for loop as well
- def __init__(self, video: ImageSequence, batch_size: int, **kwargs) -> None:
- model_path = kwargs.get(
- "model_path", project_root_dir() / "checkpoints" / "depth_pro.pt"
- )
- depth_pro.depth_pro.DEFAULT_MONODEPTH_CONFIG_DICT.checkpoint_uri = model_path
- self.device = kwargs.get("device", get_default_device())
- self.model, self.transform = depth_pro.create_model_and_transforms(
- device=self.device
- )
- self.model.eval()
-
- self._depth_map: List[DepthPerceptionI] = []
-
- # split the video into batches
- for i in range(0, len(video), batch_size):
- batch = video[i : i + batch_size]
- images, f_px_list = [], []
- for img in batch:
- img, _, f_px = depth_pro.load_rgb(img)
- img = self.transform(img)
- images.append(img)
- f_px_list.append(f_px)
-
- images = torch.stack(images)
- f_px_list = torch.stack(f_px_list)
-
- predictions = self.model.infer(images, f_px=f_px_list)
-
- for j in range(predictions.shape[0]):
- depth_perception = DepthPerceptionI(
- image=batch[j],
- depth_prediction=predictions[j]["depth"],
- focallength_px=predictions[j]["focallength_px"],
- )
- self._depth_map.append(depth_perception)
diff --git a/graid/src/graid/models/Detectron.py b/graid/src/graid/models/Detectron.py
index e35c440..585e40c 100644
--- a/graid/src/graid/models/Detectron.py
+++ b/graid/src/graid/models/Detectron.py
@@ -1,19 +1,24 @@
+import os
+import tempfile
+import urllib.request
from itertools import islice
from pathlib import Path
-from typing import Iterator, List, Optional, Union
+from collections.abc import Iterator
+from typing import Optional, Union
import cv2
-import matplotlib.pyplot as plt
import numpy as np
import torch
from detectron2 import model_zoo
-from detectron2.config import get_cfg
+from detectron2.config import LazyConfig, get_cfg, instantiate
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultPredictor
from detectron2.structures import BitMasks
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
from PIL import Image
+from detectron2.checkpoint import DetectionCheckpointer
+from detectron2.data import transforms as T
from graid.interfaces.InstanceSegmentationI import (
InstanceSegmentationModelI,
@@ -26,7 +31,6 @@
ObjectDetectionResultI,
)
from graid.utilities.common import (
- convert_batch_to_numpy,
convert_image_to_numpy,
get_default_device,
)
@@ -34,6 +38,27 @@
setup_logger()
+def _resolve_cfg_file(path_or_url: str) -> str:
+ """Resolve config file path, downloading if it's a URL."""
+ if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
+ # Download to a temporary file
+ local_path = tempfile.NamedTemporaryFile(
+ delete=False, suffix=".py" if path_or_url.endswith(".py") else ".yaml"
+ ).name
+ try:
+ urllib.request.urlretrieve(path_or_url, local_path)
+ return local_path
+ except Exception as e:
+ raise RuntimeError(
+ f"Failed to download Detectron2 config from {path_or_url}: {e}"
+ )
+ elif os.path.isfile(path_or_url):
+ return path_or_url
+ else:
+ # Treat as model_zoo shorthand
+ return model_zoo.get_config_file(path_or_url)
+
+
class DetectronBase:
"""Base class for Detectron2 models with shared functionality."""
@@ -41,20 +66,101 @@ def __init__(
self,
config_file: str,
weights_file: str,
- threshold: float = 0.1,
+ threshold: float = 0.5,
device: Optional[Union[str, torch.device]] = None,
):
- # Input Detectron2 config file and weights file
- cfg = get_cfg()
- cfg.MODEL.DEVICE = str(get_default_device()) if device is None else str(device)
- cfg.merge_from_file(model_zoo.get_config_file(config_file))
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
- cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(weights_file)
+ # ------------------------------------------------------------------
+ # Input Detectron2 config & weights β support either:
+ # 1) Built-in model_zoo shorthand (e.g. "COCO-InstanceSegmentation/...yaml")
+ # 2) Local absolute/relative file path
+ # 3) Remote HTTP(S) URL (auto-download to a temp file)
+ # ------------------------------------------------------------------
+
+ cfg_path = _resolve_cfg_file(config_file)
+
+ # ---- setup config -----------------------------------------------
+ if cfg_path.endswith(".py"):
+ # Use LazyConfig for .py files
+ cfg = LazyConfig.load_file(cfg_path)
+ cfg.model.device = (
+ str(get_default_device()) if device is None else str(device)
+ )
+ cfg.model.roi_heads.box_predictor.test_score_thresh = threshold
+ else:
+ # Use traditional config for .yaml files
+ cfg = get_cfg()
+ # allow config files to introduce new keys (e.g. custom backbones)
+ if hasattr(cfg, "set_new_allowed"):
+ cfg.set_new_allowed(True)
+ else:
+ cfg.new_allowed = True
+ cfg.MODEL.DEVICE = (
+ str(get_default_device()) if device is None else str(device)
+ )
+ cfg.merge_from_file(cfg_path)
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
+
+ # ---- resolve weights ---------------------------------------------
+ if weights_file.startswith("http://") or weights_file.startswith("https://"):
+ if cfg_path.endswith(".py"):
+ cfg.model.weights = weights_file # LazyConfig
+ else:
+ cfg.MODEL.WEIGHTS = weights_file # traditional config
+ elif os.path.isfile(weights_file):
+ if cfg_path.endswith(".py"):
+ cfg.model.weights = weights_file # LazyConfig
+ else:
+ cfg.MODEL.WEIGHTS = weights_file # traditional config
+ else:
+ # treat as model_zoo shorthand (will raise if unavailable)
+ weights_url = model_zoo.get_checkpoint_url(weights_file)
+ if cfg_path.endswith(".py"):
+ cfg.model.weights = weights_url # LazyConfig
+ else:
+ cfg.MODEL.WEIGHTS = weights_url # traditional config
+
+ # ---- create predictor --------------------------------------------
+ if cfg_path.endswith(".py"):
+ # For LazyConfig, create a traditional config for DefaultPredictor
+ # DefaultPredictor expects a traditional CfgNode, not LazyConfig
+ traditional_cfg = get_cfg()
+ traditional_cfg.MODEL.DEVICE = (
+ str(get_default_device()) if device is None else str(device)
+ )
+ traditional_cfg.MODEL.WEIGHTS = cfg.model.weights
+ traditional_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
+ # Copy other essential config values
+ traditional_cfg.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
+ traditional_cfg.MODEL.BACKBONE.NAME = "RegNet"
+ traditional_cfg.MODEL.ROI_HEADS.NAME = "StandardROIHeads"
+ traditional_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 80 # COCO classes
+ traditional_cfg.INPUT.FORMAT = "BGR"
+ self._predictor = DefaultPredictor(traditional_cfg)
+ else:
+ self._predictor = DefaultPredictor(cfg)
+
+ # ---- metadata ----------------------------------------------------
+ if cfg_path.endswith(".py"):
+ # For LazyConfig, we need to handle metadata differently
+ try:
+ self._metadata = MetadataCatalog.get(
+ cfg.dataloader.train.dataset.names[0]
+ )
+ except (KeyError, IndexError, AttributeError) as e:
+ # Fallback to COCO metadata if dataset metadata not available
+ logger.warning(f"Could not get dataset metadata, using COCO fallback: {e}")
+ self._metadata = MetadataCatalog.get("coco_2017_train")
+ else:
+ self._metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
+
+ # Store config and other attributes
self.cfg = cfg
- self._predictor = DefaultPredictor(cfg)
- self._metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
self.model_name = config_file
- self.threshold = threshold # Store threshold for reference
+ self.threshold = threshold
+
+ # Store config for cleanup
+ self._cfg_path = cfg_path
+ self._config_file = config_file
def to(self, device: Union[str, torch.device]):
"""Move model to specified device."""
@@ -84,7 +190,7 @@ def __init__(
):
super().__init__(config_file, weights_file, threshold, device)
- def identify_for_image(self, image, **kwargs) -> List[ObjectDetectionResultI]:
+ def identify_for_image(self, image, **kwargs) -> list[ObjectDetectionResultI]:
"""
Run object detection on an image or a batch of images.
Args:
@@ -125,7 +231,7 @@ def identify_for_image(self, image, **kwargs) -> List[ObjectDetectionResultI]:
print(f"image should be HWC: {image.shape}")
return self._process_single_image(image)
- def _process_single_image(self, image: np.ndarray) -> List[ObjectDetectionResultI]:
+ def _process_single_image(self, image: np.ndarray) -> list[ObjectDetectionResultI]:
predictions = self._predictor(image)
if len(predictions) == 0:
@@ -164,7 +270,7 @@ def _process_single_image(self, image: np.ndarray) -> List[ObjectDetectionResult
def identify_for_image_batch(
self, batched_images, debug: bool = False, **kwargs
- ) -> List[ObjectDetectionResultI]:
+ ) -> list[ObjectDetectionResultI]:
assert (
batched_images.ndimension() == 4
), "Input tensor must be of shape (B, C, H, W) in RGB format"
@@ -217,9 +323,9 @@ def identify_for_image_batch(
def identify_for_video(
self,
- video: Union[Iterator[Image.Image], List[Image.Image]],
+ video: Union[Iterator[Image.Image], list[Image.Image]],
batch_size: int = 1,
- ) -> Iterator[List[List[ObjectDetectionResultI]]]:
+ ) -> Iterator[list[list[ObjectDetectionResultI]]]:
"""
Run object detection on a video represented as an iterator or list of images.
Args:
@@ -267,7 +373,7 @@ def identify_for_image(
],
debug: bool = False,
**kwargs,
- ) -> List[InstanceSegmentationResultI]:
+ ) -> list[InstanceSegmentationResultI]:
"""
Run instance segmentation on an image.
Args:
@@ -280,7 +386,7 @@ def identify_for_image(
def _process_single_image(
self, image: np.ndarray
- ) -> List[InstanceSegmentationResultI]:
+ ) -> list[InstanceSegmentationResultI]:
"""Process a single image for instance segmentation."""
predictions = self._predictor(image)
@@ -331,7 +437,7 @@ def identify_for_image_batch(
],
debug: bool = False,
**kwargs,
- ) -> List[List[InstanceSegmentationResultI]]:
+ ) -> list[list[InstanceSegmentationResultI]]:
"""
Run instance segmentation on a batch of images.
Args:
@@ -423,9 +529,9 @@ def identify_for_image_batch(
def identify_for_video(
self,
- video: Union[Iterator[Image.Image], List[Image.Image]],
+ video: Union[Iterator[Image.Image], list[Image.Image]],
batch_size: int = 1,
- ) -> Iterator[List[InstanceSegmentationResultI]]:
+ ) -> Iterator[list[InstanceSegmentationResultI]]:
"""
Run instance segmentation on a video represented as iterator/list of images.
Args:
@@ -452,6 +558,9 @@ def batch_iterator(iterable, n):
def visualize(self, image: Union[np.ndarray, torch.Tensor]):
"""Visualize segmentation results on an image."""
+ # Local import to avoid loading matplotlib unless visualization is needed
+ import matplotlib.pyplot as plt
+
image = convert_image_to_numpy(image)
outputs = self._predictor(image)
v = Visualizer(image[:, :, ::-1], self._metadata, scale=1.2)
@@ -459,3 +568,208 @@ def visualize(self, image: Union[np.ndarray, torch.Tensor]):
plt.figure(figsize=(14, 10))
plt.imshow(cv2.cvtColor(v.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
plt.show()
+
+
+class DetectronLazy(ObjectDetectionModelI):
+ """
+ Detectron2 model using a Python-based 'lazy' config file.
+ This is common for newer models like ViTDet.
+ """
+
+ def __init__(
+ self,
+ config_file: str,
+ weights_file: str,
+ threshold: float = 0.5,
+ device: Optional[Union[str, torch.device]] = None,
+ ):
+ self.device = device if device is not None else get_default_device()
+ self.threshold = threshold
+
+ # Load lazy config
+ cfg = LazyConfig.load(config_file)
+
+ # Set score threshold for Cascade R-CNN, which has multiple roi_heads
+ if hasattr(cfg.model, "roi_heads") and hasattr(cfg.model.roi_heads, "box_head"):
+ # It's a cascade model, iterate through stages
+ if isinstance(cfg.model.roi_heads.box_head, list):
+ for head in cfg.model.roi_heads.box_head:
+ if hasattr(head, "test_score_thresh"):
+ head.test_score_thresh = threshold
+ else: # It's a single head
+ if hasattr(cfg.model.roi_heads, "box_predictor"):
+ if hasattr(cfg.model.roi_heads.box_predictor, "test_score_thresh"):
+ cfg.model.roi_heads.box_predictor.test_score_thresh = threshold
+
+ # Build model
+ self.model = instantiate(cfg.model)
+ self.model.to(self.device)
+ self.model.eval()
+
+ # Load checkpoint
+ checkpointer = DetectionCheckpointer(self.model)
+ checkpointer.load(weights_file)
+
+ self.cfg = cfg
+ self.model_name = Path(config_file).stem
+
+ # Get preprocessing info from config, with defaults
+ self.short_edge_length = 800
+ self.max_size = 1333
+ try:
+ # This path might differ for other lazy configs
+ aug = cfg.dataloader.test.mapper.augmentations[0]
+ if aug.short_edge_length and aug.max_size:
+ self.short_edge_length = aug.short_edge_length
+ self.max_size = aug.max_size
+ except (AttributeError, IndexError, KeyError):
+ pass # Use defaults
+
+ # Get metadata for class names
+ try:
+ # This path might differ for other lazy configs
+ dataset_names = cfg.dataloader.test.dataset.names
+ self.metadata = MetadataCatalog.get(dataset_names)
+ except (AttributeError, IndexError, KeyError):
+ print("Warning: Could not find dataset metadata in config. Fallback to COCO.")
+ self.metadata = MetadataCatalog.get("coco_2017_train")
+
+ def to(self, device: Union[str, torch.device]):
+ """Move model to specified device."""
+ self.device = device
+ self.model.to(self.device)
+
+ def set_threshold(self, threshold: float):
+ """Set confidence threshold for detections."""
+ self.threshold = threshold
+ # Also update the running model config
+ if hasattr(self.cfg.model, "roi_heads") and hasattr(self.cfg.model.roi_heads, "box_head"):
+ if isinstance(self.cfg.model.roi_heads.box_head, list):
+ for head in self.cfg.model.roi_heads.box_head:
+ if hasattr(head, "test_score_thresh"):
+ head.test_score_thresh = threshold
+ else:
+ if hasattr(self.cfg.model.roi_heads, "box_predictor"):
+ if hasattr(self.cfg.model.roi_heads.box_predictor, "test_score_thresh"):
+ self.cfg.model.roi_heads.box_predictor.test_score_thresh = threshold
+
+ def __str__(self):
+ return self.model_name
+
+ def identify_for_image_batch(
+ self, batched_images, debug: bool = False, **kwargs
+ ) -> list[list[ObjectDetectionResultI]]:
+ assert (
+ batched_images.ndimension() == 4
+ ), "Input tensor must be of shape (B, C, H, W) in RGB format"
+
+ list_of_inputs = []
+ original_shapes = []
+
+ for i in range(batched_images.shape[0]):
+ image_tensor_chw_rgb = batched_images[i]
+
+ # Convert to numpy HWC RGB
+ image_np_hwc_rgb = image_tensor_chw_rgb.permute(1, 2, 0).cpu().numpy()
+
+ # Convert RGB to BGR for model
+ image_np_hwc_bgr = cv2.cvtColor(image_np_hwc_rgb, cv2.COLOR_RGB2BGR)
+
+ original_height, original_width = image_np_hwc_bgr.shape[:2]
+ original_shapes.append((original_height, original_width))
+
+ transform_gen = T.ResizeShortestEdge(
+ [self.short_edge_length, self.short_edge_length], self.max_size
+ )
+
+ transformed_image = transform_gen.get_transform(
+ image_np_hwc_bgr
+ ).apply_image(image_np_hwc_bgr)
+ transformed_image_tensor = torch.as_tensor(
+ transformed_image.astype("float32").transpose(2, 0, 1)
+ )
+
+ inputs = {
+ "image": transformed_image_tensor.to(self.device),
+ "height": original_height,
+ "width": original_width,
+ }
+ list_of_inputs.append(inputs)
+
+ with torch.no_grad():
+ predictions = self.model(list_of_inputs)
+
+ formatted_results = []
+ for i, prediction in enumerate(predictions):
+ img_result = []
+ instances = prediction["instances"]
+ image_hw = original_shapes[i]
+
+ if len(instances) > 0:
+ for j in range(len(instances)):
+ box = instances.pred_boxes[j].tensor.cpu().numpy().tolist()[0]
+ score = instances.scores[j].item()
+ cls_id = int(instances.pred_classes[j].item())
+ label = self.metadata.thing_classes[cls_id]
+
+ odr = ObjectDetectionResultI(
+ score=score,
+ cls=cls_id,
+ label=label,
+ bbox=box,
+ image_hw=image_hw,
+ bbox_format=BBox_Format.XYXY,
+ )
+ img_result.append(odr)
+
+ formatted_results.append(img_result)
+
+ return formatted_results
+
+ def identify_for_image(self, image, **kwargs) -> list[ObjectDetectionResultI]:
+ """Runs detection on a single image."""
+ numpy_image = convert_image_to_numpy(image) # This should be RGB HWC
+ # to tensor, CHW
+ image_tensor = torch.from_numpy(
+ np.ascontiguousarray(numpy_image.transpose(2, 0, 1))
+ )
+ if image_tensor.ndimension() == 3:
+ image_tensor = image_tensor.unsqueeze(0)
+
+ results_batch = self.identify_for_image_batch(image_tensor, **kwargs)
+ return results_batch[0] if results_batch else []
+
+ def identify_for_video(
+ self,
+ video: Union[Iterator[Image.Image], list[Image.Image]],
+ batch_size: int = 1,
+ ) -> Iterator[list[list[ObjectDetectionResultI]]]:
+ """
+ Run object detection on a video represented as an iterator or list of images.
+ Args:
+ video: An iterator or list of PIL images.
+ batch_size: Number of images to process at a time.
+ Returns:
+ An iterator of lists of lists of ObjectDetectionResultI, where the outer
+ list represents the batches, the middle list represents frames, and the
+ inner list represents detections within a frame.
+ """
+
+ def batch_iterator(iterable, n):
+ iterator = iter(iterable)
+ return iter(lambda: list(islice(iterator, n)), [])
+
+ video_iterator = batch_iterator(video, batch_size)
+
+ for batch in video_iterator:
+ if not batch: # End of iterator
+ break
+
+ # Convert all images in batch to numpy arrays
+ numpy_batch = [convert_image_to_numpy(img) for img in batch]
+
+ # Convert to a tensor (B, H, W, C) -> (B, C, H, W)
+ tensor_batch = torch.from_numpy(np.array(numpy_batch)).permute(0, 3, 1, 2)
+
+ batch_results = self.identify_for_image_batch(tensor_batch)
+ yield batch_results
diff --git a/graid/src/graid/models/Ultralytics.py b/graid/src/graid/models/Ultralytics.py
index 4e7df7f..0876c9a 100644
--- a/graid/src/graid/models/Ultralytics.py
+++ b/graid/src/graid/models/Ultralytics.py
@@ -24,6 +24,7 @@ class Yolo(ObjectDetectionModelI):
def __init__(self, model: Union[str, Path]) -> None:
self.model_name = model
self._model = YOLO(model)
+ self.threshold = 0.25 # Default YOLO confidence threshold
def identify_for_image(
self,
@@ -64,7 +65,9 @@ def identify_for_image(
image = image[:, [2, 1, 0], ...]
image = image / 255.0
with torch.no_grad():
- predictions = self._model.predict(image, verbose, **kwargs)
+ predictions = self._model.predict(
+ image, verbose, conf=self.threshold, **kwargs
+ )
# undo the conversion
image = image[:, [2, 1, 0], ...]
image = image * 255.0
@@ -160,7 +163,7 @@ def _batch_iterator(iterable, n):
break
images = torch.stack([torch.tensor(np.array(img)) for img in batch])
- batch_results = self._model(images)
+ batch_results = self._model(images, conf=self.threshold)
boxes_across_frames = []
@@ -225,6 +228,7 @@ def __init__(self, model: Union[str, Path]) -> None:
super().__init__()
self._model = YOLO(model)
self._instance_count = {}
+ self.threshold = 0.25 # Default YOLO confidence threshold
def identify_for_image(
self,
@@ -416,3 +420,6 @@ def _batch_iterator(iterable, n):
def to(self, device: Union[str, torch.device]):
self._model.to(device)
+
+ def set_threshold(self, threshold: float):
+ self.threshold = threshold
diff --git a/graid/src/graid/models/WBF.py b/graid/src/graid/models/WBF.py
new file mode 100644
index 0000000..c68b97d
--- /dev/null
+++ b/graid/src/graid/models/WBF.py
@@ -0,0 +1,506 @@
+from collections.abc import Iterator
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from ensemble_boxes import weighted_boxes_fusion
+
+from graid.interfaces.ObjectDetectionI import (
+ BBox_Format,
+ ObjectDetectionModelI,
+ ObjectDetectionResultI,
+)
+from graid.models.Detectron import Detectron_obj
+from graid.models.MMDetection import MMdetection_obj
+from graid.models.Ultralytics import Yolo
+from graid.utilities.coco import coco_labels
+from graid.utilities.common import convert_image_to_numpy
+
+
+class WBF(ObjectDetectionModelI):
+ """Weighted Box Fusion ensemble across Detectron2, Ultralytics and MMDetection models."""
+
+ def __init__(
+ self,
+ detectron2_models: Optional[list["Detectron_obj"]] = None,
+ ultralytics_models: Optional[list["Yolo"]] = None,
+ mmdet_models: Optional[list["MMdetection_obj"]] = None,
+ model_weights: Optional[list[float]] = None,
+ iou_threshold: float = 0.55,
+ skip_box_threshold: float = 0.0,
+ ) -> None:
+ """Create a new Weighted Box Fusion ensemble.
+
+ Args:
+ detectron2_models: List of Detectron2 object detection models.
+ ultralytics_models: List of Ultralytics YOLO object detection models.
+ mmdet_models: List of MMDetection object detection models.
+ model_weights: Per-model weight for WBF (same ordering as the
+ concatenation of the three model lists). If
+ ``None`` all models get uniform weight.
+ iou_threshold: IoU threshold for box matching inside WBF.
+ skip_box_threshold:Boxes with *score < skip_box_threshold* will be
+ ignored by WBF.
+ """
+ super().__init__()
+
+ self.detectron2_models = detectron2_models or []
+ self.mmdet_models = mmdet_models or []
+ self.ultralytics_models = ultralytics_models or []
+
+ self._all_models: list[ObjectDetectionModelI] = (
+ self.detectron2_models + self.mmdet_models + self.ultralytics_models
+ ) # Flatten in a deterministic order so that weight list lines up
+
+ if model_weights is None:
+ self.model_weights = [1.0] * len(self._all_models)
+ else:
+ assert len(model_weights) == len(
+ self._all_models
+ ), "Length of model_weights must match total number of models."
+ self.model_weights = model_weights
+
+ self.iou_threshold = iou_threshold
+ self.skip_box_threshold = skip_box_threshold
+
+ self.model_name = "WBF_Ensemble"
+
+ # ---------------------------------------------------------------------
+ # Helper extraction functions
+ # ---------------------------------------------------------------------
+ @staticmethod
+ def _normalize_boxes_basic(
+ boxes: np.ndarray, image_hw: tuple[int, int]
+ ) -> np.ndarray:
+ """Convert absolute XYXY boxes to normalized format [0-1] without corrections."""
+ h, w = image_hw
+ # Ensure float32 for downstream operations
+ boxes_norm = boxes.copy().astype(np.float32)
+ boxes_norm[:, [0, 2]] /= w # x coords
+ boxes_norm[:, [1, 3]] /= h # y coords
+ # Clip to 0-1 just in case
+ boxes_norm = np.clip(boxes_norm, 0.0, 1.0)
+ return boxes_norm
+
+ @staticmethod
+ def _has_reversed_boxes(boxes: np.ndarray) -> bool:
+ """Check if boxes have reversed coordinates (x1 > x2 or y1 > y2)."""
+ if len(boxes) == 0:
+ return False
+ x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
+ return np.any(x2 < x1) or np.any(y2 < y1)
+
+ @staticmethod
+ def _fix_reversed_boxes(boxes: np.ndarray) -> np.ndarray:
+ """Fix reversed boxes by swapping coordinates where necessary."""
+ boxes_fixed = boxes.copy()
+
+ # Fix x coordinates where x1 > x2
+ swapped_x = boxes_fixed[:, 2] < boxes_fixed[:, 0]
+ if np.any(swapped_x):
+ # Swap x1 and x2 for reversed boxes
+ temp = boxes_fixed[swapped_x, 0].copy()
+ boxes_fixed[swapped_x, 0] = boxes_fixed[swapped_x, 2]
+ boxes_fixed[swapped_x, 2] = temp
+
+ # Fix y coordinates where y1 > y2
+ swapped_y = boxes_fixed[:, 3] < boxes_fixed[:, 1]
+ if np.any(swapped_y):
+ # Swap y1 and y2 for reversed boxes
+ temp = boxes_fixed[swapped_y, 1].copy()
+ boxes_fixed[swapped_y, 1] = boxes_fixed[swapped_y, 3]
+ boxes_fixed[swapped_y, 3] = temp
+
+ return boxes_fixed
+
+ def _normalize_boxes_detectron2(
+ self, boxes: np.ndarray, image_hw: tuple[int, int]
+ ) -> np.ndarray:
+ """Normalize boxes from Detectron2 models with targeted corrections."""
+ boxes_norm = self._normalize_boxes_basic(boxes, image_hw)
+
+ # Check if this model produces reversed boxes and fix if needed
+ if self._has_reversed_boxes(boxes_norm):
+ boxes_norm = self._fix_reversed_boxes(boxes_norm)
+
+ return boxes_norm
+
+ def _normalize_boxes_ultralytics(
+ self, boxes: np.ndarray, image_hw: tuple[int, int]
+ ) -> np.ndarray:
+ """Normalize boxes from Ultralytics models with targeted corrections."""
+ boxes_norm = self._normalize_boxes_basic(boxes, image_hw)
+
+ # Check if this model produces reversed boxes and fix if needed
+ if self._has_reversed_boxes(boxes_norm):
+ boxes_norm = self._fix_reversed_boxes(boxes_norm)
+
+ return boxes_norm
+
+ def _normalize_boxes_mmdet(
+ self, boxes: np.ndarray, image_hw: tuple[int, int]
+ ) -> np.ndarray:
+ """Normalize boxes from MMDetection models with targeted corrections."""
+ boxes_norm = self._normalize_boxes_basic(boxes, image_hw)
+
+ # Check if this model produces reversed boxes and fix if needed
+ if self._has_reversed_boxes(boxes_norm):
+ boxes_norm = self._fix_reversed_boxes(boxes_norm)
+
+ return boxes_norm
+
+ def _extract_detectron2_raw_predictions(
+ self, model: "Detectron_obj", image: np.ndarray
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Extract raw predictions (boxes, scores, classes) from a Detectron2 model.
+
+ NOTE: Detectron2 simplifies inference and returns post-NMS instances by
+ default. For a quick implementation we temporarily raise the NMS
+ threshold to 1.0, which effectively disables NMS while keeping the
+ existing pipeline intact.
+ """
+ # Backup original thresholds
+ orig_nms = model.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST
+ orig_score_thr = model.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST
+
+ # Disable NMS & lower score threshold to capture as many boxes as possible
+ model.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 1.0
+ model.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.001
+ # Re-create predictor with new cfg (cheap β only wraps forward pass)
+ predictor = model._predictor.__class__(model.cfg)
+
+ outputs = predictor(image) # dict with key "instances"
+ instances = outputs.get("instances", None)
+
+ # Restore cfg (important for subsequent calls)
+ model.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = orig_nms
+ model.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = orig_score_thr
+ model._predictor = predictor.__class__(model.cfg) # revert predictor
+
+ if instances is None or len(instances) == 0:
+ return (np.empty((0, 4), dtype=np.float32), np.empty(0), np.empty(0))
+
+ boxes = instances.pred_boxes.tensor.cpu().numpy().astype(np.float32)
+ scores = instances.scores.cpu().numpy().astype(np.float32)
+ classes = instances.pred_classes.cpu().numpy().astype(int)
+ return boxes, scores, classes
+
+ def _extract_ultralytics_raw_predictions(
+ self, model: "Yolo", image: np.ndarray
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Extract raw predictions from Ultralytics YOLO before NMS."""
+ # Ensure the model predictor is initialised
+ model._model.predict(image, verbose=False) # warm-up call (does NMS)
+
+ # Pre-process like Ultralytics internal pipeline
+ im_tensor = model._model.predictor.preprocess(np.array(image)[np.newaxis, ...])
+ infer_out = model._model.predictor.inference(im_tensor)
+ # Ultralytics may return (preds, proto) or (preds, proto, loss). Handle both cases.
+ if isinstance(infer_out, tuple):
+ preds = infer_out[0]
+ else:
+ preds = infer_out
+
+ if preds is None or len(preds) == 0:
+ return (np.empty((0, 4), dtype=np.float32), np.empty(0), np.empty(0))
+
+ # `preds` has shape (batch, num_boxes, 6) OR (batch, 1, num_boxes, 6) depending on model.
+ pred0 = preds[0] # take first batch element
+ if pred0.ndim == 3 and pred0.shape[0] == 1:
+ # Some models return shape (1, num_boxes, 6)
+ pred0 = pred0[0]
+ elif pred0.ndim == 3 and pred0.shape[-1] == 6:
+ # shape (1, num_boxes, 6) or (num_levels, num_boxes, 6)
+ pred0 = pred0.reshape(-1, 6) # Flatten any leading dims
+
+ boxes = pred0[:, :4].cpu().numpy().astype(np.float32)
+ scores = pred0[:, 4].cpu().numpy().astype(np.float32)
+ classes = pred0[:, 5].cpu().numpy().astype(int)
+
+ return boxes, scores, classes
+
+ def _extract_mmdet_raw_predictions(
+ self, model: "MMdetection_obj", image: np.ndarray
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Extract raw predictions from MMDetection model prior to NMS."""
+ from mmdet.apis import (
+ inference_detector, # local import to avoid heavy dep if unused
+ )
+
+ cfg_test = model._model.cfg.model.test_cfg
+ original_nms = None
+ try:
+ # Two-stage detectors often keep NMS here
+ original_nms = cfg_test.rcnn.nms.iou_threshold # type: ignore
+ cfg_test.rcnn.nms.iou_threshold = 1.0 # type: ignore
+ except AttributeError:
+ # Single-stage / transformer models may store it directly under test_cfg
+ if hasattr(cfg_test, "nms") and hasattr(cfg_test.nms, "iou_threshold"):
+ original_nms = cfg_test.nms.iou_threshold # type: ignore
+ cfg_test.nms.iou_threshold = 1.0 # type: ignore
+
+ predictions = inference_detector(model._model, image)
+ pred = predictions[0] if isinstance(predictions, list) else predictions
+
+ if original_nms is not None:
+ try:
+ cfg_test.rcnn.nms.iou_threshold = original_nms # type: ignore
+ except AttributeError:
+ if hasattr(cfg_test, "nms"):
+ cfg_test.nms.iou_threshold = original_nms # type: ignore
+
+ instances = pred.pred_instances
+ boxes = instances.bboxes.cpu().numpy().astype(np.float32)
+ scores = instances.scores.cpu().numpy().astype(np.float32)
+ classes = instances.labels.cpu().numpy().astype(int)
+ return boxes, scores, classes
+
+ # ------------------------------------------------------------------
+ # Core ensemble routine
+ # ------------------------------------------------------------------
+ def _gather_all_predictions(
+ self, image: np.ndarray
+ ) -> tuple[list[list[float]], list[list[float]], list[list[int]]]:
+ """Collect raw predictions from every child model, in order."""
+ all_boxes: list[list[float]] = []
+ all_scores: list[list[float]] = []
+ all_labels: list[list[int]] = []
+
+ # Detectron2 models -------------------------------------------
+ for mdl in self.detectron2_models:
+ boxes, scores, classes = self._extract_detectron2_raw_predictions(
+ mdl, image
+ )
+ if len(boxes) > 0:
+ h, w = image.shape[:2]
+ normalized = self._normalize_boxes_detectron2(boxes, (h, w))
+ flat_boxes = [
+ [
+ (
+ float(coord[0])
+ if isinstance(coord, (list, tuple, np.ndarray))
+ else float(coord)
+ )
+ for coord in b
+ ]
+ for b in normalized.tolist()
+ ]
+ all_boxes.append(flat_boxes)
+ all_scores.append(
+ [
+ (
+ float(s[0])
+ if isinstance(s, (list, tuple, np.ndarray))
+ else float(s)
+ )
+ for s in scores.tolist()
+ ]
+ )
+ all_labels.append(
+ [
+ (
+ int(c[0])
+ if isinstance(c, (list, tuple, np.ndarray))
+ else int(c)
+ )
+ for c in classes.tolist()
+ ]
+ )
+
+ # Ultralytics models ------------------------------------------
+ for mdl in self.ultralytics_models:
+ boxes, scores, classes = self._extract_ultralytics_raw_predictions(
+ mdl, image
+ )
+ if len(boxes) > 0:
+ h, w = image.shape[:2]
+ normalized = self._normalize_boxes_ultralytics(boxes, (h, w))
+ flat_boxes = [
+ [
+ (
+ float(coord[0])
+ if isinstance(coord, (list, tuple, np.ndarray))
+ else float(coord)
+ )
+ for coord in b
+ ]
+ for b in normalized.tolist()
+ ]
+ all_boxes.append(flat_boxes)
+ all_scores.append(
+ [
+ (
+ float(s[0])
+ if isinstance(s, (list, tuple, np.ndarray))
+ else float(s)
+ )
+ for s in scores.tolist()
+ ]
+ )
+ all_labels.append(
+ [
+ (
+ int(c[0])
+ if isinstance(c, (list, tuple, np.ndarray))
+ else int(c)
+ )
+ for c in classes.tolist()
+ ]
+ )
+
+ # MMDetection models ------------------------------------------
+ for mdl in self.mmdet_models:
+ boxes, scores, classes = self._extract_mmdet_raw_predictions(mdl, image)
+ if len(boxes) > 0:
+ h, w = image.shape[:2]
+ normalized = self._normalize_boxes_mmdet(boxes, (h, w))
+ flat_boxes = [
+ [
+ (
+ float(coord[0])
+ if isinstance(coord, (list, tuple, np.ndarray))
+ else float(coord)
+ )
+ for coord in b
+ ]
+ for b in normalized.tolist()
+ ]
+ all_boxes.append(flat_boxes)
+ all_scores.append(
+ [
+ (
+ float(s[0])
+ if isinstance(s, (list, tuple, np.ndarray))
+ else float(s)
+ )
+ for s in scores.tolist()
+ ]
+ )
+ all_labels.append(
+ [
+ (
+ int(c[0])
+ if isinstance(c, (list, tuple, np.ndarray))
+ else int(c)
+ )
+ for c in classes.tolist()
+ ]
+ )
+
+ return all_boxes, all_scores, all_labels
+
+ def _fuse_predictions(
+ self, image: np.ndarray
+ ) -> dict[str, Union[np.ndarray, list[float], list[int]]]:
+ """Run Weighted Box Fusion on a single image and return fused detections."""
+ all_boxes, all_scores, all_labels = self._gather_all_predictions(image)
+
+ if not all_boxes:
+ return {
+ "boxes": np.empty((0, 4)),
+ "scores": np.empty(0),
+ "labels": np.empty(0, dtype=int),
+ }
+
+ fused_boxes, fused_scores, fused_labels = weighted_boxes_fusion(
+ all_boxes,
+ all_scores,
+ all_labels,
+ weights=self.model_weights,
+ iou_thr=self.iou_threshold,
+ skip_box_thr=self.skip_box_threshold,
+ )
+
+ # Convert back to pixel coordinates
+ h, w = image.shape[:2]
+ if len(fused_boxes) > 0:
+ fused_boxes[:, [0, 2]] *= w # x coords
+ fused_boxes[:, [1, 3]] *= h # y coords
+
+ return {
+ "boxes": fused_boxes,
+ "scores": fused_scores,
+ "labels": fused_labels.astype(int),
+ }
+
+ # ------------------------------------------------------------------
+ # Public API (ObjectDetectionModelI)
+ # ------------------------------------------------------------------
+ def identify_for_image(
+ self,
+ image: Union[np.ndarray, torch.Tensor],
+ debug: bool = False,
+ **kwargs,
+ ) -> list[ObjectDetectionResultI]:
+ image_np = convert_image_to_numpy(image)
+ fused = self._fuse_predictions(image_np)
+
+ boxes = fused["boxes"]
+ scores = fused["scores"]
+ labels = fused["labels"]
+
+ results: list[ObjectDetectionResultI] = []
+ for box, score, cls_id in zip(boxes, scores, labels):
+ results.append(
+ ObjectDetectionResultI(
+ score=float(score),
+ cls=int(cls_id),
+ label=coco_labels.get(int(cls_id), str(int(cls_id))),
+ bbox=box.tolist(),
+ image_hw=image_np.shape[:2],
+ bbox_format=BBox_Format.XYXY,
+ )
+ )
+
+ # Optionally visualize β left for caller or debug flag
+ if debug and len(results) == 0:
+ print("[WBF] No detections for this image.")
+
+ return results
+
+ def identify_for_image_batch(
+ self,
+ image: Union[np.ndarray, torch.Tensor],
+ debug: bool = False,
+ **kwargs,
+ ) -> list[list[ObjectDetectionResultI]]:
+ if isinstance(image, torch.Tensor):
+ # Assume batch shape (B, C, H, W) or (C, H, W)
+ if image.ndimension() == 3:
+ return [self.identify_for_image(image, debug=debug)]
+ elif image.ndimension() == 4:
+ return [self.identify_for_image(img, debug=debug) for img in image]
+ else:
+ raise ValueError("Unsupported tensor shape for batch images")
+ elif isinstance(image, list):
+ return [self.identify_for_image(img, debug=debug) for img in image]
+ else:
+ # Single image numpy array
+ return [self.identify_for_image(image, debug=debug)]
+
+ def identify_for_video(
+ self,
+ video: Union[
+ Iterator[Union[np.ndarray, torch.Tensor]],
+ list[Union[np.ndarray, torch.Tensor]],
+ ],
+ batch_size: int = 1,
+ ) -> Iterator[list[Optional[ObjectDetectionResultI]]]:
+ # Simple implementation: iterate frame-by-frame (no batching)
+ for frame in video:
+ yield self.identify_for_image(frame)
+
+ # ------------------------------------------------------------------
+ # Utilities
+ # ------------------------------------------------------------------
+ def to(self, device: Union[str, torch.device]):
+ """Move underlying models to given device."""
+ for mdl in self._all_models:
+ mdl.to(device)
+
+ def set_threshold(self, threshold: float):
+ """Set skip_box_threshold (score threshold) for WBF."""
+ self.skip_box_threshold = threshold
+
+ def __str__(self):
+ return self.model_name
diff --git a/graid/src/graid/questions/ObjectDetectionQ.py b/graid/src/graid/questions/ObjectDetectionQ.py
index 4a38805..cb6f4d5 100644
--- a/graid/src/graid/questions/ObjectDetectionQ.py
+++ b/graid/src/graid/questions/ObjectDetectionQ.py
@@ -1,7 +1,9 @@
import logging
import math
+import random
+import numpy as np
from abc import ABC, abstractmethod
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import Callable, Optional
import torch
from PIL import Image
@@ -17,16 +19,17 @@
class Question(ABC):
@abstractmethod
def __init__(
- self, question: str, variables: List[str], predicates: List[Callable]
+ self, question: str, variables: list[str], predicates: list[Callable[..., bool]]
) -> None:
self.question = question
self.variables = variables
self.predicates = predicates
+ self.other_question: Optional[str] = None
def is_applicable(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
+ detections: list[ObjectDetectionResultI],
) -> bool:
"""
Check if the question is applicable to the given image and detections.
@@ -46,134 +49,47 @@ def is_applicable(
def _find_extremes(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]]:
- # for every kind (label) of object in the image, find the right most detection
- # label -> (center of bbox (x, y), bounding box (x1, y1, x2, y2))
- right_most_detections: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
- # also the left most
- left_most_detections: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
- # also the top most
- top_most_detections: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
- # also the lowest
- bottom_most_detections: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
-
- for detection in detections:
- class_name = detection.label
- center_box = detection.get_center() # shape == (# of boxes, 2)
- bbox = detection.as_xyxy() # shape == (# of boxes, 4)
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # find the right most bbox using the center of the bbox
- n = class_name.shape[0]
- for i in range(n):
- curr_class_name = class_name[i]
- curr_center_box = center_box[i]
- curr_bbox = bbox[i]
-
- # right most
- if curr_class_name not in right_most_detections:
- right_most_detections[curr_class_name] = (
- curr_center_box,
- curr_bbox,
- )
- else:
- if (
- curr_center_box[0]
- > right_most_detections[curr_class_name][0]
- ):
- right_most_detections[curr_class_name] = (
- curr_center_box,
- curr_bbox,
- )
-
- # left most
- if curr_class_name not in left_most_detections:
- left_most_detections[curr_class_name] = (
- curr_center_box,
- curr_bbox,
- )
- else:
- if (
- curr_center_box[0]
- < left_most_detections[curr_class_name][0]
- ):
- left_most_detections[curr_class_name] = (
- curr_center_box,
- curr_bbox,
- )
-
- # top most
- if curr_class_name not in top_most_detections:
- top_most_detections[curr_class_name] = (
- curr_center_box,
- curr_bbox,
- )
- else:
- if curr_center_box[1] < top_most_detections[curr_class_name][1]:
- top_most_detections[curr_class_name] = (
- curr_center_box,
- curr_bbox,
- )
-
- # bottom most
- if curr_class_name not in bottom_most_detections:
- bottom_most_detections[curr_class_name] = (
- curr_center_box,
- curr_bbox,
- )
- else:
- if (
- curr_center_box[1]
- > bottom_most_detections[curr_class_name][1]
- ):
- bottom_most_detections[curr_class_name] = (
- curr_center_box,
- curr_bbox,
- )
-
- else: # type(class_name) == str
- # bbox would be shape (1, 4) so let's just grab the only element
- # right most
- if class_name not in right_most_detections:
- right_most_detections[class_name] = (center_box[0], bbox[0])
- else:
- if center_box[0][0] > right_most_detections[class_name][0][0]:
- right_most_detections[class_name] = (center_box[0], bbox[0])
-
- # left most
- if class_name not in left_most_detections:
- left_most_detections[class_name] = (center_box[0], bbox[0])
- else:
- if center_box[0][0] < left_most_detections[class_name][0][0]:
- left_most_detections[class_name] = (center_box[0], bbox[0])
-
- # top most
- if class_name not in top_most_detections:
- top_most_detections[class_name] = (center_box[0], bbox[0])
- else:
- if center_box[0][1] < top_most_detections[class_name][0][1]:
- top_most_detections[class_name] = (center_box[0], bbox[0])
-
- # bottom most
- if class_name not in bottom_most_detections:
- bottom_most_detections[class_name] = (center_box[0], bbox[0])
- else:
- if center_box[0][1] > bottom_most_detections[class_name][0][1]:
- bottom_most_detections[class_name] = (center_box[0], bbox[0])
-
- return [
- left_most_detections,
- right_most_detections,
- top_most_detections,
- bottom_most_detections,
- ]
+ detections: list[ObjectDetectionResultI],
+ ) -> list[dict[str, tuple[torch.Tensor, torch.Tensor]]]:
+ # Compute extremes using precomputed context
+ ctx = ObjectDetectionUtils.get_current_context()
+ if ctx is None:
+ ctx = ObjectDetectionUtils.build_question_context(image, detections)
+ ObjectDetectionUtils.set_current_context(ctx)
+
+ labels: list[str] = ctx.get("labels", [])
+ centers: torch.Tensor = ctx.get("centers", torch.empty((0, 2)))
+ bxyxy: torch.Tensor = ctx.get("bboxes_xyxy", torch.empty((0, 4)))
+
+ left_most: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
+ right_most: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
+ top_most: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
+ bottom_most: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
+
+ for idx, cls in enumerate(labels):
+ c = centers[idx]
+ bb = bxyxy[idx]
+ # Left-most (min x-center)
+ if cls not in left_most or c[0] < left_most[cls][0][0]:
+ left_most[cls] = (c, bb)
+ # Right-most (max x-center)
+ if cls not in right_most or c[0] > right_most[cls][0][0]:
+ right_most[cls] = (c, bb)
+ # Top-most (min y-center)
+ if cls not in top_most or c[1] < top_most[cls][0][1]:
+ top_most[cls] = (c, bb)
+ # Bottom-most (max y-center)
+ if cls not in bottom_most or c[1] > bottom_most[cls][0][1]:
+ bottom_most[cls] = (c, bb)
+
+ return [left_most, right_most, top_most, bottom_most]
@abstractmethod
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
"""
Apply the question to the image and detections.
@@ -204,196 +120,287 @@ def __repr__(self):
return representation
+ # Helper utilities to reduce duplication across questions
+ def _get_ctx(self, image: Image.Image, detections: list[ObjectDetectionResultI]):
+ ctx = ObjectDetectionUtils.get_current_context()
+ if ctx is None:
+ # Build on-demand if missing (e.g., tests)
+ ctx = ObjectDetectionUtils.build_question_context(image, detections)
+ ObjectDetectionUtils.set_current_context(ctx)
+ return ctx
+
+ def iterate_labels(self, image: Image.Image, detections: list[ObjectDetectionResultI]):
+ """Yield (label_str, detection) over flattened detections."""
+ ctx = self._get_ctx(image, detections)
+ flat = ctx.get("detections", [])
+ labels = ctx.get("labels", [])
+ for det, lbl in zip(flat, labels):
+ yield lbl, det
+
+ def iterate_label_indices(self, image: Image.Image, detections: list[ObjectDetectionResultI]):
+ """Yield (idx, label_str, detection) aligned with context order."""
+ ctx = self._get_ctx(image, detections)
+ flat = ctx.get("detections", [])
+ labels = ctx.get("labels", [])
+ for idx, (det, lbl) in enumerate(zip(flat, labels)):
+ yield idx, lbl, det
+
+ def classes_with_single_detection(self, image: Image.Image, detections: list[ObjectDetectionResultI]) -> set[str]:
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ return {k for k, v in counts.items() if v == 1}
+
+ def sort_detections_by(self, image: Image.Image, detections: list[ObjectDetectionResultI], key: str, reverse: bool = False) -> list[int]:
+ """Sort detection indices by key: 'x1'|'x2'|'cx'|'cy'."""
+ ctx = self._get_ctx(image, detections)
+ bxyxy: torch.Tensor = ctx.get("bboxes_xyxy", torch.empty((0, 4)))
+ centers: torch.Tensor = ctx.get("centers", torch.empty((0, 2)))
+ if key == "x1":
+ metric = bxyxy[:, 0]
+ elif key == "x2":
+ metric = bxyxy[:, 2]
+ elif key == "cx":
+ metric = centers[:, 0]
+ elif key == "cy":
+ metric = centers[:, 1]
+ else:
+ raise ValueError(f"Unsupported sort key: {key}")
+ values = metric.cpu().tolist()
+ return sorted(range(len(values)), key=lambda i: values[i], reverse=reverse)
+
+ def per_class_reduce(self, image: Image.Image, detections: list[ObjectDetectionResultI], tensor_key: str, reduce: str = "max") -> dict[str, float]:
+ """Reduce per-class over a tensor from ctx: 'areas'|'aspects'."""
+ ctx = self._get_ctx(image, detections)
+ values: torch.Tensor = ctx.get(tensor_key, torch.empty((0,)))
+ class_to_indices: dict[str, list[int]] = ctx.get("class_to_indices", {})
+ result: dict[str, float] = {}
+ for cls, idxs in class_to_indices.items():
+ if not idxs:
+ continue
+ v = values[idxs]
+ if reduce == "max":
+ agg = float(v.max().item())
+ elif reduce == "min":
+ agg = float(v.min().item())
+ elif reduce == "sum":
+ agg = float(v.sum().item())
+ else:
+ raise ValueError(f"Unsupported reduce op: {reduce}")
+ result[cls] = agg
+ return result
+
class ObjectDetectionPredicates:
@staticmethod
def at_least_one_single_detection(
- image: Image, detections: List[ObjectDetectionResultI]
+ image: Image.Image, detections: list[ObjectDetectionResultI]
) -> bool:
- if len(detections) == 0:
- return False
- if len(detections) == 1:
- # if there is only one detection, we can just return True
- return True
-
- # check if there is at least one detection with a single class
- counts = {}
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # need to iterate over the tensor to get the class names
- for class_name in class_name:
- counts[class_name] = counts.get(class_name, 0) + 1
+ ctx = ObjectDetectionUtils.get_current_context()
+ if ctx is not None:
+ counts = ctx.get("counts", {})
+ if not counts and "detections" in ctx:
+ # No detections
+ return False
+ return any(c == 1 for c in counts.values()) or (len(ctx.get("detections", [])) == 1)
+
+ # Fallback without context
+ if len(detections) <= 1:
+ return len(detections) == 1
+ counts: dict[str, int] = {}
+ for det in detections:
+ lbl = det.label
+ if isinstance(lbl, torch.Tensor):
+ for l in lbl:
+ key = str(l.item())
+ counts[key] = counts.get(key, 0) + 1
else:
- counts[class_name] = counts.get(class_name, 0) + 1
-
- return any(count == 1 for count in counts.values())
+ key = str(lbl)
+ counts[key] = counts.get(key, 0) + 1
+ return any(c == 1 for c in counts.values())
@staticmethod
def at_least_x_many_class_detections(
- image: Image, detections: List[ObjectDetectionResultI], x: int
+ image: Image.Image, detections: list[ObjectDetectionResultI], x: int
) -> bool:
- counts = {}
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # need to iterate over the tensor to get the class names
- for single_class_name in class_name:
- counts[single_class_name] = counts.get(single_class_name, 0) + 1
+ ctx = ObjectDetectionUtils.get_current_context()
+ if ctx is not None:
+ counts = ctx.get("counts", {})
+ return len(counts) >= x
+
+ counts: dict[str, int] = {}
+ for det in detections:
+ lbl = det.label
+ if isinstance(lbl, torch.Tensor):
+ for l in lbl:
+ key = str(l.item())
+ counts[key] = counts.get(key, 0) + 1
else:
- counts[class_name] = counts.get(class_name, 0) + 1
-
+ key = str(lbl)
+ counts[key] = counts.get(key, 0) + 1
return len(counts) >= x
@staticmethod
def at_least_x_detections(
- image: Image, detections: List[ObjectDetectionResultI], x: int
- ) -> bool:
- return len(detections) >= 3
-
- @staticmethod
- def at_least_x_detections(
- image: Image, detections: List[ObjectDetectionResultI], x: int
+ image: Image.Image, detections: list[ObjectDetectionResultI], x: int
) -> bool:
- return len(detections) >= 3
+ from graid.interfaces.ObjectDetectionI import ObjectDetectionUtils
+ ctx = ObjectDetectionUtils.get_current_context()
+ if ctx is not None:
+ return len(ctx.get("detections", [])) >= x
+ return len(detections) >= x
@staticmethod
def exists_non_overlapping_detections(
- image: Image, detections: List[ObjectDetectionResultI]
+ image: Image.Image, detections: list[ObjectDetectionResultI]
) -> bool:
- for i, detection1 in enumerate(detections):
- for j in range(i + 1, len(detections)):
- detection2 = detections[j]
+ from graid.interfaces.ObjectDetectionI import ObjectDetectionUtils
+ ctx = ObjectDetectionUtils.get_current_context()
+ # Use cached computation
+ if ctx is not None:
+ cache = ctx.setdefault("pred_cache", {})
+ key = ("exists_non_overlapping_detections",)
+ if key in cache:
+ return cache[key]
+
+ # Work on flattened detections for clarity
+ flat: list[ObjectDetectionResultI] = ctx.get("detections", [])
+ # Group by label
+ by_label: dict[str, list[ObjectDetectionResultI]] = {}
+ for det in flat:
+ lbl = str(det.label) if not isinstance(det.label, torch.Tensor) else str(det.label.item())
+ by_label.setdefault(lbl, []).append(det)
+ labels = list(by_label.keys())
+ # Try pairs of different classes; early out on first non-overlap
+ for i in range(len(labels)):
+ for j in range(i + 1, len(labels)):
+ for d1 in by_label[labels[i]]:
+ for d2 in by_label[labels[j]]:
+ iou = ObjectDetectionUtils.pairwise_iou(d1, d2)
+ if iou.max() == 0:
+ cache[key] = True
+ return True
+ cache[key] = False
+ return False
- if detection1.label != detection2.label:
- iou: torch.Tensor = ObjectDetectionUtils.pairwise_iou(
- detection1, detection2
- )
+ # Fallback without context
+ for i, d1 in enumerate(detections):
+ for j in range(i + 1, len(detections)):
+ d2 = detections[j]
+ if str(d1.label) != str(d2.label):
+ iou = ObjectDetectionUtils.pairwise_iou(d1, d2)
if iou.max() == 0:
return True
-
return False
@staticmethod
def has_clusters(
- image: Image, detections: List[ObjectDetectionResultI], threshold=50
+ image: Image.Image, detections: list[ObjectDetectionResultI], threshold=50
) -> bool:
import numpy as np
- from scipy.spatial.distance import pdist, squareform
-
- # Get centers of all detections
- centers = []
- for detection in detections:
- bbox = detection.as_xyxy().squeeze(0)
- x_center = (bbox[0] + bbox[2]) / 2
- y_center = (bbox[1] + bbox[3]) / 2
- centers.append((x_center, y_center))
-
- centers = np.array(centers)
-
- # Compute pairwise distances
- dists = squareform(pdist(centers))
-
- # Simple clustering by distance threshold (e.g., 50 pixels)
- visited = set()
- clusters = []
+ ctx = ObjectDetectionUtils.get_current_context()
+ if ctx is not None:
+ cache = ctx.setdefault("pred_cache", {})
+ key = ("has_clusters", float(threshold))
+ if key in cache:
+ return cache[key]
+ centers_t: torch.Tensor = ctx.get("centers", torch.empty((0, 2)))
+ if centers_t.numel() == 0 or centers_t.shape[0] < 2:
+ cache[key] = False
+ return False
+ centers = centers_t.cpu().numpy()
+ # Simple O(n^2) proximity check; no heavy scipy
+ n = centers.shape[0]
+ clustered = False
+ for i in range(n):
+ for j in range(i + 1, n):
+ dx = centers[i, 0] - centers[j, 0]
+ dy = centers[i, 1] - centers[j, 1]
+ if (dx * dx + dy * dy) ** 0.5 < threshold:
+ clustered = True
+ break
+ if clustered:
+ break
+ cache[key] = clustered
+ return clustered
- for i in range(len(centers)):
- if i in visited:
- continue
- cluster = [i]
- visited.add(i)
- for j in range(len(centers)):
- if j not in visited and dists[i][j] < threshold:
- cluster.append(j)
- visited.add(j)
- if len(cluster) >= 2:
- clusters.append(cluster)
-
- if not clusters:
- return False
- else:
- return True
+ # Fallback: trivial no-cluster without context
+ return False
class IsObjectCentered(Question):
- def __init__(self) -> None:
+ def __init__(self, buffer_ratio: float = 0.05) -> None:
+ """Create an *Is-Object-Centered* question.
+
+ Args:
+ buffer_ratio: Fraction of the image width to treat as a no-ask buffer
+ around the one-third and two-third vertical lines. A value such as
+ ``0.05`` means 5 % of the image width on either side of the grid
+ boundary will be treated as *ambiguous* β if any side of the
+ bounding box falls in that zone, the question is skipped for
+ that object.
+ """
super().__init__(
- question="Divide the image into thirds. Is the {object_1} centered in the image, or is it off to the left or right?",
+ question=(
+ "Divide the image into thirds. In which third does the "
+ "{object_1} primarily appear? Respond with the letter only: "
+ "A) left third, B) middle third, C) right third."
+ ),
variables=["object_1"],
predicates=[
ObjectDetectionPredicates.at_least_one_single_detection,
],
)
+ if buffer_ratio < 0 or buffer_ratio > 0.5:
+ raise ValueError(
+ "Buffer ratio provided does not make sense. Must be between 0 (no buffer) and 0.5 (half the image width)")
+ self.buffer_ratio: float = buffer_ratio
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_one_single_detection(image, detections) == True
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ bxyxy: torch.Tensor = ctx.get("bboxes_xyxy", torch.empty((0, 4)))
+ # classes with single instance
+ single_classes = {k for k, v in counts.items() if v == 1}
- # get all the classes that have only one detection
- detection_counts = {}
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # need to iterate over the tensor to get the class names
- for class_name in class_name:
- detection_counts[class_name] = (
- detection_counts.get(class_name, 0) + 1
- )
- else:
- detection_counts[class_name] = detection_counts.get(class_name, 0) + 1
-
- single_detections = [
- class_name for class_name, count in detection_counts.items() if count == 1
- ]
-
- image_width, image_height = image.size
-
- object_positions = []
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # need to iterate over the tensor to get the class names
- for single_class_name in class_name:
- if single_class_name in single_detections:
- object_positions.append(
- (
- single_class_name,
- detection.as_xyxy()[0][0],
- detection.as_xyxy()[0][2],
- )
- )
- else:
- if class_name in single_detections:
- object_positions.append(
- (
- class_name,
- detection.as_xyxy()[0][0],
- detection.as_xyxy()[0][2],
- )
- )
+ image_width, _ = image.size
question_answer_pairs = []
- for class_name, x_min, x_max in object_positions:
+ for idx, lbl, _det in self.iterate_label_indices(image, detections):
+ if lbl not in single_classes:
+ continue
+ x_min = float(bxyxy[idx, 0])
+ x_max = float(bxyxy[idx, 2])
+ class_name = lbl
question = self.question.format(object_1=class_name)
- # TODO: verify this design decision manually
- # edge case: if the object is big enough to cover more than 1/3rd
- # then it's ambiguous so we will not answer
- if x_min < image_width / 3 and x_max < image_width / 3:
- answer = "left"
- elif x_min > image_width / 3 and x_max < 2 * image_width / 3:
- answer = "centered"
- elif x_min > 2 * image_width / 3 and x_max > 2 * image_width / 3:
- answer = "right"
+ left_line = image_width / 3
+ right_line = 2 * image_width / 3
+ buffer = self.buffer_ratio * image_width
+
+ # Discard if bbox is too close to a boundary (ambiguous)
+ if (
+ abs(x_min - left_line) < buffer
+ or abs(x_max - left_line) < buffer
+ or abs(x_min - right_line) < buffer
+ or abs(x_max - right_line) < buffer
+ ):
+ logger.debug("IsObjectCentered skipped due to ambiguity buffer")
+ continue
+
+ # Determine third based on buffered grid
+ if x_max < left_line - buffer:
+ answer = "A"
+ elif x_min > left_line + buffer and x_max < right_line - buffer:
+ answer = "B"
+ elif x_min > right_line + buffer:
+ answer = "C"
else:
- # object is too big to be centered so skip
- logger.debug(
- "Object is too big to be left, right or centered. Skipping question."
- )
+ # Large object spans multiple thirds β ambiguous
continue
question_answer_pairs.append((question, answer))
@@ -401,8 +408,11 @@ def apply(
class WidthVsHeight(Question):
- # TODO: try a bunch of different thresholds for width vs height
- def __init__(self, threshold: float = 0.30) -> None:
+ def __init__(
+ self,
+ threshold: float = 0.75,
+ non_articulated_classes: Optional[list[str]] = None,
+ ) -> None:
super().__init__(
question="Is the width of the {object_1} appear to be larger than the height?",
variables=["object_1"],
@@ -410,97 +420,79 @@ def __init__(self, threshold: float = 0.30) -> None:
ObjectDetectionPredicates.at_least_one_single_detection,
],
)
- self.threshold = threshold
- self.other_question = "Is the height of the {object_1} larger than the width?"
+ # ask recall. if object is detected, then ask for unique description
+ if non_articulated_classes is not None and len(non_articulated_classes) == 0:
+ raise ValueError(
+ "non_articulated_classes must be a non-empty list of class names"
+ )
+ self.non_articulated_classes: Optional[list[str]] = non_articulated_classes
+ self.threshold: float = threshold
+ self.other_question: Optional[str] = (
+ "Is the height of the {object_1} larger than the width?"
+ )
def __repr__(self):
return f"Question: {self.question} (threshold: {self.threshold})"
- def _question_answer(
- self, class_name: str, detection: ObjectDetectionResultI, reverse: bool = False
- ) -> Optional[Tuple[str, str]]:
- width = detection.as_xywh().squeeze()[2].item()
- height = detection.as_xywh().squeeze()[3].item()
- # TODO: should we check for a minimum width or height?
- if abs(width - height) / width < self.threshold:
- logger.debug(
- "Width and height are roughly equal (within threshold) so can't ask WidthVsHeight"
- )
+ def _question_answer_ratio(
+ self, class_name: str, ratio_wh: float, reverse: bool = False
+ ) -> Optional[tuple[str, str]]:
+ # Skip if near-square within threshold band
+ if abs(ratio_wh - 1.0) < self.threshold:
return None
-
- if width > height:
- answer = "yes"
- other_answer = "no"
- else:
- answer = "no"
- other_answer = "yes"
-
+ answer = "yes" if ratio_wh > 1.0 else "no"
if reverse:
question = self.other_question.format(object_1=class_name)
- answer = other_answer
+ answer = "no" if answer == "yes" else "yes"
else:
question = self.question.format(object_1=class_name)
-
return (question, answer)
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
+ detections: list[ObjectDetectionResultI],
reverse: bool = False,
- ) -> List[Tuple[str, str]]:
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_one_single_detection(image, detections) == True
-
- # get all the classes that have only one detection
- detection_counts = {}
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # need to iterate over the tensor to get the class names
- for single_class_name in class_name:
- detection_counts[single_class_name] = (
- detection_counts.get(single_class_name, 0) + 1
- )
- else:
- detection_counts[class_name] = detection_counts.get(class_name, 0) + 1
-
- single_detections = [
- class_name for class_name, count in detection_counts.items() if count == 1
- ]
-
- question_answer_pairs = []
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # need to iterate over the tensor to get the class names
- for single_class_name in class_name:
- if single_class_name in single_detections:
- question_answer_pair = self._question_answer(
- single_class_name, detection, reverse=reverse
- )
- if question_answer_pair is not None:
- question_answer_pairs.append(question_answer_pair)
- else:
- if class_name in single_detections:
- question_answer_pair = self._question_answer(
- class_name, detection, reverse=reverse
- )
- if question_answer_pair is not None:
- question_answer_pairs.append(question_answer_pair)
-
- return question_answer_pairs
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ aspects: torch.Tensor = ctx.get("aspects", torch.empty((0,)))
+ qa: list[tuple[str, str]] = []
+ for idx, lbl, _ in self.iterate_label_indices(image, detections):
+ if counts.get(lbl, 0) != 1:
+ continue
+ if self.non_articulated_classes is not None and lbl not in self.non_articulated_classes:
+ continue
+ ratio = float(aspects[idx]) if aspects.numel() > idx else None
+ if ratio is None:
+ # Fallback using bbox
+ bboxes_xyxy = ctx.get("bboxes_xyxy")
+ if bboxes_xyxy is not None:
+ b = bboxes_xyxy[idx]
+ w = float(b[2] - b[0])
+ h = float(b[3] - b[1])
+ ratio = w / max(h, 1e-6)
+ else:
+ continue # Skip if no bbox data available
+ qa_pair = self._question_answer_ratio(lbl, ratio, reverse=reverse)
+ if qa_pair is not None:
+ qa.append(qa_pair)
+ return qa
class Quadrants(Question):
- def __init__(self, N: int, M: int) -> None:
+ def __init__(self, N: int, M: int, margin_ratio: float = 0.1) -> None:
if N <= 0 or M <= 0:
raise ValueError("N and M must be positive integers")
- # TODO: verify this design decision manually
- # we will support at most a 3x3 grid
- if N * M > 9:
- raise ValueError("N * M must be less than or equal to 9")
- self.rows = N
- self.cols = M
+ if N * M > 12:
+ raise ValueError("N * M must be less than or equal to 12")
+ if margin_ratio < 0 or margin_ratio > 0.5:
+ raise ValueError(
+ "Margin ratio must be between 0 (no margin) and 0.5 (half the quadrant width/height)")
+ self.rows: int = N
+ self.cols: int = M
+ self.margin_ratio: float = margin_ratio
super().__init__(
question="Divide the image into a {N} x {M} grid. Number the quadrants from left to right, top to bottom, starting with 1. In what quadrant does the {object_1} appear?",
variables=["object_1", "N", "M"],
@@ -511,7 +503,7 @@ def __init__(self, N: int, M: int) -> None:
def _question_answer(
self, image: Image.Image, class_name: str, detection: ObjectDetectionResultI
- ) -> Optional[Tuple[str, str]]:
+ ) -> Optional[tuple[str, str]]:
x_min, y_min, x_max, y_max = detection.as_xyxy()[0]
detection_width = x_max - x_min
detection_height = y_max - y_min
@@ -521,8 +513,14 @@ def _question_answer(
quadrant_width = image_width / self.cols
quadrant_height = image_height / self.rows
+ # Margin inside each quadrant that the bbox must fully respect
+ margin_x = self.margin_ratio * quadrant_width
+ margin_y = self.margin_ratio * quadrant_height
+
+ # Require bbox to fit wholly inside a quadrant with the margin buffer
if not (
- detection_width < quadrant_width and detection_height < quadrant_height
+ detection_width < quadrant_width - 2 * margin_x
+ and detection_height < quadrant_height - 2 * margin_y
):
return None
@@ -536,6 +534,17 @@ def _question_answer(
if col != math.floor(x_max / quadrant_width):
logger.debug("Object spans multiple columns")
return None
+
+ # Ensure bbox respects margin inside the identified quadrant
+ if not (
+ x_min >= col * quadrant_width + margin_x
+ and x_max <= (col + 1) * quadrant_width - margin_x
+ and y_min >= row * quadrant_height + margin_y
+ and y_max <= (row + 1) * quadrant_height - margin_y
+ ):
+ logger.debug("Quadrants skipped due to margin ambiguity")
+ return None
+
quadrant = row * self.cols + col + 1
question = self.question.format(
@@ -549,46 +558,20 @@ def _question_answer(
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_one_single_detection(image, detections) == True
-
- # get all the classes that have only one detection
- detection_counts = {}
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # need to iterate over the tensor to get the class names
- for single_class_name in class_name:
- detection_counts[single_class_name] = (
- detection_counts.get(single_class_name, 0) + 1
- )
- else:
- detection_counts[class_name] = detection_counts.get(class_name, 0) + 1
-
- single_detections = [
- class_name for class_name, count in detection_counts.items() if count == 1
- ]
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ single_classes = {k for k, v in counts.items() if v == 1}
question_answer_pairs = []
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # need to iterate over the tensor to get the class names
- for single_class_name in class_name:
- if single_class_name in single_detections:
- question_answer_pair = self._question_answer(
- image, single_class_name, detection
- )
- if question_answer_pair is not None:
- question_answer_pairs.append(question_answer_pair)
- else:
- if class_name in single_detections:
- question_answer_pair = self._question_answer(
- image, class_name, detection
- )
- if question_answer_pair is not None:
- question_answer_pairs.append(question_answer_pair)
+ for _, lbl, det in self.iterate_label_indices(image, detections):
+ if lbl not in single_classes:
+ continue
+ qa = self._question_answer(image, lbl, det)
+ if qa is not None:
+ question_answer_pairs.append(qa)
return question_answer_pairs
@@ -604,6 +587,7 @@ def __init__(self, threshold: float = 0.3) -> None:
),
],
)
+ # in the R.O.S. verifier, black out every single box then ask
self.threshold = threshold
def __repr__(self):
@@ -612,39 +596,103 @@ def __repr__(self):
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_x_many_class_detections(image, detections, 2) == True
-
- if len(detections) == 0:
- logger.debug("No detections given to LargestAppearance question")
+ ctx = self._get_ctx(image, detections)
+ areas: torch.Tensor = ctx.get("areas", torch.empty((0,)))
+ labels: list[str] = ctx.get("labels", [])
+ if areas.numel() == 0:
+ return []
+ order = torch.argsort(areas, descending=True)
+ if order.numel() < 2:
return []
+ i0 = int(order[0].item())
+ i1 = int(order[1].item())
+ if float(areas[i0]) <= (1 + self.threshold) * float(areas[i1]):
+ return []
+ return [(self.question, str(labels[i0]))]
- # TODO: verify if this works
- # the same logic should apply here regardless of detections being a tensor or not
- areas = [detection.get_area() for detection in detections]
- largest_detection = detections[torch.argmax(torch.stack(areas))]
- second_largest_detection = detections[
- torch.argsort(torch.stack(areas).squeeze())[-2]
- ]
- # check if the largest detection is at least 30% larger than the second largest
- if not (
- largest_detection.get_area().item()
- > (1 + self.threshold) * second_largest_detection.get_area().item()
- ):
- logger.debug(
- f"Largest detection is not at least {self.threshold:.2%} larger than the second largest"
- )
+class RankLargestK(Question):
+ """Rank the *k* object classes that have the largest single-instance area.
+
+ Example question (for k=3):
+
+ "Rank the 3 kinds of objects that appear the largest in the image from
+ largest to smallest. Provide your answer as a comma-separated list of
+ object names only."
+ """
+
+ def __init__(self, k: int, margin_ratio: float = 0.3) -> None:
+ """Create a RankLargestK question.
+
+ Args:
+ k: number of classes to rank.
+ margin_ratio: required multiplicative margin between consecutive
+ ranked areas. For class *i* to be considered larger than class
+ *i+1*, its area must be at least ``(1 + margin_ratio)`` times
+ larger. If any consecutive pair fails this criterion, the
+ question will be skipped for that image.
+ """
+ if k <= 0:
+ raise ValueError("k must be a positive integer")
+ if margin_ratio < 0:
+ raise ValueError("margin_ratio must be non-negative")
+
+ self.k: int = k
+ self.margin_ratio: float = margin_ratio
+ super().__init__(
+ question=(
+ "Rank the {k} kinds of objects that appear the largest (by pixel area) in the "
+ "image from largest to smallest. Provide your answer as a "
+ "comma-separated list of object names only."
+ ),
+ variables=["k"],
+ predicates=[
+ # Need at least k different classes detected
+ lambda image, detections, k=k: ObjectDetectionPredicates.at_least_x_many_class_detections(
+ image, detections, k
+ ),
+ ],
+ )
+
+ def apply(
+ self,
+ image: Image.Image,
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
+ ctx = self._get_ctx(image, detections)
+ # per-class max of areas
+ class_max_area = self.per_class_reduce(image, detections, tensor_key="areas", reduce="max")
+ if len(class_max_area) < self.k:
+ logger.debug("Not enough unique classes for RankLargestK question")
return []
- question = self.question
- answer = str(largest_detection.label)
+ # Sort classes by their largest instance area
+ sorted_classes = sorted(
+ class_max_area.items(), key=lambda item: item[1], reverse=True
+ )
+
+ # Verify margin criterion among top-k areas
+ top_k = sorted_classes[: self.k]
+ for i in range(len(top_k) - 1):
+ area_i = top_k[i][1]
+ area_next = top_k[i + 1][1]
+ if area_i < (1 + self.margin_ratio) * area_next:
+ logger.debug(
+ "RankLargestK margin threshold not met between %s and %s", top_k[i][0], top_k[i + 1][0])
+ return []
+
+ top_k_labels = [cls for cls, _ in top_k]
+
+ question = self.question.format(k=self.k)
+ answer = ", ".join(map(str, top_k_labels))
return [(question, answer)]
class MostAppearance(Question):
- def __init__(self) -> None:
+ def __init__(self, margin_ratio: float = 0.2) -> None:
super().__init__(
question="What kind of object appears the most frequently in the image?",
variables=[],
@@ -654,47 +702,32 @@ def __init__(self) -> None:
),
],
)
+ if margin_ratio < 0 or margin_ratio >= 1:
+ raise ValueError(
+ "The margin ratio between the classes that appear most frequently must be non-negative and less than 1")
+ self.margin_ratio: float = margin_ratio
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_x_many_class_detections(image, detections, 2) == True
-
- if len(detections) == 0:
- logger.debug("No detections given to MostAppearance question")
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ if len(counts) < 2:
return []
-
- detections_counts = {}
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # need to iterate over the tensor to get the class names
- for single_class_name in class_name:
- detections_counts[single_class_name] = (
- detections_counts.get(single_class_name, 0) + 1
- )
- else:
- detections_counts[class_name] = detections_counts.get(class_name, 0) + 1
-
- sorted_detections = sorted(
- detections_counts.items(), key=lambda x: x[1], reverse=True
- )
- if sorted_detections[0][1] == sorted_detections[1][1]:
- # we will not handle ties so better not to answer
- logger.debug("Tie in MostAppearance question")
+ sorted_counts = sorted(counts.items(), key=lambda x: x[1], reverse=True)
+ top_count = sorted_counts[0][1]
+ second_count = sorted_counts[1][1]
+ if top_count < (1 + self.margin_ratio) * second_count:
return []
-
- most_detections = sorted_detections[0][0]
-
- question = self.question
- answer = str(most_detections)
- return [(question, answer)]
+ most = sorted_counts[0][0]
+ return [(self.question, str(most))]
class LeastAppearance(Question):
- def __init__(self) -> None:
+ def __init__(self, margin_ratio: float = 0.2) -> None:
super().__init__(
question="What kind of object appears the least frequently in the image?",
variables=[],
@@ -704,40 +737,26 @@ def __init__(self) -> None:
),
],
)
+ if margin_ratio < 0 or margin_ratio >= 1:
+ raise ValueError(
+ "The margin ratio between the classes that appear least frequently must be non-negative and less than 1")
+ self.margin_ratio: float = margin_ratio
def apply(
- self, image: Image.Image, detections: List[ObjectDetectionResultI]
- ) -> List[Tuple[str, str]]:
+ self, image: Image.Image, detections: list[ObjectDetectionResultI]
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_x_many_class_detections(image, detections, 2) == True
-
- if len(detections) == 0:
- logger.debug("No detections given to LeastAppearance question")
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ if len(counts) < 2:
return []
-
- detections_counts = {}
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # need to iterate over the tensor to get the class names
- for single_class_name in class_name:
- detections_counts[single_class_name] = (
- detections_counts.get(single_class_name, 0) + 1
- )
- else:
- detections_counts[class_name] = detections_counts.get(class_name, 0) + 1
-
- sorted_detections = sorted(detections_counts.items(), key=lambda x: x[1])
-
- if sorted_detections[0][1] == sorted_detections[1][1]:
- # we will not handle ties so better not to answer
- logger.debug("Tie in LeastAppearance question")
+ sorted_counts = sorted(counts.items(), key=lambda x: x[1])
+ lowest = sorted_counts[0][1]
+ second_lowest = sorted_counts[1][1]
+ if second_lowest < (1 + self.margin_ratio) * lowest:
return []
-
- least_detections = sorted_detections[0][0]
-
- question = self.question
- answer = str(least_detections)
- return [(question, answer)]
+ least = sorted_counts[0][0]
+ return [(self.question, str(least))]
class LeftOf(Question):
@@ -756,49 +775,40 @@ def __init__(self) -> None:
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_x_many_class_detections(image, detections, 2) == True
# @precondition: exists_non_overlapping_detections(image, detections) == True
-
- left_most_detections, right_most_detections, _, _ = self._find_extremes(
- image, detections
- )
-
- # iterate over the right most detections and check if there is a different class
- # that is to the left and non-overlapping of the instances we found above
- question_answer_pairs = []
- for obj_2_class_name, (_, right_most_bbox) in right_most_detections.items():
- for obj_1_class_name, (_, left_most_bbox) in left_most_detections.items():
- if obj_2_class_name == obj_1_class_name:
- continue
-
- # check if the left most detection of obj_1 is to the left
- # of the right most detection of obj_2
- if not (left_most_bbox[2] < right_most_bbox[0]): # not (x2 < x1)
- continue
-
- # and non-overlapping
- x1_inter = max(left_most_bbox[0], right_most_bbox[0])
- x2_inter = min(left_most_bbox[2], right_most_bbox[2])
- y1_inter = max(left_most_bbox[1], right_most_bbox[1])
- y2_inter = min(left_most_bbox[3], right_most_bbox[3])
-
- inter_width = max(0, x2_inter - x1_inter + 1)
- inter_height = max(0, y2_inter - y1_inter + 1)
- inter_area = inter_width * inter_height
-
- if inter_area > 0:
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ labels = list(counts.keys())
+ bxyxy: torch.Tensor = ctx.get("bboxes_xyxy", torch.empty((0, 4)))
+ # Group indices per class
+ class_to_indices: dict[str, list[int]] = ctx.get("class_to_indices", {})
+ qa = []
+ for i in range(len(labels)):
+ for j in range(len(labels)):
+ if i == j:
continue
-
- question = self.question.format(
- object_1=obj_1_class_name,
- object_2=obj_2_class_name,
- )
- answer = "Yes"
- question_answer_pairs.append((question, answer))
-
- return question_answer_pairs
+ c1, c2 = labels[i], labels[j]
+ found_yes = False
+ for idx1 in class_to_indices.get(c1, []):
+ x2_1 = float(bxyxy[idx1, 2])
+ for idx2 in class_to_indices.get(c2, []):
+ x1_2 = float(bxyxy[idx2, 0])
+ if x2_1 < x1_2:
+ # Check non-overlap via IOU 0 using cached detections
+ det1 = ctx["detections"][idx1]
+ det2 = ctx["detections"][idx2]
+ if ObjectDetectionUtils.pairwise_iou(det1, det2).max() == 0:
+ qa.append((self.question.format(object_1=c1, object_2=c2), "Yes"))
+ found_yes = True
+ break
+ if found_yes:
+ break
+ if not found_yes:
+ qa.append((self.question.format(object_1=c1, object_2=c2), "No"))
+ return qa
class RightOf(Question):
@@ -817,49 +827,38 @@ def __init__(self) -> None:
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_x_many_class_detections(image, detections, 2) == True
# @precondition: exists_non_overlapping_detections(image, detections) == True
-
- left_most_detections, right_most_detections, _, _ = self._find_extremes(
- image, detections
- )
-
- # iterate over the left most detections and check if there is a different class
- # that is to the right and non-overlapping of the instances we found above
- question_answer_pairs = []
- for obj_1_class_name, (_, left_most_bbox) in right_most_detections.items():
- for obj_2_class_name, (_, right_most_bbox) in left_most_detections.items():
- if obj_1_class_name == obj_2_class_name:
- continue
-
- # check if the right most detection of obj_1 is to the right
- # of the left most detection of obj_2
- if not (right_most_bbox[2] < left_most_bbox[0]): # not (x2 < x1)
- continue
-
- # and non-overlapping
- x1_inter = max(left_most_bbox[0], right_most_bbox[0])
- x2_inter = min(left_most_bbox[2], right_most_bbox[2])
- y1_inter = max(left_most_bbox[1], right_most_bbox[1])
- y2_inter = min(left_most_bbox[3], right_most_bbox[3])
-
- inter_width = max(0, x2_inter - x1_inter + 1)
- inter_height = max(0, y2_inter - y1_inter + 1)
- inter_area = inter_width * inter_height
-
- if inter_area > 0:
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ labels = list(counts.keys())
+ bxyxy: torch.Tensor = ctx.get("bboxes_xyxy", torch.empty((0, 4)))
+ class_to_indices: dict[str, list[int]] = ctx.get("class_to_indices", {})
+ qa = []
+ for i in range(len(labels)):
+ for j in range(len(labels)):
+ if i == j:
continue
-
- question = self.question.format(
- object_1=obj_1_class_name,
- object_2=obj_2_class_name,
- )
- answer = "Yes"
- question_answer_pairs.append((question, answer))
-
- return question_answer_pairs
+ c1, c2 = labels[i], labels[j]
+ found_yes = False
+ for idx1 in class_to_indices.get(c1, []):
+ x1_1 = float(bxyxy[idx1, 0])
+ for idx2 in class_to_indices.get(c2, []):
+ x2_2 = float(bxyxy[idx2, 2])
+ if x1_1 > x2_2:
+ det1 = ctx["detections"][idx1]
+ det2 = ctx["detections"][idx2]
+ if ObjectDetectionUtils.pairwise_iou(det1, det2).max() == 0:
+ qa.append((self.question.format(object_1=c1, object_2=c2), "Yes"))
+ found_yes = True
+ break
+ if found_yes:
+ break
+ if not found_yes:
+ qa.append((self.question.format(object_1=c1, object_2=c2), "No"))
+ return qa
# One can image an AboveOf and BelowOf question as well
@@ -885,74 +884,33 @@ def __init__(self) -> None:
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_one_single_detection(image, detections) == True
-
- # TODO: Asking this question heavily depends on the accuracy of the object detection model.
- # It's possible that the model is not able to detect some objects because it was not trained on them.
- # For example, if the model was trained on COCO, it might not be able to detect objects that
- # are not in the COCO dataset, whereas a model trained on Imagenet-1k might be able to do so.
- #
- # One way to address this, would be to implement set-of-mark prompting (highlight what we can
- # detect via bounding boxes) and then ask the model to answer the question based on that.
-
- if len(detections) == 0:
+ ctx = self._get_ctx(image, detections)
+ if len(ctx.get("detections", [])) == 0:
return []
-
- if len(detections) == 1:
- image_width, _ = image.size
- # logic to check if the bbox is actually on the left side of the image
- if (
- detections[0].as_xyxy()[0][0] < image_width / 2
- and detections[0].as_xyxy()[0][2] < image_width / 2
- ):
- return [(self.question, detections[0].label)]
- else:
- return []
-
- flattened_detections = []
- for detection in detections:
- curr_bbox = detection.as_xyxy().squeeze(0)
- if type(detection.label) is torch.Tensor:
- for i in range(detection.label.shape[0]):
- label = detection.label[i]
- curr_bbox = curr_bbox[i]
- flattened_detections.append((label, curr_bbox))
- else:
- flattened_detections.append((detection.label, curr_bbox))
-
- sorted_detections = sorted(
- flattened_detections, key=lambda x: x[1][0]
- ) # sort by x1 coordinate
- leftmost_detection = sorted_detections[0]
- second_leftmost_detection = sorted_detections[1]
-
- x1_inter = max(leftmost_detection[1][0], second_leftmost_detection[1][0])
- x2_inter = min(leftmost_detection[1][2], second_leftmost_detection[1][2])
- y1_inter = max(leftmost_detection[1][1], second_leftmost_detection[1][1])
- y2_inter = min(leftmost_detection[1][3], second_leftmost_detection[1][3])
-
- inter_width = max(0, x2_inter - x1_inter + 1)
- inter_height = max(0, y2_inter - y1_inter + 1)
- inter_area = inter_width * inter_height
-
- if inter_area > 0: # overlapping
- logger.debug("LeftMost question not ask-able due to overlapping detections")
+ bxyxy: torch.Tensor = ctx.get("bboxes_xyxy", torch.empty((0, 4)))
+ labels: list[str] = ctx.get("labels", [])
+ order = self.sort_detections_by(image, detections, key="x1", reverse=False)
+ if len(order) < 2:
+ # Single detection case: ensure it's on the left half fully
+ idx = order[0]
+ x1, x2 = float(bxyxy[idx, 0]), float(bxyxy[idx, 2])
+ if x1 < image.size[0] / 2 and x2 < image.size[0] / 2:
+ return [(self.question, str(labels[idx]))]
+ return []
+ i0, i1 = order[0], order[1]
+ # Overlap check using IOU 0 between first two leftmost
+ det0 = ctx["detections"][i0]
+ det1 = ctx["detections"][i1]
+ if ObjectDetectionUtils.pairwise_iou(det0, det1).max() > 0:
+ return []
+ # Ensure leftmost is on left half fully
+ x1, x2 = float(bxyxy[i0, 0]), float(bxyxy[i0, 2])
+ if not (x1 < image.size[0] / 2 and x2 < image.size[0] / 2):
return []
-
- image_width, _ = image.size
- # logic to check if the bbox is actually on the left side of the image
- if not (
- leftmost_detection[1][0] < image_width / 2
- and leftmost_detection[1][2] < image_width / 2
- ):
- logger.debug(
- "LeftMost question not ask-able due to not being on the left side of the image"
- )
- return []
-
- return [(self.question, leftmost_detection[0])]
+ return [(self.question, str(labels[i0]))]
class RightMost(Question):
@@ -968,76 +926,30 @@ def __init__(self) -> None:
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_one_single_detection(image, detections) == True
-
- # TODO: Asking this question heavily depends on the accuracy of the object detection model.
- # It's possible that the model is not able to detect some objects because it was not trained on them.
- # For example, if the model was trained on COCO, it might not be able to detect objects that
- # are not in the COCO dataset, whereas a model trained on Imagenet-1k might be able to do so.
- #
- # One way to address this, would be to implement set-of-mark prompting (highlight what we can
- # detect via bounding boxes) and then ask the model to answer the question based on that.
-
- if len(detections) == 0:
+ ctx = self._get_ctx(image, detections)
+ if len(ctx.get("detections", [])) == 0:
return []
-
- if len(detections) == 1:
- image_width, _ = image.size
- # logic to check if the bbox is actually on the right side of the image
- if (
- detections[0].as_xyxy()[0][0] > image_width / 2
- and detections[0].as_xyxy()[0][2] > image_width / 2
- ):
- return [(self.question, detections[0].label)]
- else:
- return []
-
- flattened_detections = []
- for detection in detections:
- curr_bbox = detection.as_xyxy().squeeze(0)
- if type(detection.label) is torch.Tensor:
- for i in range(detection.label.shape[0]):
- label = detection.label[i]
- curr_bbox = curr_bbox[i]
- flattened_detections.append((label, curr_bbox))
- else:
- flattened_detections.append((detection.label, curr_bbox))
-
- sorted_detections = sorted(
- flattened_detections, key=lambda x: x[1][2], reverse=True
- ) # sort by x2 coordinate
- rightmost_detection = sorted_detections[0]
- second_rightmost_detection = sorted_detections[1]
-
- x1_inter = max(rightmost_detection[1][0], second_rightmost_detection[1][0])
- x2_inter = min(rightmost_detection[1][2], second_rightmost_detection[1][2])
- y1_inter = max(rightmost_detection[1][1], second_rightmost_detection[1][1])
- y2_inter = min(rightmost_detection[1][3], second_rightmost_detection[1][3])
-
- inter_width = max(0, x2_inter - x1_inter + 1)
- inter_height = max(0, y2_inter - y1_inter + 1)
- inter_area = inter_width * inter_height
-
- if inter_area > 0: # overlapping
- logger.debug(
- "RightMost question not ask-able due to overlapping detections"
- )
+ bxyxy: torch.Tensor = ctx.get("bboxes_xyxy", torch.empty((0, 4)))
+ labels: list[str] = ctx.get("labels", [])
+ order = self.sort_detections_by(image, detections, key="x2", reverse=True)
+ if len(order) < 2:
+ idx = order[0]
+ x1, x2 = float(bxyxy[idx, 0]), float(bxyxy[idx, 2])
+ if x1 > image.size[0] / 2 and x2 > image.size[0] / 2:
+ return [(self.question, str(labels[idx]))]
return []
-
- image_width, _ = image.size
- # logic to check if the bbox is actually on the right side of the image
- if not (
- rightmost_detection[1][0] > image_width / 2
- and rightmost_detection[1][2] > image_width / 2
- ):
- logger.debug(
- "RightMost question not ask-able due to not being on the right side of the image"
- )
+ i0, i1 = order[0], order[1]
+ det0 = ctx["detections"][i0]
+ det1 = ctx["detections"][i1]
+ if ObjectDetectionUtils.pairwise_iou(det0, det1).max() > 0:
return []
-
- return [(self.question, rightmost_detection[0])]
+ x1, x2 = float(bxyxy[i0, 0]), float(bxyxy[i0, 2])
+ if not (x1 > image.size[0] / 2 and x2 > image.size[0] / 2):
+ return []
+ return [(self.question, str(labels[i0]))]
class HowMany(Question):
@@ -1056,34 +968,26 @@ def __init__(self) -> None:
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_x_many_class_detections(image, detections, 1) == True
-
- detection_counts = {}
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor: # shape == (# of boxes,)
- # need to iterate over the tensor to get the class names
- for single_class_name in class_name:
- detection_counts[single_class_name] = (
- detection_counts.get(single_class_name, 0) + 1
- )
- else:
- detection_counts[class_name] = detection_counts.get(class_name, 0) + 1
-
- question_answer_pairs = []
- for class_name, count in detection_counts.items():
- question_answer_pairs.append(
- (self.question.format(object_1=class_name), str(count))
- )
-
- return question_answer_pairs
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ return [
+ (self.question.format(object_1=cls), str(cnt)) for cls, cnt in counts.items()
+ ]
class AreMore(Question):
# TODO: Create a version of this question that is multiple choice
- def __init__(self) -> None:
+ def __init__(self, margin_ratio: float = 0.2) -> None:
+ """AreMore question with margin-based count filtering.
+
+ Args:
+ margin_ratio: Required margin between counts. Only asks question if
+ the larger count exceeds the smaller by at least this ratio.
+ E.g., margin_ratio=0.2 means count_1 must be β₯ 1.2 * count_2.
+ """
super().__init__(
question="Are there more {object_1}(s) than {object_2}(s) in this image?",
variables=["object_1", "object_2"],
@@ -1093,50 +997,40 @@ def __init__(self) -> None:
),
],
)
+ if margin_ratio < 0 or margin_ratio > 1:
+ raise ValueError("margin_ratio must be between 0 and 1")
+ self.margin_ratio = margin_ratio
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
-
- detection_counts = {}
- for detection in detections:
- class_name = detection.label
- if type(class_name) is torch.Tensor:
- for single_class_name in class_name:
- detection_counts[single_class_name] = (
- detection_counts.get(single_class_name, 0) + 1
- )
- else:
- detection_counts[class_name] = detection_counts.get(class_name, 0) + 1
- question_answer_pairs = []
- detected_classes = list(detection_counts.keys())
-
- for i in range(len(detected_classes)):
- for j in range(i + 1, len(detected_classes)):
- object_1, object_2 = detected_classes[i], detected_classes[j]
- count_1, count_2 = (
- detection_counts[object_1],
- detection_counts[object_2],
- )
-
- if count_1 > count_2:
- answer = "Yes"
- elif count_2 > count_1:
- answer = "No"
- else:
- continue
-
- question_answer_pairs.append(
- (self.question.format(object_1=object_1, object_2=object_2), answer)
- )
-
- return question_answer_pairs
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ classes = list(counts.keys())
+ qa: list[tuple[str, str]] = []
+ for i in range(len(classes)):
+ for j in range(i + 1, len(classes)):
+ o1, o2 = classes[i], classes[j]
+ c1, c2 = counts[o1], counts[o2]
+ if c1 > c2:
+ if c1 >= (1 + self.margin_ratio) * c2:
+ qa.append((self.question.format(object_1=o1, object_2=o2), "Yes"))
+ elif c2 > c1:
+ if c2 >= (1 + self.margin_ratio) * c1:
+ qa.append((self.question.format(object_1=o1, object_2=o2), "No"))
+ return qa
class WhichMore(Question):
- def __init__(self) -> None:
+ def __init__(self, margin_ratio: float = 0.2) -> None:
+ """WhichMore question with margin-based count filtering.
+
+ Args:
+ margin_ratio: Required margin for clear winner. Only asks question if
+ the winning count exceeds the second-highest by at least this ratio.
+ """
super().__init__(
question="What appears the most in this image: {object_1}s, {object_2}s, or {object_3}s?",
variables=["object_1", "object_2", "objejct_3"],
@@ -1146,12 +1040,15 @@ def __init__(self) -> None:
),
],
)
+ if margin_ratio < 0 or margin_ratio > 1:
+ raise ValueError("margin_ratio must be between 0 and 1")
+ self.margin_ratio = margin_ratio
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
detection_counts = {}
for detection in detections:
@@ -1162,7 +1059,8 @@ def apply(
detection_counts.get(single_class_name, 0) + 1
)
else:
- detection_counts[class_name] = detection_counts.get(class_name, 0) + 1
+ detection_counts[class_name] = detection_counts.get(
+ class_name, 0) + 1
question_answer_pairs = []
detected_classes = list(detection_counts.keys())
@@ -1181,6 +1079,15 @@ def apply(
)
max_count = max(count_1, count_2, count_3)
+ # Sort counts to find second highest
+ sorted_counts = sorted([count_1, count_2, count_3], reverse=True)
+ second_highest_count = sorted_counts[1]
+
+ # Check if winner has significant margin over second place
+ if max_count < (1 + self.margin_ratio) * second_highest_count:
+ # Winner not clear enough - skip question
+ continue
+
max_objects = []
if count_1 == max_count:
max_objects.append(object_1)
@@ -1205,7 +1112,15 @@ def apply(
class LeftMostWidthVsHeight(WidthVsHeight):
- def __init__(self, threshold: float = 0.3) -> None:
+ def __init__(self, threshold: float = 0.75, spatial_margin_ratio: float = 0.05) -> None:
+ """LeftMostWidthVsHeight with spatial stability checks.
+
+ Args:
+ threshold: Aspect ratio threshold
+ spatial_margin_ratio: Required spatial separation as fraction of image width.
+ The leftmost object must be separated from the second-leftmost by at least
+ this margin to ensure stable positioning.
+ """
super().__init__(threshold=threshold)
self.question = (
"Does the leftmost object in the image appear to be wider than it is tall?"
@@ -1213,206 +1128,436 @@ def __init__(self, threshold: float = 0.3) -> None:
self.other_question = (
"Does the leftmost object in the image appear to be taller than it is wide?"
)
+ if spatial_margin_ratio < 0 or spatial_margin_ratio > 1:
+ raise ValueError("spatial_margin_ratio must be between 0 and 1")
+ self.spatial_margin_ratio = spatial_margin_ratio
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
+ detections: list[ObjectDetectionResultI],
reverse: bool = False,
- ) -> List[Tuple[str, str]]:
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_one_single_detection(image, detections) == True
- im_width, im_height = image.size
-
- if len(detections) == 0:
- return []
-
- flattened_detections = [
- box for detection in detections for box in detection.flatten()
- ]
- detection_counts = {}
- for detection in flattened_detections:
- class_name = detection.label
- detection_counts[class_name] = detection_counts.get(class_name, 0) + 1
-
- single_detections = [
- class_name for class_name, count in detection_counts.items() if count == 1
- ]
- if len(single_detections) == 0:
- return []
-
- sorted_detections = sorted(
- flattened_detections, key=lambda x: x.as_xyxy()[0][0]
- ) # sort by x1 coordinate
- leftmost_detection = None
- second_leftmost_detection = None
- for i, detection in enumerate(sorted_detections):
- if detection.label in single_detections:
- is_on_left = (
- detection.as_xyxy()[0][0] < im_width / 2
- and detection.as_xyxy()[0][2] < im_width / 2
- )
- if not is_on_left:
- # no point in continuing if the leftmost detection is not on the left side of the image
- logger.debug(
- "LeftMostWidthVsHeight question not ask-able due to not being on the left side of the image"
- )
- return []
- leftmost_detection = detection
- if i + 1 < len(sorted_detections):
- second_leftmost_detection = sorted_detections[i + 1]
- break
-
- if leftmost_detection is None:
- logger.debug("No leftmost detection found")
- return []
- if second_leftmost_detection is not None:
- # check if the leftmost detection is overlapping with the second leftmost detection
- x1_inter = max(
- leftmost_detection.as_xyxy()[0][0],
- second_leftmost_detection.as_xyxy()[0][0],
- )
- x2_inter = min(
- leftmost_detection.as_xyxy()[0][2],
- second_leftmost_detection.as_xyxy()[0][2],
- )
- y1_inter = max(
- leftmost_detection.as_xyxy()[0][1],
- second_leftmost_detection.as_xyxy()[0][1],
- )
- y2_inter = min(
- leftmost_detection.as_xyxy()[0][3],
- second_leftmost_detection.as_xyxy()[0][3],
- )
- inter_width = max(0, x2_inter - x1_inter + 1)
- inter_height = max(0, y2_inter - y1_inter + 1)
- inter_area = inter_width * inter_height
-
- if inter_area > 0: # overlapping
- logger.debug(
- "LeftMostWidthVsHeight question not ask-able due to overlapping detections"
- )
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ bxyxy: torch.Tensor = ctx.get("bboxes_xyxy", torch.empty((0, 4)))
+ aspects: torch.Tensor = ctx.get("aspects", torch.empty((0,)))
+ labels: list[str] = ctx.get("labels", [])
+ order = self.sort_detections_by(image, detections, key="x1", reverse=False)
+ im_width, _ = image.size
+
+ left_idx = None
+ second_idx = None
+ for pos, idx in enumerate(order):
+ lbl = labels[idx]
+ if counts.get(lbl, 0) != 1:
+ continue
+ x1, x2 = float(bxyxy[idx, 0]), float(bxyxy[idx, 2])
+ if not (x1 < im_width / 2 and x2 < im_width / 2):
return []
-
- # check if the leftmost detection is at least threshold % larger than the second largest
- question_answer_pair = self._question_answer(
- leftmost_detection.label,
- leftmost_detection,
- reverse=reverse,
- )
- if question_answer_pair is None:
- logger.debug(
- "LeftMostWidthVsHeight question not ask-able due to width and height being roughly equal"
- )
+ left_idx = idx
+ if pos + 1 < len(order):
+ second_idx = order[pos + 1]
+ break
+ if left_idx is None:
return []
- return question_answer_pair
+ if second_idx is not None:
+ left_x2 = float(bxyxy[left_idx, 2])
+ second_x1 = float(bxyxy[second_idx, 0])
+ required_margin = self.spatial_margin_ratio * im_width
+ if (second_x1 - left_x2) < required_margin:
+ return []
+ # ensure no overlap
+ if ObjectDetectionUtils.pairwise_iou(ctx["detections"][left_idx], ctx["detections"][second_idx]).max() > 0:
+ logger.debug("Leftmost object overlaps with second-leftmost object")
+ return []
+ ratio = float(aspects[left_idx]) if aspects.numel() > left_idx else None
+ if ratio is None:
+ b = bxyxy[left_idx]
+ ratio = float((b[2]-b[0]) / max(float(b[3]-b[1]), 1e-6))
+ qa = self._question_answer_ratio(labels[left_idx], ratio, reverse=reverse)
+ return [qa] if qa is not None else []
class RightMostWidthVsHeight(WidthVsHeight):
- def __init__(self, threshold: float = 0.3) -> None:
+ def __init__(self, threshold: float = 0.75, spatial_margin_ratio: float = 0.05) -> None:
+ """RightMostWidthVsHeight with spatial stability checks.
+
+ Args:
+ threshold: Aspect ratio threshold (inherited from WidthVsHeight)
+ spatial_margin_ratio: Required spatial separation as fraction of image width.
+ The rightmost object must be separated from the second-rightmost by at least
+ this margin to ensure stable positioning.
+ """
super().__init__(threshold=threshold)
self.question = (
"Does the rightmost object in the image appear to be wider than it is tall?"
)
self.other_question = "Does the rightmost object in the image appear to be taller than it is wide?"
+ if spatial_margin_ratio < 0 or spatial_margin_ratio > 1:
+ raise ValueError("spatial_margin_ratio must be between 0 and 1")
+ self.spatial_margin_ratio = spatial_margin_ratio
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
+ detections: list[ObjectDetectionResultI],
reverse: bool = False,
- ) -> List[Tuple[str, str]]:
+ ) -> list[tuple[str, str]]:
# @precondition: at_least_one_single_detection(image, detections) == True
- im_width, im_height = image.size
-
- if len(detections) == 0:
+ ctx = self._get_ctx(image, detections)
+ counts: dict[str, int] = ctx.get("counts", {})
+ bxyxy: torch.Tensor = ctx.get("bboxes_xyxy", torch.empty((0, 4)))
+ aspects: torch.Tensor = ctx.get("aspects", torch.empty((0,)))
+ labels: list[str] = ctx.get("labels", [])
+ order = self.sort_detections_by(image, detections, key="x2", reverse=True)
+ im_width, _ = image.size
+
+ right_idx = None
+ second_idx = None
+ for pos, idx in enumerate(order):
+ lbl = labels[idx]
+ if counts.get(lbl, 0) != 1:
+ continue
+ x1, x2 = float(bxyxy[idx, 0]), float(bxyxy[idx, 2])
+ if not (x1 > im_width / 2 and x2 > im_width / 2):
+ return []
+ right_idx = idx
+ if pos + 1 < len(order):
+ second_idx = order[pos + 1]
+ break
+ if right_idx is None:
return []
+ if second_idx is not None:
+ right_x1 = float(bxyxy[right_idx, 0])
+ second_x2 = float(bxyxy[second_idx, 2])
+ required_margin = self.spatial_margin_ratio * im_width
+ if (right_x1 - second_x2) < required_margin:
+ return []
+ if ObjectDetectionUtils.pairwise_iou(ctx["detections"][right_idx], ctx["detections"][second_idx]).max() > 0:
+ logger.debug("Rightmost object overlaps with second-rightmost object")
+ return []
+ ratio = float(aspects[right_idx]) if aspects.numel() > right_idx else None
+ if ratio is None:
+ b = bxyxy[right_idx]
+ ratio = float((b[2]-b[0]) / max(float(b[3]-b[1]), 1e-6))
+ qa = self._question_answer_ratio(labels[right_idx], ratio, reverse=reverse)
+ return [qa] if qa is not None else []
- flattened_detections = [
- box for detection in detections for box in detection.flatten()
- ]
- detection_counts = {}
- for detection in flattened_detections:
- class_name = detection.label
- detection_counts[class_name] = detection_counts.get(class_name, 0) + 1
- single_detections = [
- class_name for class_name, count in detection_counts.items() if count == 1
- ]
- if len(single_detections) == 0:
+# drop this question
+class ObjectsInRow(Question):
+ def __init__(self, variance_threshold: float = 0.1) -> None:
+ """Linear regression-based row detection.
+
+ Args:
+ variance_threshold: Maximum normalized variance for y-centers to be
+ considered in a row. Lower values = stricter row detection.
+ """
+ super().__init__(
+ question="Are there any objects arranged in a row?",
+ variables=[],
+ predicates=[
+ lambda image, detections: ObjectDetectionPredicates.at_least_x_detections(
+ image, detections, 3
+ ),
+ ],
+ )
+ self.variance_threshold = variance_threshold
+
+ def apply(
+ self,
+ image: Image.Image,
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
+ from sklearn.linear_model import LinearRegression
+
+ if len(detections) < 3:
+ return [(self.question, "No")]
+
+ # Get center points
+ centers = []
+ for detection in detections:
+ bbox = detection.as_xyxy().squeeze(0)
+ x_center = float((bbox[0] + bbox[2]) / 2)
+ y_center = float((bbox[1] + bbox[3]) / 2)
+ centers.append((x_center, y_center))
+
+ # Sort by x-coordinate
+ centers = sorted(centers, key=lambda p: p[0])
+
+ # Try sliding windows of 3+ objects
+ image_height = image.size[1]
+
+ for window_size in range(3, len(centers) + 1):
+ for start in range(len(centers) - window_size + 1):
+ window = centers[start:start + window_size]
+
+ # Extract x and y coordinates
+ x_coords = np.array([p[0] for p in window]).reshape(-1, 1)
+ y_coords = np.array([p[1] for p in window])
+
+ # Fit linear regression
+ reg = LinearRegression().fit(x_coords, y_coords)
+ y_pred = reg.predict(x_coords)
+
+ # Calculate normalized variance (by image height)
+ variance = np.var(y_coords - y_pred)
+ normalized_variance = variance / (image_height ** 2)
+
+ if normalized_variance < self.variance_threshold:
+ return [(self.question, "Yes")]
+
+ return [(self.question, "No")]
+
+
+class ObjectsInLine(Question):
+ def __init__(self, variance_threshold: float = 0.1) -> None:
+ """Multiple choice question about which objects are in a row.
+
+ Args:
+ variance_threshold: Same as ObjectsInRow for consistency.
+ """
+ super().__init__(
+ question="Which objects appear to be arranged in a row? A) {option_a}, B) {option_b}, C) {option_c}, D) No clear row arrangement. Respond with the letter only.",
+ variables=["option_a", "option_b", "option_c"],
+ predicates=[
+ lambda image, detections: ObjectDetectionPredicates.at_least_x_detections(
+ image, detections, 3
+ ),
+ ],
+ )
+ self.variance_threshold = variance_threshold
+
+ def apply(
+ self,
+ image: Image.Image,
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
+ from sklearn.linear_model import LinearRegression
+
+ if len(detections) < 3:
return []
- sorted_detections = sorted(
- flattened_detections, key=lambda x: x.as_xyxy()[0][2], reverse=True
- ) # sort by x2 coordinate
- rightmost_detection = None
- second_rightmost_detection = None
- for i, detection in enumerate(sorted_detections):
- if detection.label in single_detections:
- is_on_right = (
- detection.as_xyxy()[0][0] > im_width / 2
- and detection.as_xyxy()[0][2] > im_width / 2
- )
- if not is_on_right:
- # no point in continuing if the rightmost detection is not on the right side of the image
- logger.debug(
- "RightMostWidthVsHeight question not ask-able due to not being on the right side of the image"
- )
- return []
- rightmost_detection = detection
- if i + 1 < len(sorted_detections):
- second_rightmost_detection = sorted_detections[i + 1]
- break
+ # Get centers with labels
+ centers_with_labels = []
+ for detection in detections:
+ bbox = detection.as_xyxy().squeeze(0)
+ x_center = float((bbox[0] + bbox[2]) / 2)
+ y_center = float((bbox[1] + bbox[3]) / 2)
+ label = str(detection.label)
+ centers_with_labels.append((x_center, y_center, label))
+
+ # Sort by x-coordinate
+ centers_with_labels = sorted(centers_with_labels, key=lambda p: p[0])
+
+ # Find best row arrangement
+ best_row = None
+ best_variance = float('inf')
+ image_height = image.size[1]
+
+ for window_size in range(3, len(centers_with_labels) + 1):
+ for start in range(len(centers_with_labels) - window_size + 1):
+ window = centers_with_labels[start:start + window_size]
+
+ x_coords = np.array([p[0] for p in window]).reshape(-1, 1)
+ y_coords = np.array([p[1] for p in window])
+
+ reg = LinearRegression().fit(x_coords, y_coords)
+ y_pred = reg.predict(x_coords)
+
+ variance = np.var(y_coords - y_pred)
+ normalized_variance = variance / (image_height ** 2)
+
+ if normalized_variance < self.variance_threshold and normalized_variance < best_variance:
+ best_variance = normalized_variance
+ best_row = [p[2] for p in window] # Extract labels
- if rightmost_detection is None:
- logger.debug("No rightmost detection found")
+ if best_row is None:
return []
- if second_rightmost_detection is not None:
- # check if the rightmost detection is overlapping with the second rightmost detection
- x1_inter = max(
- rightmost_detection.as_xyxy()[0][0],
- second_rightmost_detection.as_xyxy()[0][0],
- )
- x2_inter = min(
- rightmost_detection.as_xyxy()[0][2],
- second_rightmost_detection.as_xyxy()[0][2],
- )
- y1_inter = max(
- rightmost_detection.as_xyxy()[0][1],
- second_rightmost_detection.as_xyxy()[0][1],
- )
- y2_inter = min(
- rightmost_detection.as_xyxy()[0][3],
- second_rightmost_detection.as_xyxy()[0][3],
- )
- inter_width = max(0, x2_inter - x1_inter + 1)
- inter_height = max(0, y2_inter - y1_inter + 1)
- inter_area = inter_width * inter_height
+ # Create multiple choice options
+ row_text = ", ".join(best_row)
+
+ # Generate plausible distractors that are different from correct answer
+ all_labels = list(set(str(d.label) for d in detections))
+ random.shuffle(all_labels)
+
+ # Create distractors ensuring they're different from correct answer
+ distractor1 = ", ".join(all_labels[:min(3, len(all_labels))])
+ distractor2 = ", ".join(all_labels[-min(3, len(all_labels)):])
+
+ # Ensure distractors are different from correct answer
+ max_attempts = 10
+ attempt = 0
+ while (distractor1 == row_text or distractor2 == row_text or distractor1 == distractor2) and attempt < max_attempts:
+ random.shuffle(all_labels)
+ distractor1 = ", ".join(all_labels[:min(3, len(all_labels))])
+ # Use different slice
+ distractor2 = ", ".join(all_labels[-min(2, len(all_labels)):])
+ attempt += 1
+
+ # If still duplicates after attempts, skip this question
+ if distractor1 == row_text or distractor2 == row_text or distractor1 == distractor2:
+ return []
- if inter_area > 0:
- logger.debug(
- "RightMostWidthVsHeight question not ask-able due to overlapping detections"
- )
- return []
- # check if the rightmost detection is at least threshold % larger than the second largest
- question_answer_pair = self._question_answer(
- rightmost_detection.label,
- rightmost_detection,
- reverse=reverse,
+ # Randomly assign correct answer to A/B/C
+ options = [row_text, distractor1, distractor2]
+ random.shuffle(options)
+ correct_letter = ["A", "B", "C"][options.index(row_text)]
+
+ q = self.question.format(
+ option_a=options[0],
+ option_b=options[1],
+ option_c=options[2]
)
- if question_answer_pair is None:
- logger.debug(
- "RightMostWidthVsHeight question not ask-able due to width and height being roughly equal"
- )
+
+ return [(q, correct_letter)]
+
+
+# drop this question
+class MostClusteredObjects(Question):
+ def __init__(self, eps_ratio: float = 0.05, min_samples: int = 3) -> None:
+ """DBSCAN-based clustering with multiple choice answers.
+
+ Args:
+ eps_ratio: Maximum distance between points in a cluster as a fraction
+ of the image diagonal. Default 0.05 means 5% of image diagonal.
+ min_samples: Minimum points required to form a cluster.
+ """
+ super().__init__(
+ question="Which group of objects appears most tightly clustered? A) {option_a}, B) {option_b}, C) {option_c}, D) No clear clusters. Respond with the letter only.",
+ variables=["option_a", "option_b", "option_c"],
+ predicates=[
+ lambda image, detections: ObjectDetectionPredicates.at_least_x_detections(
+ image, detections, 9 # Need at least 3 clusters Γ 3 objects each
+ ),
+ ],
+ )
+ self.eps_ratio = eps_ratio
+ self.min_samples = min_samples
+
+ def apply(
+ self,
+ image: Image.Image,
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
+ from sklearn.cluster import DBSCAN
+
+ if len(detections) < 9:
return []
- return question_answer_pair
+ # Get centers and labels
+ centers = []
+ labels = []
+ for detection in detections:
+ bbox = detection.as_xyxy().squeeze(0)
+ x_center = float((bbox[0] + bbox[2]) / 2)
+ y_center = float((bbox[1] + bbox[3]) / 2)
+ centers.append([x_center, y_center])
+ labels.append(str(detection.label))
-class ObjectsInRow(Question):
- def __init__(self) -> None:
+ centers = np.array(centers)
+
+ # Calculate eps as a fraction of image diagonal
+ image_width, image_height = image.size
+ image_diagonal = math.sqrt(image_width**2 + image_height**2)
+ eps = self.eps_ratio * image_diagonal
+
+ # Apply DBSCAN
+ clustering = DBSCAN(
+ eps=eps, min_samples=self.min_samples).fit(centers)
+ cluster_labels = clustering.labels_
+
+ # Group objects by cluster (ignore noise points with label -1)
+ clusters = {}
+ for i, cluster_id in enumerate(cluster_labels):
+ if cluster_id != -1: # Not noise
+ if cluster_id not in clusters:
+ clusters[cluster_id] = []
+ clusters[cluster_id].append(labels[i])
+
+ if len(clusters) < 2:
+ return [] # Need at least 2 clusters to compare
+
+ # Find most compact cluster
+ def cluster_compactness(cluster_id):
+ cluster_points = centers[cluster_labels == cluster_id]
+ if len(cluster_points) < 2:
+ return float('inf')
+ return np.mean(np.var(cluster_points, axis=0))
+
+ most_compact_id = min(clusters.keys(), key=cluster_compactness)
+ most_compact_objects = list(
+ set(clusters[most_compact_id])) # Remove duplicates
+
+ # Create multiple choice options
+ correct_text = ", ".join(sorted(most_compact_objects))
+
+ # Generate distractors from other clusters or random combinations
+ all_unique_labels = list(set(labels))
+ random.shuffle(all_unique_labels)
+
+ # Create distractors ensuring they're different from correct answer
+ distractor1 = ", ".join(
+ all_unique_labels[:min(3, len(all_unique_labels))])
+ distractor2 = ", ".join(
+ all_unique_labels[-min(2, len(all_unique_labels)):])
+
+ # Ensure distractors are different from correct answer
+ max_attempts = 10
+ attempt = 0
+ while (distractor1 == correct_text or distractor2 == correct_text or distractor1 == distractor2) and attempt < max_attempts:
+ random.shuffle(all_unique_labels)
+ distractor1 = ", ".join(
+ all_unique_labels[:min(3, len(all_unique_labels))])
+ distractor2 = ", ".join(
+ all_unique_labels[-min(2, len(all_unique_labels)):])
+ attempt += 1
+
+ # If still duplicates after attempts, skip this question
+ if distractor1 == correct_text or distractor2 == correct_text or distractor1 == distractor2:
+ return []
+
+ # Randomly assign correct answer
+ options = [correct_text, distractor1, distractor2]
+ random.shuffle(options)
+ correct_letter = ["A", "B", "C"][options.index(correct_text)]
+
+ q = self.question.format(
+ option_a=options[0],
+ option_b=options[1],
+ option_c=options[2]
+ )
+
+ return [(q, correct_letter)]
+
+
+class MoreThanThresholdHowMany(Question):
+ """More-than count question with built-in Yes/No balance.
+
+ For each detected object class with count *N* we generate two prompts:
+
+ 1. *Yes case* β target = βN / thresholdβ.
+ The detector's count is safely above the target, so the correct answer is **Yes**.
+
+ 2. *No case* β target = βN Γ thresholdβ.
+ The detector's count is well below the target, so the correct answer is **No**.
+
+ The gap created by the multiplicative buffer acts as a hedge against recall / precision noise
+ while keeping the overall Yes/No distribution roughly balanced.
+ """
+
+ def __init__(self, threshold: float = 2.0):
+ if threshold <= 1.0:
+ raise ValueError(
+ "threshold should be > 1.0 for 'more than' questions")
+
+ self.threshold: float = threshold
super().__init__(
- question="Are there any objects arranged in a row?",
- variables=[],
+ question="Are there more than {target} {object_1}(s) in this image? Respond Yes/No.",
+ variables=["object_1", "target"],
predicates=[
lambda image, detections: ObjectDetectionPredicates.at_least_x_many_class_detections(
image, detections, 1
@@ -1423,210 +1568,731 @@ def __init__(self) -> None:
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
- if len(detections) < 3:
- return [(self.question, "No")]
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
+
+ # Count detections per class
+ counts: dict[str, int] = {}
+ for det in detections:
+ lbl = det.label
+ if isinstance(lbl, torch.Tensor):
+ for l in lbl:
+ counts[str(l)] = counts.get(str(l), 0) + 1
+ else:
+ counts[str(lbl)] = counts.get(str(lbl), 0) + 1
- bboxes = [detection.as_xyxy().squeeze(0) for detection in detections]
+ qa_pairs: list[tuple[str, str]] = []
+ for cls, n in counts.items():
+ if n == 0:
+ continue
- bboxes_sorted_by_x = sorted(
- bboxes, key=lambda bbox: bbox[0]
- ) # Sorted by left boundary
+ # Question that should be answered "Yes" (target below n)
+ target_yes = max(1, math.floor(n / self.threshold))
+ if target_yes == n:
+ target_yes = max(1, target_yes - 1)
- def y_overlap(min_y1, max_y1, min_y2, max_y2):
- inter = max(0, min(max_y1, max_y2) - max(min_y1, min_y2))
- len1 = max_y1 - min_y1
- len2 = max_y2 - min_y2
- min_len = min(len1, len2)
+ q_yes = self.question.format(object_1=cls, target=target_yes)
+ qa_pairs.append((q_yes, "Yes"))
- # two objects are considered on the same line only if the y overlap is at least 50% of the smaller object.
- # TODO: add this as a threshold.
- return inter >= 0.5 * min_len
+ # Question that should be answered "No" (target well above n)
+ target_no = math.ceil(n * self.threshold)
+ if target_no == n:
+ target_no += 1
- def check_row_alignment(bboxes_sorted):
- for i in range(len(bboxes_sorted) - 2):
- box1, box2, box3 = (
- bboxes_sorted[i],
- bboxes_sorted[i + 1],
- bboxes_sorted[i + 2],
- )
+ q_no = self.question.format(object_1=cls, target=target_no)
+ qa_pairs.append((q_no, "No"))
- # Require >=50% y-overlap for each adjacent pair
- if y_overlap(box1[1], box1[3], box2[1], box2[3]) and y_overlap(
- box2[1], box2[3], box3[1], box3[3]
- ):
- return True
+ return qa_pairs
- return False
- row_detected = check_row_alignment(bboxes_sorted_by_x)
+class LessThanThresholdHowMany(Question):
+ """Less-than count question with symmetric Yes/No balance.
- answer = "Yes" if row_detected else "No"
- return [(self.question, answer)]
+ For detected count *N* we generate:
+ 1. *Yes case* β target = βN / thresholdβ (> N), so the answer **Yes** is correct.
+ 2. *No case* β target = βN Γ thresholdβ (< N), so **No** is correct.
-class ObjectsInLine(Question):
- def __init__(self) -> None:
+ This mirrors the more-than version and maintains balanced answer keys while
+ providing a tolerance band for detector errors.
+ """
+
+ def __init__(self, threshold: float = 0.5):
+ if not (0.0 < threshold < 1.0):
+ raise ValueError(
+ "threshold must be between 0 and 1 for 'less than'")
+
+ self.threshold: float = threshold
super().__init__(
- question="What objects are arranged in a row?",
- variables=[],
+ question="Are there less than {target} {object_1}(s) in this image? Respond Yes/No.",
+ variables=["object_1", "target"],
predicates=[
- # TODO: at least 3 detections
- lambda image, detections: ObjectDetectionPredicates.at_least_x_detections(
- image, detections, 3
- ),
lambda image, detections: ObjectDetectionPredicates.at_least_x_many_class_detections(
image, detections, 1
),
- lambda image, detections: ObjectsInRow().apply(image, detections)[0][1]
- == "Yes",
],
)
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
- bboxes = [detection.as_xyxy().squeeze(0) for detection in detections]
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
+
+ counts: dict[str, int] = {}
+ for det in detections:
+ lbl = det.label
+ if isinstance(lbl, torch.Tensor):
+ for l in lbl:
+ counts[str(l)] = counts.get(str(l), 0) + 1
+ else:
+ counts[str(lbl)] = counts.get(str(lbl), 0) + 1
+
+ qa_pairs: list[tuple[str, str]] = []
+ for cls, n in counts.items():
+ if n == 0:
+ continue
- detections_sorted_by_x = sorted(
- detections, key=lambda detection: detection.as_xyxy().squeeze(0)[0]
+ # Question that should be answered "Yes" (target above n)
+ target_yes = math.ceil(n / self.threshold)
+ if target_yes == n:
+ target_yes += 1
+
+ q_yes = self.question.format(object_1=cls, target=target_yes)
+ qa_pairs.append((q_yes, "Yes"))
+
+ # Question that should be answered "No" (target well below n)
+ target_no = max(1, math.floor(n * self.threshold))
+ if target_no == n:
+ target_no = max(1, target_no - 1)
+
+ q_no = self.question.format(object_1=cls, target=target_no)
+ qa_pairs.append((q_no, "No"))
+
+ return qa_pairs
+
+
+class MultiChoiceHowMany(Question):
+ """Noise-tolerant *How Many* as a 3-way multiple-choice question.
+
+ Workflow per detected object class with count *N*:
+
+ 1. Build **contiguous** numeric buckets based on *N* (and confidence variance):
+ β’ *low* : `0 β βΞ± Β· Nβ`
+ β’ *mid* : `βΞ± Β· Nβ β βΞ² Β· Nβ`
+ β’ *high* : `βΞ² Β· Nβ β βΞ² Β· Nβ+w` (finite width so all three look alike)
+ where `(Ξ±, Ξ²) = (0.5, 1.5)` by default or `(0.4, 1.8)` when per-class
+ confidence variance > 0.05, and *w* equals the width of the mid bucket.
+
+ 2. Randomly **shuffle** which bucket is labelled A, B, or C. This removes
+ the positional/letter bias while the LLM still sees all ranges.
+
+ 3. The correct answer letter is determined after the shuffle so that the
+ dataset remains balanced across A/B/C over time.
+
+ 4. A fourth option **D) Unsure / Not Visible** is always listed to allow a
+ graceful fallback when the model feels uncertain.
+
+ Questions are only generated when `N β₯ 4`; for very small counts, the
+ buckets become too narrow to be useful.
+ """
+
+ def __init__(self):
+ super().__init__(
+ question="How many {object_1}(s) are in the image? Choose one: "
+ "A) {range_a}, B) {range_b}, C) {range_c}, D) Unsure / Not Visible. "
+ "Respond with the letter only.",
+ variables=["object_1", "range_a", "range_b", "range_c"],
+ predicates=[
+ lambda image, detections: ObjectDetectionPredicates.at_least_x_many_class_detections(
+ image, detections, 1
+ ),
+ ],
)
- bboxes_sorted_by_x = [
- detection.as_xyxy().squeeze(0) for detection in detections_sorted_by_x
- ]
- def y_overlap(min_y1, max_y1, min_y2, max_y2):
- inter = max(0, min(max_y1, max_y2) - max(min_y1, min_y2))
- len1 = max_y1 - min_y1
- len2 = max_y2 - min_y2
- min_len = min(len1, len2)
-
- return inter >= 0.5 * min_len
-
- def find_rows(bboxes_sorted) -> List[List[int]]:
- rows = []
- i = 0
- while i < len(bboxes_sorted) - 2:
- current_row_indices = [i]
- for j in range(i + 1, len(bboxes_sorted)):
- if y_overlap(
- bboxes_sorted[j - 1][1],
- bboxes_sorted[j - 1][3],
- bboxes_sorted[j][1],
- bboxes_sorted[j][3],
- ):
- current_row_indices.append(j)
- else:
- break
- if len(current_row_indices) >= 3:
- rows.append(current_row_indices)
- i += len(current_row_indices)
+ def _bucket_ranges(self, n: int, var: float) -> tuple[dict[str, str], str]:
+ """Return bucket description dict and the *semantic* correct bucket key.
+
+ Keys: "low", "mid", "high" β string description "xβy" (inclusive).
+ Also returns which *bucket key* contains ``n`` so we can map it to the
+ shuffled letter later.
+ """
+
+ # Variance-based adjustment of coefficients
+ low_coef, mid_high_coef = (0.4, 1.8) if var > 0.05 else (0.5, 1.5)
+
+ # Bucket boundaries (inclusive)
+ low_max = max(0, int((low_coef * n) - 1e-6))
+ mid_min = low_max + 1
+ mid_max = int(mid_high_coef * n)
+ high_min = mid_max + 1
+
+ # Make the high bucket a finite *range* with similar width to mid bucket
+ mid_width = mid_max - mid_min
+ high_max = high_min + max(2, mid_width) # ensure non-zero width
+
+ buckets: dict[str, str] = {
+ "low": f"0-{low_max}" if low_max > 0 else "0-{mid_min-1}",
+ "mid": f"{mid_min}-{mid_max}",
+ "high": f"{high_min}-{high_max}",
+ }
+
+ # With fixed Ξ±/Ξ² the detected count N always lands in the mid bucket,
+ # so we can simply hard-code it instead of checking.
+ correct_bucket = "mid"
+
+ return buckets, correct_bucket
+
+ def apply(
+ self,
+ image: Image.Image,
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
+
+ counts: dict[str, int] = {}
+ for det in detections:
+ lbl = det.label
+ if isinstance(lbl, torch.Tensor):
+ for l in lbl:
+ counts[str(l)] = counts.get(str(l), 0) + 1
+ else:
+ counts[str(lbl)] = counts.get(str(lbl), 0) + 1
+
+ qa_pairs: list[tuple[str, str]] = []
+ for cls, n in counts.items():
+ if n < 4:
+ continue
+ # extract per-detection confidences for this class
+ scores: list[float] = []
+ for det in detections:
+ lbl = det.label
+ conf = getattr(det, "score", getattr(det, "confidence", 1.0))
+ if isinstance(lbl, torch.Tensor):
+ for idx in range(lbl.shape[0]):
+ if str(lbl[idx]) == cls:
+ scores.append(float(conf[idx]) if isinstance(
+ conf, torch.Tensor) else float(conf))
else:
- i += 1
- return rows
+ if str(lbl) == cls:
+ scores.append(float(conf))
- rows = find_rows(bboxes_sorted_by_x)
+ var = float(np.var(scores)) if len(scores) > 1 else 0.0
- if not rows:
- return [(self.question, "None")]
+ buckets, correct_bucket = self._bucket_ranges(n, var)
- # Collect object names per row
- row_descriptions = []
- for idx, row in enumerate(rows):
- object_names = [detections_sorted_by_x[r]._label for r in row]
- row_descriptions.append(f"Row {idx+1}: {', '.join(object_names)}")
+ # Randomly permute letter β bucket mapping to avoid letter bias
+ letters = ["A", "B", "C"]
+ random.shuffle(letters)
+ bucket_keys = ["low", "mid", "high"]
- return [(self.question, " | ".join(row_descriptions))]
+ letter_to_bucket = {letter: bucket for letter,
+ bucket in zip(letters, bucket_keys)}
+ # Build question text in A/B/C order after permutation
+ q = self.question.format(
+ object_1=cls,
+ range_a=buckets[letter_to_bucket["A"].lower()],
+ range_b=buckets[letter_to_bucket["B"].lower()],
+ range_c=buckets[letter_to_bucket["C"].lower()],
+ )
-class MostClusteredObjects(Question):
- def __init__(self, threshold=100) -> None:
+ # Identify the letter assigned to the mid bucket (the correct answer)
+ correct_letter = {bkey: ltr for ltr, bkey in letter_to_bucket.items()}[
+ "mid"]
+
+ qa_pairs.append((q, correct_letter))
+
+ return qa_pairs
+
+
+class FrontOf(Question):
+ def __init__(self, margin_ratio: float = 0.1) -> None:
+ """
+ FrontOf question using depth perception and SAM segmentation.
+
+ Args:
+ margin_ratio: Required relative depth difference for reliable comparison.
+ Objects must differ by at least this fraction of the closer object's depth.
+ """
super().__init__(
- question="What group of objects are most clustered together?",
- variables=[],
+ question="Is there at least one {object_1} in front of any {object_2}?",
+ variables=["object_1", "object_2"],
predicates=[
lambda image, detections: ObjectDetectionPredicates.at_least_x_many_class_detections(
- image, detections, 2 # Need at least 2 to form a cluster
+ image, detections, 2
),
- lambda image, detections: ObjectDetectionPredicates.has_clusters(
- image, detections, threshold=threshold
+ ObjectDetectionPredicates.exists_non_overlapping_detections,
+ ],
+ )
+ if margin_ratio <= 0 or margin_ratio >= 1:
+ raise ValueError("margin_ratio must be between 0 and 1")
+ self.margin_ratio = margin_ratio
+
+ # Initialize SAM and DepthPro models lazily
+ self._sam_predictor = None
+ self._depth_model = None
+
+ def _get_sam_predictor(self):
+ """Lazy initialization of SAM predictor."""
+ if self._sam_predictor is None:
+ from graid.utilities.sam_utils import SAMPredictor
+ self._sam_predictor = SAMPredictor()
+ return self._sam_predictor
+
+ def _get_depth_model(self):
+ """Lazy initialization of DepthPro model."""
+ if self._depth_model is None:
+ from graid.models.DepthPro import DepthPro
+ self._depth_model = DepthPro()
+ return self._depth_model
+
+ def apply(
+ self,
+ image: Image.Image,
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
+ """
+ Apply FrontOf question using depth-based comparison.
+
+ @precondition: at_least_x_many_class_detections(image, detections, 2) == True
+ @precondition: exists_non_overlapping_detections(image, detections) == True
+ """
+ from graid.utilities.sam_utils import compare_object_depths
+
+ try:
+ # Get depth map for the image
+ depth_model = self._get_depth_model()
+ depth_result = depth_model.predict_depth(image)
+ depth_map = depth_result.depth_prediction
+
+ # Get SAM predictor
+ sam_predictor = self._get_sam_predictor()
+
+ # Group detections by class
+ detections_by_class = {}
+ for detection in detections:
+ class_name = detection.label
+ if isinstance(class_name, torch.Tensor):
+ # Handle tensor labels - flatten to individual detections
+ flattened = detection.flatten()
+ for flat_det in flattened:
+ cls = str(flat_det.label)
+ if cls not in detections_by_class:
+ detections_by_class[cls] = []
+ detections_by_class[cls].append(flat_det)
+ else:
+ cls = str(class_name)
+ if cls not in detections_by_class:
+ detections_by_class[cls] = []
+ detections_by_class[cls].append(detection)
+
+ # Generate question-answer pairs by comparing different object classes
+ question_answer_pairs = []
+ class_names = list(detections_by_class.keys())
+
+ for i in range(len(class_names)):
+ for j in range(i + 1, len(class_names)):
+ obj1_class = class_names[i]
+ obj2_class = class_names[j]
+
+ obj_1_detections = detections_by_class[obj1_class]
+ obj_2_detections = detections_by_class[obj2_class]
+
+ # Track evidence across all non-overlapping pairs
+ found_yes_case = False # decisive evidence that obj1 is in front of obj2 (or swapped)
+ total_pairs = 0
+ decisive_pairs = 0
+ all_opposite = True # every decisive pair showed obj2 in front of obj1
+
+ # Try all combinations of objects from the two classes
+ for obj_1_det in obj_1_detections:
+ for obj_2_det in obj_2_detections:
+ # Check if objects are non-overlapping
+ iou = ObjectDetectionUtils.pairwise_iou(obj_1_det, obj_2_det)
+ if iou.max() > 0:
+ continue # Skip overlapping objects
+ total_pairs += 1
+
+ # Get refined masks using SAM
+ mask1 = sam_predictor.get_mask_from_bbox(image, obj_1_det)
+ mask2 = sam_predictor.get_mask_from_bbox(image, obj_2_det)
+
+ if mask1 is None or mask2 is None:
+ # ambiguous; do not count as decisive
+ all_opposite = False
+ continue
+
+ # Compare depths using masks
+ comparison, depth1, depth2 = compare_object_depths(
+ depth_map, obj_1_det, mask1, obj_2_det, mask2, self.margin_ratio
+ )
+
+ if comparison is None:
+ # ambiguous; do not count as decisive
+ all_opposite = False
+ continue
+
+ decisive_pairs += 1
+ if comparison == "object1_front":
+ # obj1 is in front of obj2 β ask the direct question with Yes
+ question = self.question.format(object_1=obj1_class, object_2=obj2_class)
+ question_answer_pairs.append((question, "Yes"))
+ found_yes_case = True
+ break
+ elif comparison == "object2_front":
+ # obj2 is in front of obj1 β ask the swapped question with Yes
+ question = self.question.format(object_1=obj2_class, object_2=obj1_class)
+ question_answer_pairs.append((question, "Yes"))
+ # keep scanning to ensure no conflicting evidence for direct order
+ else:
+ all_opposite = False
+ if found_yes_case:
+ break
+
+ if found_yes_case:
+ # We already emitted a decisive Yes; do not emit a No for the opposite
+ continue
+
+ # Only emit a "No" when it is decisively obvious according to the margin:
+ # - we evaluated at least one non-overlapping pair
+ # - every decisive comparison contradicted the direct question (obj2 in front)
+ # - and there were no ambiguous comparisons
+ if total_pairs > 0 and decisive_pairs == total_pairs and all_opposite:
+ question = self.question.format(object_1=obj1_class, object_2=obj2_class)
+ question_answer_pairs.append((question, "No"))
+
+ return question_answer_pairs
+
+ except Exception as e:
+ logger.debug(f"FrontOf question failed: {e}")
+ return []
+
+
+class BehindOf(Question):
+ def __init__(self, margin_ratio: float = 0.1) -> None:
+ """
+ BehindOf question using depth perception and SAM segmentation.
+
+ Args:
+ margin_ratio: Required relative depth difference for reliable comparison.
+ """
+ super().__init__(
+ question="Is there at least one {object_1} behind any {object_2}?",
+ variables=["object_1", "object_2"],
+ predicates=[
+ lambda image, detections: ObjectDetectionPredicates.at_least_x_many_class_detections(
+ image, detections, 2
),
+ ObjectDetectionPredicates.exists_non_overlapping_detections,
],
)
- self.threshold = threshold
+ if margin_ratio <= 0 or margin_ratio >= 1:
+ raise ValueError("margin_ratio must be between 0 and 1")
+ self.margin_ratio = margin_ratio
+
+ # Initialize SAM and DepthPro models lazily
+ self._sam_predictor = None
+ self._depth_model = None
+
+ def _get_sam_predictor(self):
+ """Lazy initialization of SAM predictor."""
+ if self._sam_predictor is None:
+ from graid.utilities.sam_utils import SAMPredictor
+ self._sam_predictor = SAMPredictor()
+ return self._sam_predictor
+
+ def _get_depth_model(self):
+ """Lazy initialization of DepthPro model."""
+ if self._depth_model is None:
+ from graid.models.DepthPro import DepthPro
+ self._depth_model = DepthPro()
+ return self._depth_model
def apply(
self,
image: Image.Image,
- detections: List[ObjectDetectionResultI],
- ) -> List[Tuple[str, str]]:
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
+ """
+ Apply BehindOf question using depth-based comparison.
+
+ @precondition: at_least_x_many_class_detections(image, detections, 2) == True
+ @precondition: exists_non_overlapping_detections(image, detections) == True
+ """
+ from graid.utilities.sam_utils import compare_object_depths
+
+ try:
+ # Get depth map for the image
+ depth_model = self._get_depth_model()
+ depth_result = depth_model.predict_depth(image)
+ depth_map = depth_result.depth_prediction
+
+ # Get SAM predictor
+ sam_predictor = self._get_sam_predictor()
+
+ # Group detections by class
+ detections_by_class = {}
+ for detection in detections:
+ class_name = detection.label
+ if isinstance(class_name, torch.Tensor):
+ # Handle tensor labels - flatten to individual detections
+ flattened = detection.flatten()
+ for flat_det in flattened:
+ cls = str(flat_det.label)
+ if cls not in detections_by_class:
+ detections_by_class[cls] = []
+ detections_by_class[cls].append(flat_det)
+ else:
+ cls = str(class_name)
+ if cls not in detections_by_class:
+ detections_by_class[cls] = []
+ detections_by_class[cls].append(detection)
+
+ # Generate question-answer pairs by comparing different object classes
+ question_answer_pairs = []
+ class_names = list(detections_by_class.keys())
+
+ for i in range(len(class_names)):
+ for j in range(i + 1, len(class_names)):
+ obj1_class = class_names[i]
+ obj2_class = class_names[j]
+
+ obj_1_detections = detections_by_class[obj1_class]
+ obj_2_detections = detections_by_class[obj2_class]
+
+ found_yes_case = False
+ # Try all combinations of objects from the two classes
+ for obj_1_det in obj_1_detections:
+ for obj_2_det in obj_2_detections:
+ # Check if objects are non-overlapping
+ iou = ObjectDetectionUtils.pairwise_iou(obj_1_det, obj_2_det)
+ if iou.max() > 0:
+ continue # Skip overlapping objects
+
+ # Get refined masks using SAM
+ mask1 = sam_predictor.get_mask_from_bbox(image, obj_1_det)
+ mask2 = sam_predictor.get_mask_from_bbox(image, obj_2_det)
+
+ if mask1 is None or mask2 is None:
+ continue
+
+ # Compare depths using masks
+ comparison, depth1, depth2 = compare_object_depths(
+ depth_map, obj_1_det, mask1, obj_2_det, mask2, self.margin_ratio
+ )
+
+ if comparison is None:
+ continue # Depths too similar for reliable comparison
+
+ # Generate question-answer pairs (behind means farther = higher depth)
+ if comparison == "object2_front":
+ # obj1 is behind obj2 (obj2 is in front)
+ question = self.question.format(
+ object_1=obj1_class, object_2=obj2_class
+ )
+ question_answer_pairs.append((question, "Yes"))
+ found_yes_case = True
+ break
+ elif comparison == "object1_front":
+ # obj2 is behind obj1 (obj1 is in front)
+ question = self.question.format(
+ object_1=obj2_class, object_2=obj1_class
+ )
+ question_answer_pairs.append((question, "Yes"))
+ found_yes_case = True
+ break
+ if found_yes_case:
+ break
+
+ if not found_yes_case:
+ question = self.question.format(object_1=obj1_class, object_2=obj2_class)
+ question_answer_pairs.append((question, "No"))
+ question = self.question.format(object_1=obj2_class, object_2=obj1_class)
+ question_answer_pairs.append((question, "No"))
+
+ return question_answer_pairs
+
+ except Exception as e:
+ logger.debug(f"BehindOf question failed: {e}")
+ return []
- import numpy as np
- from scipy.spatial.distance import pdist, squareform
- # Get centers of all detections
- centers = []
- for detection in detections:
- bbox = detection.as_xyxy().squeeze(0)
- x_center = (bbox[0] + bbox[2]) / 2
- y_center = (bbox[1] + bbox[3]) / 2
- centers.append((x_center, y_center))
+class DepthRanking(Question):
+ """Rank the *k* object classes that are closest to the camera.
+
+ Example question (for k=3):
+
+ "Rank the 3 kinds of objects that appear the closest in the image from
+ closest to farthest. Provide your answer as a comma-separated list of
+ object names only."
+ """
- centers = np.array(centers)
+ def __init__(self, k: int, margin_ratio: float = 0.1) -> None:
+ """Create a DepthRanking question.
+
+ Args:
+ k: number of classes to rank.
+ margin_ratio: required multiplicative margin between consecutive
+ ranked depths. For class *i* to be considered closer than class
+ *i+1*, its depth must be at most `(1 - margin_ratio)` times
+ the depth of i+1. If any consecutive pair fails this criterion, the
+ question will be skipped for that image.
+ """
+ if k <= 0:
+ raise ValueError("k must be a positive integer")
+ if not (0 < margin_ratio < 1):
+ raise ValueError("margin_ratio must be between 0 and 1")
+
+ self.k: int = k
+ self.margin_ratio: float = margin_ratio
+ super().__init__(
+ question=(
+ "Rank the {k} kinds of objects that appear the closest to the camera in the "
+ "image from closest to farthest. Provide your answer as a "
+ "comma-separated list of object names only."
+ ),
+ variables=["k"],
+ predicates=[
+ # Need at least k different classes detected
+ lambda image, detections, k=k: ObjectDetectionPredicates.at_least_x_many_class_detections(
+ image, detections, k
+ ),
+ ],
+ )
- # Compute pairwise distances
- dists = squareform(pdist(centers))
+ # Initialize SAM and DepthPro models lazily
+ self._sam_predictor = None
+ self._depth_model = None
- # Simple clustering by distance threshold (e.g., 50 pixels)
- visited = set()
- clusters = []
+ def _get_sam_predictor(self):
+ """Lazy initialization of SAM predictor."""
+ if self._sam_predictor is None:
+ from graid.utilities.sam_utils import SAMPredictor
+ self._sam_predictor = SAMPredictor()
+ return self._sam_predictor
- for i in range(len(centers)):
- if i in visited:
- continue
- cluster = [i]
- visited.add(i)
- for j in range(len(centers)):
- if j not in visited and dists[i][j] < self.threshold:
- cluster.append(j)
- visited.add(j)
- if len(cluster) >= 2:
- clusters.append(cluster)
-
- def compactness(cluster_indices):
- cluster_centers = centers[cluster_indices]
- if len(cluster_centers) < 2:
- return float("inf")
- return pdist(cluster_centers).mean()
-
- clusters.sort(key=lambda c: compactness(c))
- most_compact_cluster = clusters[0]
-
- object_names = [detections[i]._label for i in most_compact_cluster]
- return [(self.question, f"{', '.join(object_names)}")]
-
-
-ALL_QUESTIONS = [
- IsObjectCentered(),
- WidthVsHeight(),
- LargestAppearance(),
- MostAppearance(),
- LeastAppearance(),
- LeftOf(),
- RightOf(),
- LeftMost(),
- RightMost(),
- HowMany(),
- MostClusteredObjects(),
- WhichMore(),
- AreMore(),
- Quadrants(2, 2),
- Quadrants(2, 3),
- Quadrants(3, 2),
- Quadrants(3, 3),
- LeftMostWidthVsHeight(),
- RightMostWidthVsHeight(),
-]
+ def _get_depth_model(self):
+ """Lazy initialization of DepthPro model."""
+ if self._depth_model is None:
+ from graid.models.DepthPro import DepthPro
+ self._depth_model = DepthPro()
+ return self._depth_model
+
+ def apply(
+ self,
+ image: Image.Image,
+ detections: list[ObjectDetectionResultI],
+ ) -> list[tuple[str, str]]:
+ if len(detections) == 0:
+ logger.debug("No detections for DepthRanking question")
+ return []
+
+ from graid.utilities.sam_utils import extract_average_depth_from_mask
+
+ try:
+ depth_model = self._get_depth_model()
+ depth_result = depth_model.predict_depth(image)
+ depth_map = depth_result.depth_prediction
+
+ sam_predictor = self._get_sam_predictor()
+
+ # Build min-depth per class dictionary
+ class_min_depth: dict[str, float] = {}
+ for detection in detections:
+ label = detection.label
+ if isinstance(label, torch.Tensor):
+ # Iterate through tensor labels (multiple boxes per detection)
+ for idx in range(label.shape[0]):
+ cls_name = str(label[idx].item())
+
+ # Create a single detection result for this box
+ # Handle both scalar and tensor score/cls
+ if isinstance(detection.score, torch.Tensor):
+ det_score = detection.score[idx]
+ else:
+ det_score = detection.score
+
+ if isinstance(detection.cls, torch.Tensor):
+ det_cls = detection.cls[idx]
+ else:
+ det_cls = detection.cls
+
+ single_detection = ObjectDetectionResultI(
+ score=det_score,
+ cls=det_cls,
+ label=cls_name,
+ bbox=detection.as_xyxy()[idx],
+ image_hw=detection.as_ultra_box.orig_shape,
+ )
+
+ mask = sam_predictor.get_mask_from_bbox(image, single_detection)
+ if mask is not None:
+ depth = extract_average_depth_from_mask(depth_map, mask)
+ if depth is not None:
+ class_min_depth[cls_name] = min(
+ class_min_depth.get(cls_name, float('inf')), depth
+ )
+ else:
+ cls_name = str(label)
+ mask = sam_predictor.get_mask_from_bbox(image, detection)
+ if mask is not None:
+ depth = extract_average_depth_from_mask(depth_map, mask)
+ if depth is not None:
+ class_min_depth[cls_name] = min(
+ class_min_depth.get(cls_name, float('inf')), depth
+ )
+
+ if len(class_min_depth) < self.k:
+ logger.debug("Not enough unique classes with depth for DepthRanking question")
+ return []
+
+ # Sort classes by their closest instance depth
+ sorted_classes = sorted(
+ class_min_depth.items(), key=lambda item: item[1]
+ )
+
+ # Verify margin criterion among top-k depths
+ top_k = sorted_classes[: self.k]
+ for i in range(len(top_k) - 1):
+ depth_i = top_k[i][1]
+ depth_next = top_k[i + 1][1]
+ if depth_i > (1 - self.margin_ratio) * depth_next:
+ logger.debug(
+ "DepthRanking margin threshold not met between %s and %s", top_k[i][0], top_k[i+1][0]
+ )
+ return []
+
+ top_k_labels = [cls for cls, _ in top_k]
+
+ question = self.question.format(k=self.k)
+ answer = ", ".join(map(str, top_k_labels))
+ return [(question, answer)]
+
+ except Exception as e:
+ logger.debug(f"DepthRanking question failed: {e}")
+ return []
+
+
+# Dynamically discover all Question classes in this module
+import inspect
+import sys
+
+def _build_all_questions():
+ """Build ALL_QUESTIONS list by discovering all Question subclasses in this module."""
+ current_module = sys.modules[__name__]
+ question_classes = {}
+
+ # Find all classes that inherit from Question
+ for name, obj in inspect.getmembers(current_module, inspect.isclass):
+ if (issubclass(obj, Question) and
+ obj != Question and # Exclude the base class
+ hasattr(obj, 'is_applicable')): # Ensure it's a concrete question class
+ question_classes[name] = obj
+
+ return question_classes
+
+# Build the dictionary of available question classes
+ALL_QUESTION_CLASSES = _build_all_questions()
+
+# Keep the old ALL_QUESTIONS for backward compatibility, but it's no longer used
+ALL_QUESTIONS = []
diff --git a/graid/src/graid/questions/QUESTION_ROBUSTNESS.md b/graid/src/graid/questions/QUESTION_ROBUSTNESS.md
new file mode 100644
index 0000000..afb3297
--- /dev/null
+++ b/graid/src/graid/questions/QUESTION_ROBUSTNESS.md
@@ -0,0 +1,236 @@
+### General Strategies for Robustness
+
+Before diving into specific questions, two general strategies can improve robustness across the board:
+
+1. **Confidence-Based Filtering:** Almost all detector outputs include a confidence score for each bounding box. A simple and highly effective strategy is to only consider detections above a certain confidence threshold (e.g., 0.75). This directly mitigates issues from low-confidence, often incorrect, labels.
+2. **Semantic Grouping:** Create a class hierarchy (e.g., 'car', 'bus', 'truck' -> 'vehicle'; 'cat', 'dog' -> 'animal'). Performing analysis on these parent classes makes the logic robust to misclassifications within a super-category (e.g., mistaking a 'truck' for a 'car' doesn't affect the 'vehicle' count).
+
+---
+
+### 1. `IsObjectCentered` β
IMPLEMENTED
+
+* **Question:** Asks which third of the image a single-instance object is in.
+* **Error Impact:**
+ * **Low Recall:** If one of two `person` objects is missed, the system will incorrectly assume there is only one `person` and ask this question about it.
+ * **Wrong Label:** If a `chair` is mislabeled as a `person`, the question might be asked about the position of the `chair` under the `person` label, leading to a factually incorrect Q&A pair about the scene.
+
+* **IMPLEMENTED SOLUTIONS:**
+ 1. β
**Ambiguity Buffer:** Added configurable `buffer_ratio` (default 5% of image width) that creates no-ask zones around the one-third and two-third lines. Questions are skipped if any edge of the bounding box falls within these buffer zones.
+ 2. β
**Multiple Choice Format:** Converted to multiple choice with clear instructions: "Divide the image into thirds. In which third does the {object_1} primarily appear? Respond with the letter only: A) left third, B) middle third, C) right third."
+ 3. β
**Letter-Only Answers:** Answers are now just "A", "B", or "C", eliminating ambiguity in response format.
+
+---
+
+### 2. `WidthVsHeight` β
IMPLEMENTED
+
+* **Question:** Asks if a single-instance object is wider than it is tall (or vice-versa).
+* **Error Impact:**
+ * **Low Recall:** Same as `IsObjectCentered`, can lead to mistakenly identifying an object as a single instance.
+ * **Wrong Label:** The question would be about the aspect ratio of a mislabeled object. The geometric answer might be correct for the box, but semantically wrong for the scene.
+
+* **IMPLEMENTED SOLUTIONS:**
+ 1. β
**Increased Aspect Ratio Threshold:** Changed from 0.3 to 0.75, so questions are only asked when width is at least 1.75x height (or vice versa), making it robust to minor bounding box noise.
+ 2. β
**Non-Articulated Classes Filter:** Added `non_articulated_classes` parameter that restricts questions to objects with fixed aspect ratios (e.g., cars, chairs, tables). Excludes pose-variant objects like people and animals.
+ 3. β
**Clear Yes/No Answers:** Maintained simple "yes"/"no" format for consistency with other questions.
+
+---
+
+### 3. `Quadrants` β
IMPLEMENTED
+
+* **Question:** Asks which quadrant of an N x M grid a single-instance object is in.
+* **Error Impact:**
+ * **Low Recall / Wrong Label:** The same impact as for `IsObjectCentered`.
+
+* **IMPLEMENTED SOLUTIONS:**
+ 1. β
**Margin-Based Filtering:** Added configurable `margin_ratio` (default 10% of quadrant size) that requires bounding boxes to be fully contained within quadrants with safety margins. Prevents questions about objects near quadrant boundaries.
+ 2. β
**Size-Based Filtering:** Only asks questions if the object is small enough to fit within a quadrant with margins, avoiding ambiguous cases with large objects.
+ 3. β
**Numeric Quadrant Answers:** Maintains clear numeric answers (1, 2, 3, etc.) based on left-to-right, top-to-bottom numbering.
+
+---
+
+### 4. `LargestAppearance` + NEW: `RankLargestK` β
IMPLEMENTED
+
+* **Question:** `LargestAppearance` asks which object class appears largest; `RankLargestK` ranks the top K classes by largest instance.
+* **Error Impact:**
+ * **Low Recall:** Highly sensitive. If the true largest object is not detected, the answer is guaranteed to be wrong.
+ * **Wrong Label:** If the largest object is mislabeled (e.g., a `bus` is labeled as a `car`), the answer will be the incorrect class.
+
+* **IMPLEMENTED SOLUTIONS:**
+ 1. β
**New RankLargestK Question:** Created new question class that ranks top K object classes by their largest single instance. Takes `k` parameter and `margin_ratio` for robust ranking.
+ 2. β
**Margin-Based Ranking:** RankLargestK requires significant area differences between consecutive ranks (default 30% margin) to ensure robust ordering against detection noise.
+ 3. β
**Comma-Separated List Format:** Answer format is "car, bus, person" eliminating linguistic complexity around ordinals like "first", "second", etc.
+ 4. β
**Largest-Per-Class Logic:** Both questions now compare the largest instance of each class rather than all individual detections.
+
+---
+
+### 5. `MostAppearance` & 6. `LeastAppearance` β
IMPLEMENTED
+
+* **Question:** Asks which object class appears most or least frequently.
+* **Error Impact:**
+ * **Low Recall:** Very sensitive. Missing a few instances of a class can easily change its rank from most to least frequent.
+ * **Wrong Label:** Very sensitive. A single mislabeling of a `car` to a `truck` affects two counts simultaneously.
+
+* **IMPLEMENTED SOLUTIONS:**
+ 1. β
**Margin-Based Count Filtering:** Added configurable `margin_ratio` (default 20%) that requires the winning class count to exceed the runner-up by the specified margin. For MostAppearance: `top_count > (1 + margin) * second_count`.
+ 2. β
**Robust Count Comparison:** Questions are only asked when there's a clear winner, providing a robustness buffer against small counting errors.
+ 3. β
**Consistent Answer Format:** Maintains simple class name answers for easy parsing and evaluation.
+
+---
+
+### 7. `LeftOf` & 8. `RightOf`
+
+* **Question:** Asks if an instance of `{object_1}` is to the left/right of an instance of `{object_2}`.
+* **Error Impact:**
+ * **Low Recall:** Can cause false negatives. If the only `person` to the left of a `tree` is missed, the answer will be incorrectly "No".
+ * **Wrong Label:** Can cause false positives. If a `lamppost` to the left of a `tree` is mislabeled as a `person`, a factually incorrect Q&A will be generated.
+
+* **Proposed Solutions:**
+ 1. **Require Unambiguous Separation:** Strengthen the condition. Instead of just one instance being to the left of another, require that *all* instances of `{object_1}` are to the left of *all* instances of `{object_2}`. This can be checked by verifying `max_x(all_obj1) < min_x(all_obj2)`. This asks the question only in the clearest, most unambiguous scenarios.
+ 2. **Aggregate Position:** Base the decision on the average position (centroid) of each class. Ask the question only if the centroid of all `{object_1}` instances is significantly to the left of the centroid of all `{object_2}` instances. This is robust to one or two outliers.
+ 3. **Answer with "Sometimes":** If the condition is not absolute (i.e., some `obj1` are left of `obj2`, but some are not), introduce a "Sometimes" or "In some cases" answer. This more accurately reflects complex scenes.
+
+---
+
+### 9. `LeftMost` & 10. `RightMost`
+
+* **Question:** Asks for the class label of the leftmost/rightmost object.
+* **Error Impact:**
+ * **Low Recall:** Highly sensitive. If the true leftmost object is not detected, the question is being answered about the wrong object, and the answer is guaranteed to be incorrect.
+ * **Wrong Label:** The system correctly identifies the leftmost box, but gives it the wrong label.
+
+* **Proposed Solutions:**
+ 1. **Class-Agnostic Questioning:** Rephrase the question to be about the *properties* of the leftmost object, sidestepping the label issue. For example, "Does the leftmost object appear wider than it is tall?". This is the approach taken by `LeftMostWidthVsHeight` and is very robust to label error.
+ 2. **Check for Ambiguity:** Before asking, check if the second-leftmost object is very close to the leftmost one. If their positions are nearly identical, the title of "leftmost" is ambiguous and sensitive to error. In this case, either avoid the question or mention both objects in the answer.
+ 3. **"Set-of-Mark" Verification:** As the code comments suggest, this is a prime candidate for Set-of-Mark prompting. Generate the image with all detected boxes drawn on it. Feed this to a VQA model and ask, "What is the label of the object in the leftmost box?". The VQA may be able to correct the detector's label error.
+
+---
+
+### 11. `HowMany` β REPLACED WITH 3 NEW QUESTIONS β
IMPLEMENTED
+
+* **Original Question:** Asked for exact count of a specific object class.
+* **Error Impact:** Direct report of detector output, highly sensitive to both recall and precision errors.
+
+* **IMPLEMENTED REPLACEMENTS:**
+
+#### 11a. `MoreThanThresholdHowMany` β
IMPLEMENTED
+* **Question:** "Are there more than {target} {object_1}(s) in this image? Respond Yes/No."
+* **Robustness:** Uses multiplicative thresholds to create buffer zones. For detected count N, generates two questions:
+ - Yes case: target = βN / thresholdβ (answer: "Yes")
+ - No case: target = βN Γ thresholdβ (answer: "No")
+* **Benefits:** Balanced Yes/No distribution, tolerant to counting errors
+
+#### 11b. `LessThanThresholdHowMany` β
IMPLEMENTED
+* **Question:** "Are there less than {target} {object_1}(s) in this image? Respond Yes/No."
+* **Robustness:** Symmetric logic to MoreThanThresholdHowMany with inverse thresholds
+* **Benefits:** Provides complementary threshold-based questions for balanced evaluation
+
+#### 11c. `MultiChoiceHowMany` β
IMPLEMENTED
+* **Question:** "How many {object_1}(s) are in the image? Choose one: A) {range_a}, B) {range_b}, C) {range_c}, D) Unsure / Not Visible."
+* **Robustness Features:**
+ - Contiguous range buckets based on detected count (low/mid/high)
+ - Confidence variance adaptation: wider ranges for uncertain detections
+ - Random A/B/C shuffling to prevent positional bias
+ - Detector count always falls in middle bucket by design
+* **Benefits:** Tolerates Β±1-2 counting errors while maintaining clear boundaries
+
+---
+
+### 12. `AreMore`, 13. `WhichMore` β οΈ NOT YET IMPLEMENTED
+
+* **Analysis:** These questions are comparative versions of `HowMany` and suffer from the same sensitivities. The solutions for `MostAppearance` are directly applicable here. The most important is to **require a significant difference in counts** before asking the question to ensure the comparison is robust to minor counting errors.
+
+* **PLANNED SOLUTIONS:**
+ 1. **Margin-Based Count Filtering:** Apply same margin logic as MostAppearance/LeastAppearance
+ 2. **Significant Difference Requirement:** Only ask when count differences exceed threshold to avoid tie-breaking scenarios
+
+---
+
+### 14. `LeftMostWidthVsHeight`, 15. `RightMostWidthVsHeight` β οΈ NOT YET IMPLEMENTED
+
+* **Analysis:** These are excellent examples of robust question generation. By making the question class-agnostic ("Does the leftmost object..."), they are already immune to **Wrong Label** errors. The primary weakness is **Low Recall** (if the true leftmost object is missed).
+* **PLANNED SOLUTIONS:**
+ 1. **Confirm Subject Identity in Answer:** The question can be class-agnostic, but the answer can reveal the label. Q: "Does the leftmost object appear to be wider than it is tall?" A: "Yes. The object, identified as a car, is wider than it is tall." This makes any label error transparent.
+ 2. **Ensure Spatial Stability:** Before asking, confirm the identified leftmost object is significantly farther to the left than the next contender. This prevents small box errors from changing the subject of the question.
+
+---
+
+### 16. `ObjectsInRow`, 17. `ObjectsInLine`, 18. `MostClusteredObjects` β
IMPLEMENTED
+
+* **Analysis:** These questions rely on the spatial relationships between multiple objects.
+* **Error Impact:**
+ * **Low Recall:** Missing an object can break a row or cluster. Conversely, missing objects that *aren't* in a row can create the illusion of one among the remaining detections.
+ * **Wrong Label:** This primarily affects `ObjectsInLine` and `MostClusteredObjects`, which report the labels of the grouped objects.
+
+* **IMPLEMENTED SOLUTIONS:**
+
+#### 16. `ObjectsInRow` β
IMPLEMENTED
+* **Linear Regression Approach:** Replaced y-overlap heuristic with linear regression on y-centers. Uses normalized variance threshold (default 0.1 of image heightΒ²) for row detection.
+* **Sliding Window Analysis:** Tests multiple window sizes and positions to find the best linear fit.
+* **Robust Yes/No Answers:** Simple binary response format avoiding ambiguous spatial descriptions.
+
+#### 17. `ObjectsInLine` β
IMPLEMENTED
+* **Multiple Choice Format:** "Which objects appear to be arranged in a row? A) {option_a}, B) {option_b}, C) {option_c}, D) No clear row arrangement."
+* **Same Linear Regression Logic:** Uses identical statistical approach as ObjectsInRow for consistency.
+* **Distractor Deduplication:** Implements retry logic (up to 10 attempts) to ensure distractors are unique from correct answer and each other. Skips question if duplicates persist.
+* **Random Shuffling:** Randomizes A/B/C assignment to prevent positional bias.
+
+#### 18. `MostClusteredObjects` β
IMPLEMENTED
+* **DBSCAN Clustering:** Replaced distance-based clustering with DBSCAN algorithm for robust cluster detection.
+* **Image-Relative Parameters:** Uses `eps_ratio` (default 5% of image diagonal) instead of fixed pixel distances. Automatically scales for different image sizes.
+* **Increased Requirements:** Now requires β₯9 detections (3 clusters Γ 3 objects) and `min_samples=3` for meaningful clusters.
+* **Multiple Choice Format:** Same A/B/C structure as ObjectsInLine with distractor deduplication.
+* **Cluster Quality Control:** Requires β₯2 clusters for comparative evaluation and finds most compact cluster using variance-based scoring.
+
+---
+
+## IMPLEMENTATION PROGRESS SUMMARY
+
+### β
COMPLETED (11/18 questions)
+1. **IsObjectCentered** - Buffer zones, multiple choice A/B/C format
+2. **WidthVsHeight** - Increased threshold (0.75), non-articulated classes filter
+3. **Quadrants** - Margin-based filtering (10% quadrant size)
+4. **RankLargestK** - NEW: Ranks top-K classes with margin requirements
+5. **MostAppearance** - Margin-based count filtering (20%)
+6. **LeastAppearance** - Margin-based count filtering (20%)
+11. **MoreThanThresholdHowMany** - NEW: Threshold-based Yes/No questions
+11. **LessThanThresholdHowMany** - NEW: Inverse threshold questions
+11. **MultiChoiceHowMany** - NEW: 3-way multiple choice with ranges
+16. **ObjectsInRow** - Linear regression on y-centers
+17. **ObjectsInLine** - Multiple choice with distractor deduplication
+18. **MostClusteredObjects** - DBSCAN clustering, image-relative parameters
+
+### β οΈ PENDING IMPLEMENTATION (7/18 questions)
+7. **LeftOf** - No changes planned (already robust)
+8. **RightOf** - No changes planned (already robust)
+9. **LeftMost** - No changes planned (already robust)
+10. **RightMost** - No changes planned (already robust)
+12. **AreMore** - Needs margin-based filtering
+13. **WhichMore** - Needs margin-based filtering
+14. **LeftMostWidthVsHeight** - Needs spatial stability checks
+15. **RightMostWidthVsHeight** - Needs spatial stability checks
+
+### π― KEY ROBUSTNESS IMPROVEMENTS ACHIEVED
+- **Buffer Zones:** Prevent questions near spatial boundaries
+- **Margin Requirements:** Ensure significant differences before asking questions
+- **Multiple Choice:** Reduce answer ambiguity with A/B/C formats
+- **Statistical Methods:** Linear regression for rows, DBSCAN for clusters
+- **Image-Relative Parameters:** Scale with image size instead of fixed pixels
+- **Distractor Deduplication:** Prevent identical multiple choice options
+- **Threshold-Based Questions:** Replace exact counts with range-tolerant questions
+
+A. LeftMost / RightMost (+ WidthVsHeight variants)
+Rival set = any class in our global vocabulary EXCEPT the incumbent leftmost/rightmost class.
+A single positive label in RoS β reject question (missed object further out).
+B. LeftOf / RightOf
+RoS = horizontal band spanning entire height between x_max(objβ) and x_min(objβ).
+Rival set = objectββs class for LeftOf (and vice-versa for RightOf).
+C. Quadrants / IsObjectCentered
+Verify only if bbox is within Ξ΄ px of a grid boundary.
+RoS = margin zone on the opposite side of the claimed quadrant.
+D. ObjectsInRow / ObjectsInLine
+Rival set = same class labels already in the row.
+If SAM finds another instance of those labels within row stripe but outside existing boxes β reject βrowβ claim.
+E. MostClusteredObjects
+Check if additional objects of the same candidate class exist inside the cluster centroid radius.
+If yes β may reinforce the cluster β keep; if object of other class appears β question remains valid.
+Accept/reject based on whether the top-cluster ranking would change.
\ No newline at end of file
diff --git a/graid/src/graid/setup.py b/graid/src/graid/setup.py
index fa5a8fa..07ea2f0 100644
--- a/graid/src/graid/setup.py
+++ b/graid/src/graid/setup.py
@@ -90,17 +90,17 @@ def install_detectron2() -> None:
if platform.system() == "Darwin":
subprocess.run(
[
- 'CC=clang CXX=clang++ ARCHFLAGS="-arch x86_64" uv python',
- "-m",
+ 'CC=clang CXX=clang++ ARCHFLAGS="-arch x86_64" uv',
"pip",
"install",
"--no-build-isolation",
"-e",
".",
+ "--no-build-isolation",
]
)
else:
- subprocess.run(["uv", "python", "-m", "pip", "install", "--no-build-isolation", "-e", "."])
+ subprocess.run(['CC=clang CXX=clang++ ARCHFLAGS="-arch x86_64" uv', "pip", "install", "-e", ".", "--no-build-isolation"])
# Change back to the original directory
os.chdir("..")
diff --git a/graid/src/graid/utilities/coco.py b/graid/src/graid/utilities/coco.py
index f8967fc..07986a2 100644
--- a/graid/src/graid/utilities/coco.py
+++ b/graid/src/graid/utilities/coco.py
@@ -1,6 +1,8 @@
# Official COCO Panoptic Categories
# Source: https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json
+from typing import Any, Optional
+
# Standard COCO Detection classes (80 classes, 0-79)
coco_labels = {
-1: "undefined",
@@ -230,3 +232,79 @@
# For backward compatibility, add undefined label
coco_labels[-1] = "undefined"
inverse_coco_label["undefined"] = -1
+
+
+def validate_coco_objects(objects: list[str]) -> tuple[bool, Optional[str]]:
+ """
+ Validate that all objects in the list are valid COCO object names.
+
+ Args:
+ objects: List of object names to validate
+
+ Returns:
+ Tuple of (is_valid, error_message)
+ """
+ if not objects:
+ return True, None
+
+ valid_coco_objects = set(coco_labels.values())
+ # Remove 'undefined' from valid objects as it's not a real COCO class
+ valid_coco_objects.discard("undefined")
+
+ invalid_objects = []
+ for obj in objects:
+ if obj not in valid_coco_objects:
+ invalid_objects.append(obj)
+
+ if invalid_objects:
+ return (
+ False,
+ f"Invalid COCO objects: {invalid_objects}. Valid objects: {sorted(valid_coco_objects)}",
+ )
+
+ return True, None
+
+
+def get_valid_coco_objects() -> list[str]:
+ """
+ Get a list of valid COCO object names.
+
+ Returns:
+ List of valid COCO object names
+ """
+ valid_objects = list(coco_labels.values())
+ # Remove 'undefined' as it's not a real COCO class
+ if "undefined" in valid_objects:
+ valid_objects.remove("undefined")
+ return sorted(valid_objects)
+
+
+def filter_detections_by_allowable_set(
+ detections: list[dict[str, Any]], allowable_set: Optional[list[str]]
+) -> list[dict[str, Any]]:
+ """
+ Filter detections to only include objects in the allowable set.
+
+ Args:
+ detections: List of detection dictionaries with 'class' or 'label' field
+ allowable_set: List of allowed COCO object names, or None to allow all
+
+ Returns:
+ Filtered list of detections
+ """
+ if not allowable_set:
+ return detections
+
+ allowable_set_normalized = set(allowable_set)
+ filtered_detections = []
+
+ for detection in detections:
+ # Handle different possible keys for class name
+ class_name = (
+ detection.get("class") or detection.get("label") or detection.get("name")
+ )
+
+ if class_name and class_name in allowable_set_normalized:
+ filtered_detections.append(detection)
+
+ return filtered_detections
diff --git a/graid/src/graid/utilities/sam_utils.py b/graid/src/graid/utilities/sam_utils.py
new file mode 100644
index 0000000..4c460de
--- /dev/null
+++ b/graid/src/graid/utilities/sam_utils.py
@@ -0,0 +1,257 @@
+"""
+SAM (Segment Anything Model) utilities for object mask refinement.
+"""
+from enum import Enum
+from typing import Optional, Tuple, List
+
+import torch
+import numpy as np
+from PIL import Image
+from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
+
+from graid.interfaces.ObjectDetectionI import ObjectDetectionResultI
+from graid.utilities.common import get_default_device, project_root_dir
+
+_sam_model = None
+
+
+def get_sam_model(**kwargs) -> torch.nn.Module:
+ """
+ Get a SAM model instance, loading it if necessary.
+
+ This function ensures that the SAM model is loaded only once.
+
+ Args:
+ model_path: Path to SAM checkpoint (default: checkpoints/sam_vit_h_4b8939.pth)
+ device: Device to use for inference
+ model_type: SAM model type (default: vit_h)
+
+ Returns:
+ The loaded SAM model.
+ """
+ global _sam_model
+ if _sam_model is None:
+ model_path = kwargs.get(
+ "model_path", project_root_dir() / "checkpoints" / "sam_vit_h_4b8939.pth"
+ )
+
+ if not model_path.exists():
+ raise FileNotFoundError(
+ f"SAM checkpoint not found at {model_path}. "
+ "Please download the SAM checkpoint following the project's setup instructions."
+ )
+
+ device = kwargs.get("device", get_default_device())
+ model_type = kwargs.get("model_type", "vit_h")
+
+ _sam_model = sam_model_registry[model_type](checkpoint=str(model_path))
+ _sam_model.to(device=device)
+
+ return _sam_model
+
+
+class SAMMaskReturnType(Enum):
+ LARGEST_AREA = "largest_area"
+ HIGHEST_CONFIDENCE = "highest_confidence"
+
+
+class SAMPredictor:
+ """
+ Wrapper around SAM for getting refined object masks from bounding boxes.
+ """
+
+ def __init__(self, **kwargs):
+ """
+ Initialize SAM predictor.
+
+ Args:
+ model_path: Path to SAM checkpoint (default: checkpoints/sam_vit_h_4b8939.pth)
+ device: Device to use for inference
+ model_type: SAM model type (default: vit_h)
+ """
+ sam = get_sam_model(**kwargs)
+ self.predictor = SamPredictor(sam)
+
+ def get_mask_from_bbox(
+ self,
+ image: Image.Image,
+ detection: ObjectDetectionResultI,
+ return_type: SAMMaskReturnType = SAMMaskReturnType.LARGEST_AREA,
+ ) -> Optional[torch.Tensor]:
+ """
+ Get refined mask for an object using its bounding box as prompt.
+
+ Args:
+ image: PIL Image
+ detection: ObjectDetectionResultI containing the bounding box
+ return_type: Method to select the best mask from multiple predictions
+
+ Returns:
+ Binary mask as torch.Tensor of shape (H, W), or None if no valid mask
+ """
+ # Convert PIL Image to numpy array (RGB)
+ image_array = np.array(image)
+
+ # Set the image for SAM predictor
+ self.predictor.set_image(image_array)
+
+ # Get bounding box in XYXY format
+ bbox_xyxy = detection.as_xyxy().squeeze().cpu().numpy() # Shape: (4,)
+
+ # SAM expects bbox in [x_min, y_min, x_max, y_max] format
+ input_box = bbox_xyxy
+
+ # Predict masks using the bounding box as prompt
+ masks, scores, logits = self.predictor.predict(
+ point_coords=None,
+ point_labels=None,
+ box=input_box[None, :], # Add batch dimension
+ multimask_output=True,
+ )
+
+ if len(masks) == 0:
+ return None
+
+ if return_type == SAMMaskReturnType.LARGEST_AREA:
+ mask_areas = np.sum(masks, axis=(1, 2))
+ best_idx = np.argmax(mask_areas)
+ best_mask = masks[best_idx]
+ elif return_type == SAMMaskReturnType.HIGHEST_CONFIDENCE:
+ best_idx = np.argmax(scores)
+ best_mask = masks[best_idx]
+ else:
+ raise ValueError(f"Invalid return_type: {return_type}")
+
+ # Convert to torch tensor
+ return torch.from_numpy(best_mask).bool()
+
+ def get_masks_from_detections(
+ self,
+ image: Image.Image,
+ detections: List[ObjectDetectionResultI],
+ return_type: SAMMaskReturnType = SAMMaskReturnType.LARGEST_AREA,
+ ) -> List[Tuple[ObjectDetectionResultI, Optional[torch.Tensor]]]:
+ """
+ Get refined masks for multiple detections.
+
+ Args:
+ image: PIL Image
+ detections: List of ObjectDetectionResultI objects
+ return_type: Method to select the best mask
+
+ Returns:
+ List of (detection, mask) tuples where mask can be None if prediction failed
+ """
+ results = []
+
+ # Set image once for all predictions
+ image_array = np.array(image)
+ self.predictor.set_image(image_array)
+
+ for detection in detections:
+ mask = self.get_mask_from_bbox(
+ image, detection, return_type=return_type
+ )
+ results.append((detection, mask))
+
+ return results
+
+ def to(self, device: torch.device) -> "SAMPredictor":
+ """Move model to specified device."""
+ self.device = device
+ self.predictor.model.to(device)
+ return self
+
+
+class SAMMaskGenerator:
+ """
+ Wrapper around SAM's SamAutomaticMaskGenerator.
+ """
+
+ def __init__(self, **kwargs):
+ """
+ Initialize SAM mask generator.
+
+ Args:
+ **kwargs: Arguments for SamAutomaticMaskGenerator and get_sam_model.
+ """
+ sam = get_sam_model(**kwargs)
+ self.mask_generator = SamAutomaticMaskGenerator(sam, **kwargs)
+
+ def generate(self, image: np.ndarray) -> List[dict]:
+ """
+ Generate masks for the entire image.
+
+ Args:
+ image: Image as a numpy array in RGB format.
+
+ Returns:
+ A list of masks, where each mask is a dictionary containing segmentation data.
+ """
+ return self.mask_generator.generate(image)
+
+
+def extract_average_depth_from_mask(
+ depth_map: torch.Tensor, mask: torch.Tensor
+) -> Optional[float]:
+ """
+ Extract average depth value from pixels covered by the mask.
+
+ Args:
+ depth_map: Depth map tensor of shape (H, W) with depth values in meters
+ mask: Binary mask tensor of shape (H, W)
+
+ Returns:
+ Average depth in meters, or None if mask is empty
+ """
+ if mask.sum() == 0:
+ return None
+
+ # Apply mask to depth map and compute average
+ masked_depths = depth_map[mask]
+ return float(masked_depths.mean().item())
+
+
+def compare_object_depths(
+ depth_map: torch.Tensor,
+ detection1: ObjectDetectionResultI,
+ mask1: torch.Tensor,
+ detection2: ObjectDetectionResultI,
+ mask2: torch.Tensor,
+ margin_ratio: float = 0.1,
+) -> Tuple[Optional[str], float, float]:
+ """
+ Compare relative depths of two objects using their masks.
+
+ Args:
+ depth_map: Depth map tensor (H, W) with values in meters
+ detection1: First object detection
+ mask1: Mask for first object
+ detection2: Second object detection
+ mask2: Mask for second object
+ margin_ratio: Required margin for reliable comparison
+
+ Returns:
+ Tuple of:
+ - Comparison result: "object1_front", "object2_front", or None if too close
+ - Average depth of object1
+ - Average depth of object2
+ """
+ avg_depth1 = extract_average_depth_from_mask(depth_map, mask1)
+ avg_depth2 = extract_average_depth_from_mask(depth_map, mask2)
+
+ if avg_depth1 is None or avg_depth2 is None:
+ return None, avg_depth1 or 0.0, avg_depth2 or 0.0
+
+ # Smaller depth values mean closer to camera (in front)
+ depth_diff = abs(avg_depth1 - avg_depth2)
+ min_depth = min(avg_depth1, avg_depth2)
+
+ # Check if difference is significant relative to distance
+ if depth_diff < margin_ratio * min_depth:
+ return None, avg_depth1, avg_depth2
+
+ if avg_depth1 < avg_depth2:
+ return "object1_front", avg_depth1, avg_depth2
+ else:
+ return "object2_front", avg_depth1, avg_depth2
\ No newline at end of file
diff --git a/graid/src/graid/utils/profiling.py b/graid/src/graid/utils/profiling.py
new file mode 100644
index 0000000..0bffde9
--- /dev/null
+++ b/graid/src/graid/utils/profiling.py
@@ -0,0 +1,127 @@
+"""
+Profiling utilities for GRAID question generation statistics.
+
+This module provides reusable functions for logging and formatting
+profiling statistics from question generation processes.
+"""
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def log_profiling_statistics(question_stats, title="Question Processing Statistics"):
+ """
+ Log profiling statistics to console using consistent formatting.
+
+ Args:
+ question_stats: Dictionary containing detailed profiling statistics
+ title: Title for the statistics table
+ """
+ if not question_stats or 'detailed_stats' not in question_stats:
+ return
+
+ logger.info(f"π {title}:")
+ logger.info("=" * 95)
+
+ # Table header
+ logger.info(f"{'Question Type':<40} {'is_app(ms)':<12} {'apply(ms)':<12} {'Hit Rate':<10} {'Failed':<8} {'Total QA':<10}")
+ logger.info("-" * 95)
+
+ # Table rows
+ for qtype, stats in question_stats['detailed_stats'].items():
+ # Calculate averages
+ is_app_time, is_app_count = stats.get("is_applicable_time", (0, 1))
+ apply_time, apply_count = stats.get("apply_time", (0, 1))
+ is_app_avg = (is_app_time / max(is_app_count, 1)) * 1000
+ apply_avg = (apply_time / max(apply_count, 1)) * 1000
+
+ # Calculate success metrics
+ is_applicable_count = stats.get("is_applicable_true_count", 0)
+ empty_results = stats.get("apply_empty_results", 0)
+ total_qa_generated = stats.get("total_qa_generated", 0)
+ successful_cases = is_applicable_count - empty_results
+ hit_rate = (successful_cases / max(is_applicable_count, 1)) * 100 if is_applicable_count > 0 else 0
+
+ question_text = stats.get("question_text", qtype)
+ # Truncate for console alignment
+ question_text_short = question_text[:39]
+ logger.info(f"{question_text_short:<40} {is_app_avg:<12.2f} {apply_avg:<12.2f} {hit_rate:<10.1f}% {empty_results:<8} {total_qa_generated:<10}")
+
+ logger.info("=" * 95)
+ logger.info("Notes: Hit Rate = % of applicable cases that generated β₯1 QA pair")
+ logger.info(" Failed = cases where is_applicable=True but apply returned no QA pairs")
+
+
+def format_profiling_table(question_stats, format_type="markdown"):
+ """
+ Format profiling statistics table in either markdown or console format.
+
+ Args:
+ question_stats: Dictionary containing detailed profiling statistics
+ format_type: "markdown" for README tables, "console" for logging
+
+ Returns:
+ Formatted table as string
+ """
+ if not question_stats or 'detailed_stats' not in question_stats:
+ return ""
+
+ if format_type == "markdown":
+ # Markdown table format for README
+ table_lines = [
+ "| Question Type | is_applicable Avg (ms) | apply Avg (ms) | Predicate -> QA Hit Rate | Empty cases |",
+ "|---------------|------------------------|----------------|--------------------------|-------------|"
+ ]
+ else:
+ # Console table format for logging
+ table_lines = [
+ f"{'Question Type':<40} {'is_app(ms)':<12} {'apply(ms)':<12} {'Hit Rate':<10} {'Failed':<8} {'Total QA':<10}",
+ "-" * 95
+ ]
+
+ for qtype, stats in question_stats['detailed_stats'].items():
+ # Calculate averages
+ is_app_time, is_app_count = stats.get("is_applicable_time", (0, 1))
+ apply_time, apply_count = stats.get("apply_time", (0, 1))
+ is_app_avg = (is_app_time / max(is_app_count, 1)) * 1000
+ apply_avg = (apply_time / max(apply_count, 1)) * 1000
+
+ # Calculate success metrics
+ is_applicable_count = stats.get("is_applicable_true_count", 0)
+ empty_results = stats.get("apply_empty_results", 0)
+ total_qa_generated = stats.get("total_qa_generated", 0)
+ successful_cases = is_applicable_count - empty_results
+ hit_rate = (successful_cases / max(is_applicable_count, 1)) * 100 if is_applicable_count > 0 else 0
+
+ question_text = stats.get("question_text", qtype)
+
+ if format_type == "markdown":
+ table_lines.append(f"| {question_text} | {is_app_avg:.2f} | {apply_avg:.2f} | {hit_rate:.1f}% | {empty_results} |")
+ else:
+ # Truncate for console alignment
+ question_text_short = question_text[:39]
+ table_lines.append(f"{question_text_short:<40} {is_app_avg:<12.2f} {apply_avg:<12.2f} {hit_rate:<10.1f}% {empty_results:<8} {total_qa_generated:<10}")
+
+ return "\n".join(table_lines)
+
+
+def format_profiling_notes(format_type="markdown"):
+ """
+ Format profiling explanation notes.
+
+ Args:
+ format_type: "markdown" for README, "console" for logging
+
+ Returns:
+ Formatted notes as string
+ """
+ if format_type == "markdown":
+ return """\n**Notes:**
+- `is_applicable` checks if a question type can be applied to an image
+- `apply` generates the actual question-answer pairs
+- Predicate -> QA Hit Rate = Percentage of applicable cases that generated at least one QA pair
+- Empty cases = Number of times is_applicable=True but apply returned no QA pairs"""
+ else:
+ return """Notes: Hit Rate = % of applicable cases that generated β₯1 QA pair
+ Failed = cases where is_applicable=True but apply returned no QA pairs"""
diff --git a/graid/src/graid/verification/EXAMPLE.md b/graid/src/graid/verification/EXAMPLE.md
new file mode 100644
index 0000000..fe694de
--- /dev/null
+++ b/graid/src/graid/verification/EXAMPLE.md
@@ -0,0 +1,16 @@
+
+from graid.evaluator.vlms import GPT
+from graid.evaluator.prompts import SetOfMarkPrompt
+from graid.verification import PrecisionVerifier, RecallVerifier
+
+# Create VLM instance
+vlm = GPT(model_name="gpt-4o")
+
+# PrecisionVerifier - verify predicted labels
+precision_verifier = PrecisionVerifier(vlm)
+is_correct = precision_verifier.verify(cropped_image, "car")
+
+# RecallVerifier - find missed objects
+prompting_strategy = SetOfMarkPrompt()
+recall_verifier = RecallVerifier(prompting_strategy, vlm)
+no_objects_missed = recall_verifier.verify(region_image, ["car", "truck"])
\ No newline at end of file
diff --git a/graid/src/graid/verification/precision_verifier.py b/graid/src/graid/verification/precision_verifier.py
new file mode 100644
index 0000000..e767c52
--- /dev/null
+++ b/graid/src/graid/verification/precision_verifier.py
@@ -0,0 +1,103 @@
+import logging
+from typing import Optional
+
+from PIL import Image
+
+from graid.evaluator.prompts import PromptingStrategy, PassthroughPrompt
+from graid.evaluator.vlms import VLM
+
+logger = logging.getLogger(__name__)
+
+
+class PrecisionVerifier:
+ """Verify that a *predicted* label matches the object in a cropped image.
+
+ This verifier focuses on **precision**; that is, confirming a detector's
+ *positive* prediction is correct. The caller is expected to supply a
+ *pre-cropped* image that contains exactly the detected object (usually a
+ detection bounding-box crop) along with the predicted class label.
+
+ Parameters
+ ----------
+ vlm : VLM
+ Instance of a VLM class that adheres to the VLM interface.
+ prompting_strategy : PromptingStrategy | None, optional
+ Strategy used to build the visual prompt passed to the VLM. Defaults
+ to a **no-op** strategy that leaves the image unchanged and only adds
+ a textual question.
+ yes_tokens : tuple[str, ...], default ("yes", "true", "y", "1")
+ Tokens considered affirmative when parsing the VLM's response.
+ no_tokens : tuple[str, ...], default ("no", "false", "n", "0")
+ Tokens considered negative when parsing the VLM's response.
+ """
+
+ def __init__(
+ self,
+ vlm: VLM,
+ prompting_strategy: Optional[PromptingStrategy] = None,
+ *,
+ yes_tokens: tuple[str, ...] = ("yes", "true", "y", "1"),
+ no_tokens: tuple[str, ...] = ("no", "false", "n", "0"),
+ ) -> None:
+ self.vlm = vlm
+ self.ps = prompting_strategy or PassthroughPrompt()
+ self._yes_tokens = tuple(token.lower() for token in yes_tokens)
+ self._no_tokens = tuple(token.lower() for token in no_tokens)
+
+ # ------------------------------------------------------------------
+ # public API
+ # ------------------------------------------------------------------
+ def verify(self, image: Image.Image, label: str) -> bool:
+ """Return **True** when the object in the image matches *label*.
+
+ The method:
+ 1. Builds a yes/no prompt for the provided label.
+ 2. Queries the VLM using its ``generate_answer`` interface.
+ 3. Parses the VLM's yes/no answer.
+ 4. Returns *True* if the answer is affirmative.
+ """
+ question = self._build_question(label)
+
+ # Generate the prompt, which may annotate the image
+ annotated_image, messages = self.ps.generate_prompt(image, question)
+
+ # Call the VLM using its standard interface; we ignore the second
+ # element (messages) in the returned tuple.
+ answer_text, _ = self.vlm.generate_answer(annotated_image, messages)
+
+ is_match = self._parse_answer(answer_text)
+ logger.debug("Label: '%s' | VLM: '%s' | Match: %s", label, answer_text, is_match)
+ return is_match
+
+ # ------------------------------------------------------------------
+ # helpers
+ # ------------------------------------------------------------------
+ @staticmethod
+ def _build_question(label: str) -> str:
+ # Request a binary response to simplify parsing.
+ return (
+ f"Is the object in this image a {label}? "
+ "Answer with either 'yes' or 'no'."
+ )
+
+ def _parse_answer(self, answer_text: str) -> bool:
+ """Interpret the VLM response as *yes* or *no*.
+
+ Any unrecognized response defaults to *False* (i.e., the label does
+ *not* match), since we only want to confirm positives we are confident
+ about.
+ """
+ cleaned = answer_text.strip().lower()
+ if "```" in cleaned:
+ cleaned = cleaned.split("```")[-2 if cleaned.endswith("```") else -1].strip()
+
+ # Extract first token (in case of longer sentences)
+ first_token = cleaned.split()[0] if cleaned else ""
+
+ if first_token in self._yes_tokens:
+ return True
+ if first_token in self._no_tokens:
+ return False
+
+ logger.warning("Unrecognized VLM precision-verification answer '%s'", answer_text)
+ return False
\ No newline at end of file
diff --git a/graid/src/graid/verification/recall_verifier.py b/graid/src/graid/verification/recall_verifier.py
new file mode 100644
index 0000000..b0774ce
--- /dev/null
+++ b/graid/src/graid/verification/recall_verifier.py
@@ -0,0 +1,116 @@
+import ast
+import logging
+from collections.abc import Sequence
+from typing import Optional
+
+from PIL import Image
+
+from graid.evaluator.prompts import PromptingStrategy
+from graid.evaluator.vlms import VLM
+
+logger = logging.getLogger(__name__)
+
+
+class RecallVerifier:
+ """Orchestrates object detection verification using a VLM and a prompting strategy.
+
+ This class coordinates the verification process by generating prompts for
+ suspicious regions, querying the VLM with annotated images, and parsing
+ the responses to determine if any objects were missed by the original
+ detector.
+
+ The prompting behavior can be controlled by providing different strategies
+ at instantiation. For example, use ``SetOfMarkPrompt`` to visually highlight
+ regions of interest, or ``PassthroughPrompt`` to send the image as-is.
+
+ Parameters
+ ----------
+ prompting_strategy : PromptingStrategy
+ A prompting strategy, e.g., ``SetOfMarkPrompt`` or
+ ``PassthroughPrompt`` from ``graid.evaluator.prompts``.
+ vlm : VLM
+ Instance of a VLM class that adheres to the VLM interface.
+ """
+
+ def __init__(
+ self,
+ prompting_strategy: PromptingStrategy,
+ vlm: VLM,
+ ) -> None:
+ self.ps = prompting_strategy
+ self.vlm = vlm
+
+ # ---------------------------------------------------------------------
+ # public API
+ # ---------------------------------------------------------------------
+ def verify(
+ self,
+ image: Image.Image,
+ possible_classes: Optional[Sequence[str]] = None,
+ ) -> bool:
+ """Return **True** if *no* objects are detected in the given region of suspicion.
+
+ Parameters
+ ----------
+ image : PIL.Image.Image
+ *Pre-cropped* region of suspicion.
+ possible_classes : Sequence[str] | None, optional
+ Candidate classes to ask the VLM about. If ``None`` we let the
+ model answer freely.
+ """
+ # STEP 1: build the textual question adapted to the chosen strategy
+ question = self._build_question(possible_classes)
+
+ # STEP 2: Generate the prompt, which may annotate the image
+ annotated_image, messages = self.ps.generate_prompt(image, question)
+
+ # STEP 3: query the VLM using the standard interface and parse the answer
+ answer_text, _ = self.vlm.generate_answer(annotated_image, messages)
+ found_labels = self._parse_answer(answer_text)
+
+ logger.debug(
+ "Possible: %s | Found: %s | Prompting: %s",
+ possible_classes,
+ found_labels,
+ self.ps.__class__.__name__,
+ )
+ return len(found_labels) == 0
+
+ # ------------------------------------------------------------------
+ # helpers
+ # ------------------------------------------------------------------
+ @staticmethod
+ def _build_question(possible_classes: Optional[Sequence[str]]) -> str:
+ if possible_classes:
+ class_list = ", ".join(possible_classes)
+ return (
+ "Which of these objects are present in the image: "
+ f"{class_list}? Provide your answer as a python list. "
+ "If none, return empty list []."
+ )
+ else:
+ return (
+ "Are there any objects present in the image? "
+ "Provide your answer as a python list of object names. "
+ "If none, return empty list []."
+ )
+
+ @staticmethod
+ def _parse_answer(answer_text: str) -> list[str]:
+ """Extract a Python list from raw answer text.
+
+ The model may wrap the list in triple back-ticks; we strip those out
+ and fall back to empty list on any parsing error.
+ """
+ cleaned = answer_text.strip()
+ if "```" in cleaned:
+ cleaned = cleaned.split("```")[-2 if cleaned.endswith("```") else -1]
+ try:
+ parsed = ast.literal_eval(cleaned)
+ if isinstance(parsed, list):
+ return [str(x) for x in parsed]
+ # If VLM returned single token instead of list
+ return [str(parsed)]
+ except Exception as e: # noqa: BLE001
+ logger.warning("Failed to parse VLM answer '%s': %s", answer_text, e)
+ return []
\ No newline at end of file
diff --git a/graid/tests/__init__.py b/graid/tests/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/pyproject.toml b/pyproject.toml
index f59dd38..35d32f3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "graid"
version = "0.1.0"
-description = "GRAID: Generating Reasoning questions from Analysis of Images via Discriminative artificial intelligence - A framework for VLM scene understanding in robotics and autonomous driving"
+description = "GRAID: Generating Reasoning questions from Analysis of Images via Discriminative artificial intelligence - A framework for VQA generation from real world images"
authors = [
{ name = "Karim Elmaaroufi", email = "elmaaroufi@berkeley.edu" },
{ name = "Liheng Lai", email = "liheng@berkeley.edu" },
@@ -13,7 +13,6 @@ keywords = [
"vision-language-models",
"VLM",
"robotics",
- "autonomous-driving",
"scene-understanding",
"depth-estimation",
"object-detection",
@@ -45,6 +44,9 @@ dependencies = [
"clip",
"wandb>=0.20.1,<0.21",
"sentence-transformers>=4.1.0,<5",
+ "datasets (>=4.0.0,<5.0.0)",
+ "ensemble-boxes (>=1.0.9,<2.0.0)",
+ "supervision (>=0.26.1,<0.27.0)",
"pycocotools>=2.0.10",
]
@@ -79,6 +81,7 @@ dev = [
"ipykernel>=6.29.5,<7",
"gdown>=5.2.0,<6",
"pytest>=8.4.0,<9",
+ "pytest-xdist>=3.8.0",
]
mmdetection = ["openmim>=0.3.9,<0.4"]
@@ -105,5 +108,11 @@ include = ["graid/src/graid"]
requires = [ "setuptools", "wheel", "torch"]
build-backend = "setuptools.build_meta:__legacy__"
+[tool.setuptools.packages.find]
+where = ["graid/src"]
+
+[tool.setuptools.package-dir]
+"" = "graid/src"
+
[tool.isort]
profile = "black"
diff --git a/test_depth_questions.py b/test_depth_questions.py
new file mode 100644
index 0000000..0519ecb
--- /dev/null
+++ b/test_depth_questions.py
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/test_depth_questions_stacked.py b/test_depth_questions_stacked.py
new file mode 100644
index 0000000..73ed2fa
--- /dev/null
+++ b/test_depth_questions_stacked.py
@@ -0,0 +1,343 @@
+#!/usr/bin/env python3
+"""
+Test script to verify the 3 depth questions (FrontOf, BehindOf, DepthRanking) work correctly.
+Creates a stacked visualization with all images vertically arranged.
+"""
+
+import sys
+import os
+sys.path.insert(0, 'graid/src')
+
+try:
+ import torch
+ import matplotlib.pyplot as plt
+ import matplotlib.patches as patches
+ from PIL import Image, ImageDraw, ImageFont
+ import numpy as np
+ from pathlib import Path
+
+ # Try importing GRAID modules with fallback handling
+ try:
+ from graid.data.ImageLoader import Bdd100kDataset
+ DATASET_AVAILABLE = True
+ except ImportError as e:
+ print(f"Warning: Could not import dataset loader: {e}")
+ DATASET_AVAILABLE = False
+
+ try:
+ from graid.questions.ObjectDetectionQ import FrontOf, BehindOf, DepthRanking
+ from graid.interfaces.ObjectDetectionI import ObjectDetectionResultI, ObjectDetectionUtils
+ QUESTIONS_AVAILABLE = True
+ except ImportError as e:
+ print(f"Error: Could not import question modules: {e}")
+ QUESTIONS_AVAILABLE = False
+ sys.exit(1)
+
+ try:
+ from graid.models.DepthPro import DepthPro
+ DEPTH_MODEL_AVAILABLE = True
+ except ImportError as e:
+ print(f"Warning: Could not import DepthPro: {e}")
+ DEPTH_MODEL_AVAILABLE = False
+
+except ImportError as e:
+ print(f"Error: Missing required dependencies: {e}")
+ sys.exit(1)
+
+def draw_boxes(
+ image: np.ndarray,
+ detections: list[ObjectDetectionResultI],
+ alpha: float = 1.0,
+) -> np.ndarray:
+ """Overlay detections (xyxy) on an RGB image and return the visualised image."""
+ colours: list[tuple[int, int, int]] = [
+ (255, 0, 0),
+ (0, 255, 0),
+ (0, 0, 255),
+ (255, 255, 0),
+ (255, 0, 255),
+ (0, 255, 255),
+ (255, 128, 0),
+ (128, 0, 255),
+ (255, 192, 203),
+ (128, 128, 0),
+ ]
+
+ pil_img = Image.fromarray(image)
+ draw = ImageDraw.Draw(pil_img, "RGBA")
+
+ for i, det in enumerate(detections):
+ try:
+ colour = colours[i % len(colours)] + (255,)
+ xyxy = det.as_xyxy()
+ # Support tensor/ndarray/list
+ if hasattr(xyxy, 'shape') and len(xyxy.shape) > 1:
+ for j in range(xyxy.shape[0]):
+ x1, y1, x2, y2 = xyxy[j][:4]
+ draw.rectangle([x1, y1, x2, y2], outline=colour, width=3)
+ # Add label
+ label = str(det.label[j].item()) if isinstance(det.label, torch.Tensor) else str(det.label)
+ draw.text((x1, y1-15), label, fill=colour)
+ else:
+ x1, y1, x2, y2 = xyxy[:4]
+ draw.rectangle([x1, y1, x2, y2], outline=colour, width=3)
+ # Add label
+ label = str(det.label)
+ draw.text((x1, y1-15), label, fill=colour)
+ except Exception as e:
+ print(f"Error drawing detection {i}: {e}")
+ continue
+
+ return np.array(pil_img)
+
+def create_mock_data():
+ """Create mock detection data for testing when dataset is not available"""
+ # Create a simple test image (640x480)
+ image = Image.new('RGB', (640, 480), color=(135, 206, 235)) # Sky blue background
+
+ # Convert to numpy for drawing simple shapes
+ img_array = np.array(image)
+
+ # Draw some simple rectangles to represent objects
+ # Traffic light (closer, left side)
+ img_array[100:200, 150:200] = [255, 255, 0] # Yellow rectangle
+
+ # Traffic sign (further, right side)
+ img_array[120:180, 450:500] = [255, 0, 0] # Red rectangle
+
+ # Person (middle depth, center)
+ img_array[250:400, 300:350] = [139, 69, 19] # Brown rectangle
+
+ # Car (closest, bottom)
+ img_array[350:450, 200:400] = [0, 0, 255] # Blue rectangle
+
+ image = Image.fromarray(img_array)
+
+ # Create mock detections
+ detections = [
+ ObjectDetectionResultI(
+ score=0.9, cls=9, label="traffic light",
+ bbox=[150, 100, 200, 200], image_hw=(480, 640)
+ ),
+ ObjectDetectionResultI(
+ score=0.85, cls=11, label="traffic sign",
+ bbox=[450, 120, 500, 180], image_hw=(480, 640)
+ ),
+ ObjectDetectionResultI(
+ score=0.8, cls=0, label="person",
+ bbox=[300, 250, 350, 400], image_hw=(480, 640)
+ ),
+ ObjectDetectionResultI(
+ score=0.95, cls=2, label="car",
+ bbox=[200, 350, 400, 450], image_hw=(480, 640)
+ ),
+ ]
+
+ return image, detections
+
+def test_depth_questions():
+ """Test depth questions and create stacked visualization"""
+ print("=== Testing Depth Questions ===")
+
+ # Initialize results storage
+ all_results = []
+
+ if DATASET_AVAILABLE:
+ print("Loading dataset...")
+ try:
+ # Load dataset
+ dataset = Bdd100kDataset(split="val", n_images=3)
+ images_and_detections = [dataset[i] for i in range(min(3, len(dataset)))]
+ except Exception as e:
+ print(f"Error loading dataset: {e}")
+ print("Falling back to mock data...")
+ images_and_detections = [(create_mock_data())]
+ else:
+ print("Using mock data...")
+ images_and_detections = [create_mock_data()]
+
+ # Initialize question classes
+ try:
+ front_question = FrontOf(margin_ratio=0.1)
+ behind_question = BehindOf(margin_ratio=0.1)
+ depth_ranking = DepthRanking(k=3, margin_ratio=0.1)
+
+ print(f"β
Successfully initialized depth questions")
+ except Exception as e:
+ print(f"β Error initializing questions: {e}")
+ return
+
+ # Process each image
+ for idx, (pil_image, detections) in enumerate(images_and_detections):
+ print(f"\nProcessing image {idx+1}...")
+
+ # Filter detections to only include relevant classes
+ relevant_classes = ["traffic light", "traffic sign", "person", "car", "bus", "truck"]
+ filtered_detections = []
+
+ for detection in detections:
+ if isinstance(detection.label, str):
+ if detection.label in relevant_classes:
+ filtered_detections.append(detection)
+ elif isinstance(detection.label, torch.Tensor):
+ # Handle tensor labels - flatten and filter
+ flattened = detection.flatten()
+ for flat_det in flattened:
+ if str(flat_det.label) in relevant_classes:
+ filtered_detections.append(flat_det)
+
+ print(f"Found {len(filtered_detections)} relevant detections")
+
+ if len(filtered_detections) < 2:
+ print("Not enough detections for depth questions, skipping...")
+ continue
+
+ # Set up question context for proper predicate evaluation
+ try:
+ ObjectDetectionUtils.set_current_context(
+ ObjectDetectionUtils.build_question_context(pil_image, filtered_detections)
+ )
+ except Exception as e:
+ print(f"Warning: Could not set question context: {e}")
+
+ # Test each question type
+ questions_to_test = [
+ ("FrontOf", front_question),
+ ("BehindOf", behind_question),
+ ("DepthRanking", depth_ranking)
+ ]
+
+ image_results = {
+ 'image': pil_image,
+ 'detections': filtered_detections,
+ 'questions': {}
+ }
+
+ for question_name, question in questions_to_test:
+ try:
+ print(f" Testing {question_name}...")
+
+ # Check if question is applicable
+ if question.is_applicable(pil_image, filtered_detections):
+ # Apply the question
+ qa_pairs = question.apply(pil_image, filtered_detections)
+ print(f" Generated {len(qa_pairs)} question-answer pairs")
+
+ for q, a in qa_pairs:
+ print(f" Q: {q}")
+ print(f" A: {a}")
+
+ image_results['questions'][question_name] = qa_pairs
+ else:
+ print(f" {question_name} not applicable to this image")
+ image_results['questions'][question_name] = []
+
+ except Exception as e:
+ print(f" β Error with {question_name}: {e}")
+ image_results['questions'][question_name] = []
+
+ all_results.append(image_results)
+
+ # Create stacked visualization
+ create_stacked_visualization(all_results)
+
+def create_stacked_visualization(all_results):
+ """Create a single stacked image with all test results"""
+ if not all_results:
+ print("No results to visualize")
+ return
+
+ print(f"\nCreating stacked visualization with {len(all_results)} images...")
+
+ # Calculate dimensions for stacked image
+ img_width = 640
+ img_height = 480
+ text_height = 200 # Space for question/answer text
+ total_height_per_image = img_height + text_height
+ total_height = total_height_per_image * len(all_results)
+
+ # Create the stacked image
+ stacked_image = Image.new('RGB', (img_width, total_height), color=(255, 255, 255))
+
+ current_y = 0
+
+ for idx, result in enumerate(all_results):
+ # Get image with detections overlay
+ image_with_boxes = draw_boxes(np.array(result['image']), result['detections'])
+ image_with_boxes = Image.fromarray(image_with_boxes)
+
+ # Resize if necessary
+ if image_with_boxes.size != (img_width, img_height):
+ image_with_boxes = image_with_boxes.resize((img_width, img_height))
+
+ # Paste the image
+ stacked_image.paste(image_with_boxes, (0, current_y))
+ current_y += img_height
+
+ # Add text with questions and answers
+ text_img = Image.new('RGB', (img_width, text_height), color=(240, 240, 240))
+ draw = ImageDraw.Draw(text_img)
+
+ # Try to use a font, fallback to default if not available
+ try:
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
+ except:
+ font = ImageFont.load_default()
+
+ y_offset = 10
+ line_height = 15
+
+ # Image header
+ draw.text((10, y_offset), f"Image {idx+1} - Depth Questions Test Results:",
+ fill=(0, 0, 0), font=font)
+ y_offset += line_height * 2
+
+ # Add questions and answers
+ for question_type, qa_pairs in result['questions'].items():
+ if qa_pairs:
+ draw.text((10, y_offset), f"{question_type}:", fill=(0, 0, 255), font=font)
+ y_offset += line_height
+
+ for q, a in qa_pairs[:3]: # Limit to first 3 to fit in space
+ # Wrap long questions
+ q_short = q[:80] + "..." if len(q) > 80 else q
+ draw.text((20, y_offset), f"Q: {q_short}", fill=(0, 100, 0), font=font)
+ y_offset += line_height
+ draw.text((20, y_offset), f"A: {a}", fill=(100, 0, 0), font=font)
+ y_offset += line_height
+
+ if len(qa_pairs) > 3:
+ draw.text((20, y_offset), f"... and {len(qa_pairs)-3} more",
+ fill=(100, 100, 100), font=font)
+ y_offset += line_height
+ else:
+ draw.text((10, y_offset), f"{question_type}: No applicable questions",
+ fill=(150, 150, 150), font=font)
+ y_offset += line_height
+
+ y_offset += line_height // 2
+
+ # Paste the text section
+ stacked_image.paste(text_img, (0, current_y))
+ current_y += text_height
+
+ # Save the stacked visualization
+ output_path = "depth_questions_test_results_stacked.png"
+ stacked_image.save(output_path)
+ print(f"β
Stacked visualization saved to: {output_path}")
+
+ # Also display using matplotlib if available
+ try:
+ plt.figure(figsize=(12, len(all_results) * 6))
+ plt.imshow(np.array(stacked_image))
+ plt.axis('off')
+ plt.title('Depth Questions Test Results (All Images Stacked)')
+ plt.tight_layout()
+ plt.savefig("depth_questions_matplotlib.png", dpi=150, bbox_inches='tight')
+ plt.show()
+ print("β
Also displayed with matplotlib")
+ except Exception as e:
+ print(f"Note: Could not display with matplotlib: {e}")
+
+if __name__ == "__main__":
+ test_depth_questions()
\ No newline at end of file
diff --git a/tests/manual_tests/test_detectron_seg_visual.py b/tests/manual_tests/test_detectron_seg_visual.py
new file mode 100644
index 0000000..1c40a90
--- /dev/null
+++ b/tests/manual_tests/test_detectron_seg_visual.py
@@ -0,0 +1,207 @@
+"""
+Manual test for Detectron2 instance segmentation with visual output.
+This test loads images, runs segmentation, and displays results with overlaid masks
+showing class names and confidence scores.
+"""
+
+from itertools import islice
+from pathlib import Path
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from graid.data.ImageLoader import Bdd100kDataset
+from graid.models.Detectron import Detectron_seg
+from graid.utilities.common import get_default_device
+
+
+def draw_masks_with_labels(image, results, alpha=0.5):
+ """
+ Draw segmentation masks on image with class labels and confidence scores.
+
+ Args:
+ image: Original image as numpy array (H, W, C)
+ results: List of InstanceSegmentationResultI objects
+ alpha: Transparency factor for mask overlay
+
+ Returns:
+ Image with overlaid masks and labels
+ """
+ # Convert to RGB if needed
+ if image.shape[-1] == 3:
+ display_image = image.copy()
+ else:
+ display_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ # Colors for different instances (using distinct colors)
+ colors = [
+ (255, 0, 0), # Red
+ (0, 255, 0), # Green
+ (0, 0, 255), # Blue
+ (255, 255, 0), # Yellow
+ (255, 0, 255), # Magenta
+ (0, 255, 255), # Cyan
+ (255, 128, 0), # Orange
+ (128, 0, 255), # Purple
+ (255, 192, 203), # Pink
+ (128, 128, 0), # Olive
+ ]
+
+ # Create overlay for masks
+ overlay = display_image.copy()
+
+ for i, result in enumerate(results):
+ # Get mask as numpy array
+ mask = result.as_tensor()[0].cpu().numpy().astype(bool)
+
+ # Get color for this instance
+ color = colors[i % len(colors)]
+
+ # Apply colored mask
+ overlay[mask] = color
+
+ # Find bounding box for label placement
+ mask_indices = np.where(mask)
+ if len(mask_indices[0]) > 0:
+ y_min, y_max = mask_indices[0].min(), mask_indices[0].max()
+ x_min, x_max = mask_indices[1].min(), mask_indices[1].max()
+
+ # Create label text
+ label_text = f"{result.label}: {result.score:.1%}"
+
+ # Calculate text position (top-left of bounding box)
+ text_x = max(0, x_min)
+ text_y = max(20, y_min)
+
+ # Draw text background
+ text_size = cv2.getTextSize(
+ label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)[0]
+ cv2.rectangle(
+ overlay,
+ (text_x, text_y - text_size[1] - 5),
+ (text_x + text_size[0] + 10, text_y + 5),
+ (0, 0, 0),
+ -1,
+ )
+
+ # Draw text
+ cv2.putText(
+ overlay,
+ label_text,
+ (text_x + 5, text_y),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.7,
+ (255, 255, 255),
+ 2,
+ )
+
+ # Blend original image with overlay
+ result_image = cv2.addWeighted(display_image, 1 - alpha, overlay, alpha, 0)
+
+ return result_image
+
+
+def main():
+ """Run manual segmentation visualization test."""
+
+ print("=== Manual Detectron2 Segmentation Visualization Test ===")
+
+ # Configuration
+ NUM_IMAGES = 3
+ SAVE_RESULTS = True
+ SHOW_PLOTS = True
+
+ # Initialize dataset
+ print("Loading BDD100K dataset...")
+ bdd = Bdd100kDataset(
+ split="val", use_original_categories=False, use_extended_annotations=False)
+
+ # Initialize model with standard Detectron2 Mask R-CNN
+ print("Loading Detectron2 segmentation model...")
+ model = Detectron_seg(
+ config_file="COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
+ weights_file="COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
+ threshold=0.3, # Lower threshold to see more detections
+ device=get_default_device(),
+ )
+
+ print(f"Model loaded on device: {get_default_device()}")
+
+ # Create output directory
+ output_dir = Path("segmentation_results")
+ if SAVE_RESULTS:
+ output_dir.mkdir(exist_ok=True)
+ print(f"Results will be saved to: {output_dir}")
+
+ # Process images
+ print(f"\nProcessing {NUM_IMAGES} images...")
+
+ for i, data in enumerate(islice(bdd, NUM_IMAGES)):
+ print(f"\n--- Image {i+1}/{NUM_IMAGES} ---")
+
+ # Extract image and metadata
+ image_tensor = data["image"] # Shape: [3, H, W]
+ image_name = data["name"]
+
+ # Convert tensor to numpy for visualization (CHW -> HWC)
+ image_np = image_tensor.permute(1, 2, 0).cpu().numpy()
+
+ # Ensure values are in [0, 255] range
+ if image_np.max() <= 1.0:
+ image_np = (image_np * 255).astype(np.uint8)
+ else:
+ image_np = image_np.astype(np.uint8)
+
+ print(f"Image: {image_name}")
+ print(f"Shape: {image_np.shape}")
+
+ # Run segmentation
+ print("Running segmentation...")
+ results = model.identify_for_image(image_tensor)
+
+ print(f"Found {len(results)} instances:")
+ for j, result in enumerate(results):
+ print(f" {j+1}. {result.label}: {result.score:.1%}")
+
+ # Create visualization
+ if len(results) > 0:
+ print("Creating visualization...")
+ vis_image = draw_masks_with_labels(image_np, results)
+
+ # Display results
+ if SHOW_PLOTS:
+ plt.figure(figsize=(15, 10))
+
+ # Original image
+ plt.subplot(1, 2, 1)
+ plt.imshow(image_np)
+ plt.title(f"Original Image: {image_name}")
+ plt.axis("off")
+
+ # Segmentation results
+ plt.subplot(1, 2, 2)
+ plt.imshow(vis_image)
+ plt.title(f"Segmentation Results ({len(results)} instances)")
+ plt.axis("off")
+
+ plt.tight_layout()
+
+ if SAVE_RESULTS:
+ save_path = output_dir / \
+ f"segmentation_{i+1}_{image_name}.png"
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
+ print(f"Saved: {save_path}")
+
+ plt.show()
+ else:
+ print("No instances detected in this image.")
+
+ print(f"\n=== Test Complete ===")
+ if SAVE_RESULTS:
+ print(f"Results saved to: {output_dir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/manual_tests/test_mask2former_seg.py b/tests/manual_tests/test_mask2former_seg.py
index 7d99493..6256889 100644
--- a/tests/manual_tests/test_mask2former_seg.py
+++ b/tests/manual_tests/test_mask2former_seg.py
@@ -1,13 +1,4 @@
-#!/usr/bin/env python3
-"""
-Manual test for Mask2Former instance/panoptic segmentation with visual output.
-This test loads images, runs segmentation, and displays results with overlaid masks
-showing class names and confidence scores. Automatically uses panoptic segmentation
-if available, otherwise falls back to instance segmentation.
-"""
-
import sys
-from itertools import islice
from pathlib import Path
import cv2
@@ -15,9 +6,9 @@
import numpy as np
import torch
-from graid.data.ImageLoader import Bdd100kDataset
+from graid.data.ImageLoader import Bdd10kDataset
from graid.models.MMDetection import MMdetection_seg
-from graid.utilities.common import yolo_bdd_transform
+from graid.utilities.common import yolo_bdd_seg_transform
sys.path.append("/work/ke/research/scenic-reasoning/graid/src")
@@ -81,7 +72,8 @@ def draw_masks_with_labels(image, results, alpha=0.5):
text_y = max(20, y_min)
# Draw text background
- text_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)[0]
+ text_size = cv2.getTextSize(
+ label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)[0]
cv2.rectangle(
overlay,
(text_x, text_y - text_size[1] - 5),
@@ -118,12 +110,10 @@ def main():
SHOW_PLOTS = True
# Initialize dataset
- print("Loading BDD100K dataset...")
- dataset = Bdd100kDataset(
+ print("Loading BDD10K dataset...")
+ dataset = Bdd10kDataset(
split="val",
- transform=lambda i, l: yolo_bdd_transform(i, l, new_shape=(768, 1280)),
- use_original_categories=False,
- use_extended_annotations=False,
+ # transform=lambda i, l: yolo_bdd_seg_transform(i, l, new_shape=(768, 1280)),
)
# Initialize model with Mask2Former Swin-L
@@ -241,7 +231,8 @@ def main():
plt.tight_layout()
if SAVE_RESULTS:
- save_path = output_dir / f"mask2former_{i+1}_{image_name}.png"
+ save_path = output_dir / \
+ f"mask2former_{i+1}_{image_name}.png"
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"Saved: {save_path}")
diff --git a/tests/manual_tests/test_wbf_obj_visual.py b/tests/manual_tests/test_wbf_obj_visual.py
new file mode 100644
index 0000000..acf6159
--- /dev/null
+++ b/tests/manual_tests/test_wbf_obj_visual.py
@@ -0,0 +1,225 @@
+import matplotlib.pyplot as plt
+import numpy as np
+from itertools import islice
+from pathlib import Path
+from typing import Any
+
+from graid.data.ImageLoader import Bdd100kDataset, NuImagesDataset, WaymoDataset
+from graid.interfaces.ObjectDetectionI import ObjectDetectionResultI
+from graid.models.MMDetection import MMdetection_obj
+from graid.models.Ultralytics import RT_DETR, Yolo
+from graid.models.WBF import WBF
+from graid.utilities.common import get_default_device, project_root_dir, yolo_nuscene_transform, yolo_waymo_transform
+from graid.models.Detectron import Detectron_obj
+
+from PIL import Image, ImageDraw
+
+
+def filter_detections_by_score(
+ detections: list[ObjectDetectionResultI], min_score: float = 0.4
+) -> list[ObjectDetectionResultI]:
+ """Filter out detections with scores below the minimum threshold."""
+ return [det for det in detections if det.score >= min_score]
+
+
+def draw_boxes(
+ image: np.ndarray[Any, np.dtype[np.uint8]], detections: list[ObjectDetectionResultI], alpha: float = 1.0
+) -> np.ndarray[Any, np.dtype[np.uint8]]:
+ """Overlay detections on an RGB image and return the visualised image."""
+
+ colours: list[tuple[int, int, int]] = [
+ (255, 0, 0),
+ (0, 255, 0),
+ (0, 0, 255),
+ (255, 255, 0),
+ (255, 0, 255),
+ (0, 255, 255),
+ (255, 128, 0),
+ (128, 0, 255),
+ (255, 192, 203),
+ (128, 128, 0),
+ ]
+
+ # Pillow path ---------------------------------------------------
+ pil_img = Image.fromarray(image)
+ draw = ImageDraw.Draw(pil_img, "RGBA")
+ for i, det in enumerate(detections):
+ colour = colours[i % len(colours)] + (255,)
+ x1, y1, x2, y2 = map(int, det.as_xyxy().squeeze()[:4].tolist())
+ label = f"{det.label}: {det.score:.1%}"
+ draw.rectangle([x1, y1, x2, y2], outline=colour, width=2)
+ text_size = draw.textlength(label)
+ draw.rectangle([x1, y1 - 15, x1 + text_size + 4, y1], fill=colour)
+ draw.text((x1 + 2, y1 - 14), label, fill=(0, 0, 0, int(255 * alpha)))
+ print(
+ f"Found {det.label}: {det.score:.1%} at {x1/image.shape[1]}, {y1/image.shape[0]}, {x2/image.shape[1]}, {y2/image.shape[0]}")
+ return np.array(pil_img)
+
+
+# ----------------------------------------------------------------------------
+# Model loading helpers
+# ----------------------------------------------------------------------------
+
+def load_dino() -> MMdetection_obj:
+ mmdet = project_root_dir() / "install" / "mmdetection"
+ cfg = str(mmdet / "configs/dino/dino-5scale_swin-l_8xb2-12e_coco.py")
+ ckpt = (
+ "https://download.openmmlab.com/mmdetection/v3.0/dino/"
+ "dino-5scale_swin-l_8xb2-12e_coco/"
+ "dino-5scale_swin-l_8xb2-12e_coco_20230228_072924-a654145f.pth"
+ )
+ return MMdetection_obj(cfg, ckpt, device=get_default_device())
+
+
+def load_codetr() -> MMdetection_obj:
+ mmdet = project_root_dir() / "install" / "mmdetection"
+ cfg = str(
+ mmdet
+ / "projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_lsj_16xb1_3x_coco.py"
+ )
+ ckpt = (
+ "https://download.openmmlab.com/mmdetection/v3.0/codetr/"
+ "co_dino_5scale_lsj_swin_large_1x_coco-3af73af2.pth"
+ )
+ return MMdetection_obj(cfg, ckpt, device=get_default_device())
+
+
+def load_rtdetr() -> RT_DETR:
+ return RT_DETR("rtdetr-x.pt")
+
+
+def load_yolo_v10x() -> Yolo:
+ return Yolo("yolov10x.pt")
+
+
+def load_mask_rcnn_detectron() -> Detectron_obj:
+ """Load Detectron2 Mask R-CNN R-50 FPN model."""
+ # Use a simpler, well-supported model config
+ cfg = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"
+ ckpt = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"
+ return Detectron_obj(cfg, ckpt, device=get_default_device())
+
+
+def main():
+ print("=== WBF Ensemble Bounding-Box Visualisation ===")
+
+ NUM_IMAGES = 3
+ SAVE = True
+ SHOW = True
+ SCORE_THRESHOLD = 0.25 # Minimum confidence score for detections
+
+ datasets = []
+
+ # BDD100K dataset
+ print("Loading BDD100K validation images β¦")
+ bdd_dataset = Bdd100kDataset(split="val")
+ datasets.append(("BDD100K", "val", bdd_dataset))
+ print(f"β BDD100K loaded successfully ({len(bdd_dataset)} images)")
+
+ # NuImages dataset
+ print("Loading NuImages validation images β¦")
+ nuimages_dataset = NuImagesDataset(
+ split="val",
+ transform=lambda i, l: yolo_nuscene_transform(
+ i, l, new_shape=(896, 1600))
+ )
+ datasets.append(("NuImages", "val", nuimages_dataset))
+ print(
+ f"β NuImages loaded successfully ({len(nuimages_dataset)} images)")
+
+ # Waymo dataset
+ print("Loading Waymo validation images β¦")
+ waymo_dataset = WaymoDataset(
+ split="validation",
+ transform=lambda i, l: yolo_waymo_transform(i, l, (1280, 1920))
+ )
+ datasets.append(("Waymo", "validation", waymo_dataset))
+ print(f"β Waymo loaded successfully ({len(waymo_dataset)} images)")
+
+ if not datasets:
+ raise RuntimeError("No datasets could be loaded successfully!")
+
+ print(f"\nSuccessfully loaded {len(datasets)} dataset(s)")
+
+ print("Initialising base models β¦")
+ dino = load_dino()
+ codetr = load_codetr()
+ rtdetr = load_rtdetr()
+ yolo10x = load_yolo_v10x()
+ mask_rcnn = load_mask_rcnn_detectron()
+
+ # Assemble WBF
+ ensemble = WBF(
+ detectron2_models=[mask_rcnn],
+ mmdet_models=[dino, codetr],
+ ultralytics_models=[rtdetr, yolo10x],
+ model_weights=[0.8, 0.8, 1.0, 0.9, 0.8],
+ iou_threshold=0.55,
+ skip_box_threshold=0.01,
+ )
+ print("WBF ensemble ready!")
+ print(f"Using score threshold: {SCORE_THRESHOLD}")
+
+ out_dir = Path("wbf_results")
+ if SAVE:
+ out_dir.mkdir(exist_ok=True)
+ print(f"Saving results to {out_dir.resolve()}")
+
+ # Process images from each dataset
+ for dataset_name, split, dataset in datasets:
+ print(f"\n--- Processing {dataset_name} dataset ---")
+
+ for idx, data in enumerate(islice(dataset, NUM_IMAGES)):
+ img_tensor = data["image"]
+ filename = data["name"]
+
+ # Sanitize and shorten filename
+ filename = filename.replace("/", "_").replace("\\", "_")
+ short_filename = Path(filename).stem
+
+ # Create helpful filename: dataset_split_shortname_wbf.png
+ output_filename = f"{dataset_name.lower()}_{split}_{short_filename}.png"
+
+ # Convert CHW tensor β HWC numpy in the correct value range
+ img_np = img_tensor.permute(1, 2, 0).cpu().numpy()
+
+ # Dataset may already be in [0,255] β mimic logic from
+ if img_np.max() <= 1.0:
+ img_np = (img_np * 255).astype(np.uint8)
+ else:
+ img_np = img_np.astype(np.uint8)
+
+ print(f"[{idx+1}/{NUM_IMAGES}] {dataset_name} - {filename}")
+ detections = ensemble.identify_for_image(img_tensor)
+ print(f" β {len(detections)} raw detections")
+
+ # Apply post-processing filter
+ filtered_detections = filter_detections_by_score(
+ detections, SCORE_THRESHOLD)
+ print(
+ f" β {len(filtered_detections)} detections after filtering (score >= {SCORE_THRESHOLD})")
+
+ if len(filtered_detections) == 0:
+ print(" β No detections remain after filtering, skipping visualization")
+ continue
+
+ vis = draw_boxes(img_np, filtered_detections)
+
+ if SHOW:
+ plt.figure(figsize=(10, 6))
+ plt.imshow(vis)
+ plt.title(
+ f"WBF fused detections (score >= {SCORE_THRESHOLD}) β {dataset_name} - {filename}")
+ plt.axis("off")
+ plt.show()
+
+ if SAVE:
+ save_path = out_dir / output_filename
+ Image.fromarray(vis).save(save_path)
+ print(f" Saved β {save_path.relative_to(out_dir.parent)}")
+
+ print("\n=== Done. ===")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/graid/tests/count_questions.py b/tests/unit_tests/count_questions.py
similarity index 95%
rename from graid/tests/count_questions.py
rename to tests/unit_tests/count_questions.py
index d764ea8..480bca1 100644
--- a/graid/tests/count_questions.py
+++ b/tests/unit_tests/count_questions.py
@@ -1,11 +1,8 @@
import time
import torchvision.transforms as transforms
-from graid.data.ImageLoader import (
- Bdd100kDataset,
- NuImagesDataset,
- WaymoDataset,
-)
+
+from graid.data.ImageLoader import Bdd100kDataset, NuImagesDataset, WaymoDataset
from graid.interfaces.ObjectDetectionI import ObjectDetectionUtils
from graid.questions.ObjectDetectionQ import (
HowMany,
diff --git a/graid/tests/detection_inst_seg.py b/tests/unit_tests/detection_inst_seg.py
similarity index 100%
rename from graid/tests/detection_inst_seg.py
rename to tests/unit_tests/detection_inst_seg.py
diff --git a/tests/unit_tests/test_coco_utilities.py b/tests/unit_tests/test_coco_utilities.py
new file mode 100644
index 0000000..feed246
--- /dev/null
+++ b/tests/unit_tests/test_coco_utilities.py
@@ -0,0 +1,189 @@
+"""
+Unit tests for COCO utility functions.
+
+This module tests the COCO validation and filtering functions that handle
+allowable sets for object detection.
+"""
+
+import unittest
+from unittest.mock import patch
+
+from graid.utilities.coco import (
+ validate_coco_objects,
+ get_valid_coco_objects,
+ filter_detections_by_allowable_set,
+)
+
+
+class TestCocoUtilities(unittest.TestCase):
+ """Test cases for COCO utility functions."""
+
+ def test_validate_coco_objects_empty_list(self):
+ """Test validation with empty list."""
+ result = validate_coco_objects([])
+ self.assertEqual(result, (True, None))
+
+ def test_validate_coco_objects_none(self):
+ """Test validation with None (should treat as empty)."""
+ # Note: The function signature expects list[str], but we handle None gracefully
+ result = validate_coco_objects([]) # Use empty list instead of None
+ self.assertEqual(result, (True, None))
+
+ def test_validate_coco_objects_valid_objects(self):
+ """Test validation with valid COCO objects."""
+ valid_objects = ["person", "car", "truck", "bicycle"]
+ result = validate_coco_objects(valid_objects)
+ self.assertEqual(result, (True, None))
+
+ def test_validate_coco_objects_invalid_objects(self):
+ """Test validation with invalid COCO objects."""
+ invalid_objects = ["person", "invalid_object", "car"]
+ is_valid, error_msg = validate_coco_objects(invalid_objects)
+ self.assertFalse(is_valid)
+ self.assertIsNotNone(error_msg)
+ assert error_msg is not None # Type assertion for linter
+ self.assertIn("Invalid COCO objects", error_msg)
+ self.assertIn("invalid_object", error_msg)
+
+ def test_validate_coco_objects_all_invalid(self):
+ """Test validation with all invalid objects."""
+ invalid_objects = ["invalid1", "invalid2", "invalid3"]
+ is_valid, error_msg = validate_coco_objects(invalid_objects)
+ self.assertFalse(is_valid)
+ self.assertIsNotNone(error_msg)
+ assert error_msg is not None # Type assertion for linter
+ self.assertIn("Invalid COCO objects", error_msg)
+ self.assertIn("invalid1", error_msg)
+ self.assertIn("invalid2", error_msg)
+ self.assertIn("invalid3", error_msg)
+
+ def test_get_valid_coco_objects_returns_sorted_list(self):
+ """Test that get_valid_coco_objects returns a sorted list."""
+ valid_objects = get_valid_coco_objects()
+ self.assertIsInstance(valid_objects, list)
+ self.assertEqual(len(valid_objects), 80) # Standard COCO has 80 classes
+ self.assertEqual(valid_objects, sorted(valid_objects))
+
+ def test_get_valid_coco_objects_excludes_undefined(self):
+ """Test that get_valid_coco_objects excludes 'undefined'."""
+ valid_objects = get_valid_coco_objects()
+ self.assertNotIn("undefined", valid_objects)
+
+ def test_get_valid_coco_objects_contains_common_objects(self):
+ """Test that get_valid_coco_objects contains common objects."""
+ valid_objects = get_valid_coco_objects()
+ common_objects = ["person", "car", "truck", "bicycle", "dog", "cat"]
+ for obj in common_objects:
+ self.assertIn(obj, valid_objects)
+
+ def test_filter_detections_by_allowable_set_none_allowable_set(self):
+ """Test filtering with None allowable set (should return all)."""
+ detections = [
+ {"class": "person", "confidence": 0.9, "bbox": [0, 0, 10, 10]},
+ {"class": "car", "confidence": 0.8, "bbox": [20, 20, 30, 30]},
+ {"class": "bicycle", "confidence": 0.7, "bbox": [40, 40, 50, 50]},
+ ]
+
+ filtered = filter_detections_by_allowable_set(detections, None)
+ self.assertEqual(len(filtered), 3)
+ self.assertEqual(filtered, detections)
+
+ def test_filter_detections_by_allowable_set_empty_allowable_set(self):
+ """Test filtering with empty allowable set (should return all)."""
+ detections = [
+ {"class": "person", "confidence": 0.9, "bbox": [0, 0, 10, 10]},
+ {"class": "car", "confidence": 0.8, "bbox": [20, 20, 30, 30]},
+ ]
+
+ filtered = filter_detections_by_allowable_set(detections, [])
+ self.assertEqual(len(filtered), 2)
+ self.assertEqual(filtered, detections)
+
+ def test_filter_detections_by_allowable_set_filters_correctly(self):
+ """Test filtering with specific allowable set."""
+ detections = [
+ {"class": "person", "confidence": 0.9, "bbox": [0, 0, 10, 10]},
+ {"class": "car", "confidence": 0.8, "bbox": [20, 20, 30, 30]},
+ {"class": "bicycle", "confidence": 0.7, "bbox": [40, 40, 50, 50]},
+ {"class": "dog", "confidence": 0.6, "bbox": [60, 60, 70, 70]},
+ ]
+
+ allowable_set = ["person", "car"]
+ filtered = filter_detections_by_allowable_set(detections, allowable_set)
+
+ self.assertEqual(len(filtered), 2)
+ self.assertEqual(filtered[0]["class"], "person")
+ self.assertEqual(filtered[1]["class"], "car")
+
+ def test_filter_detections_by_allowable_set_different_class_keys(self):
+ """Test filtering with different class key names."""
+ detections = [
+ {"label": "person", "confidence": 0.9, "bbox": [0, 0, 10, 10]},
+ {"name": "car", "confidence": 0.8, "bbox": [20, 20, 30, 30]},
+ {"class": "bicycle", "confidence": 0.7, "bbox": [40, 40, 50, 50]},
+ ]
+
+ allowable_set = ["person", "car"]
+ filtered = filter_detections_by_allowable_set(detections, allowable_set)
+
+ self.assertEqual(len(filtered), 2)
+ self.assertEqual(filtered[0]["label"], "person")
+ self.assertEqual(filtered[1]["name"], "car")
+
+ def test_filter_detections_by_allowable_set_no_matches(self):
+ """Test filtering when no detections match allowable set."""
+ detections = [
+ {"class": "bicycle", "confidence": 0.7, "bbox": [40, 40, 50, 50]},
+ {"class": "dog", "confidence": 0.6, "bbox": [60, 60, 70, 70]},
+ ]
+
+ allowable_set = ["person", "car"]
+ filtered = filter_detections_by_allowable_set(detections, allowable_set)
+
+ self.assertEqual(len(filtered), 0)
+
+ def test_filter_detections_by_allowable_set_missing_class_key(self):
+ """Test filtering when detections missing class key."""
+ detections = [
+ {"confidence": 0.9, "bbox": [0, 0, 10, 10]}, # Missing class key
+ {"class": "car", "confidence": 0.8, "bbox": [20, 20, 30, 30]},
+ ]
+
+ allowable_set = ["person", "car"]
+ filtered = filter_detections_by_allowable_set(detections, allowable_set)
+
+ self.assertEqual(len(filtered), 1)
+ self.assertEqual(filtered[0]["class"], "car")
+
+ def test_filter_detections_by_allowable_set_empty_detections(self):
+ """Test filtering with empty detections list."""
+ detections = []
+ allowable_set = ["person", "car"]
+ filtered = filter_detections_by_allowable_set(detections, allowable_set)
+
+ self.assertEqual(len(filtered), 0)
+
+ def test_integration_validate_and_filter(self):
+ """Test integration of validation and filtering."""
+ # First validate allowable set
+ allowable_set = ["person", "car", "truck"]
+ is_valid, error_msg = validate_coco_objects(allowable_set)
+ self.assertTrue(is_valid)
+ self.assertIsNone(error_msg)
+
+ # Then filter detections
+ detections = [
+ {"class": "person", "confidence": 0.9, "bbox": [0, 0, 10, 10]},
+ {"class": "car", "confidence": 0.8, "bbox": [20, 20, 30, 30]},
+ {"class": "bicycle", "confidence": 0.7, "bbox": [
+ 40, 40, 50, 50]}, # Should be filtered out
+ ]
+
+ filtered = filter_detections_by_allowable_set(detections, allowable_set)
+ self.assertEqual(len(filtered), 2)
+ self.assertEqual(filtered[0]["class"], "person")
+ self.assertEqual(filtered[1]["class"], "car")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/unit_tests/test_config_support.py b/tests/unit_tests/test_config_support.py
new file mode 100644
index 0000000..a01849a
--- /dev/null
+++ b/tests/unit_tests/test_config_support.py
@@ -0,0 +1,484 @@
+"""
+Unit tests for configuration support module.
+
+This module tests the configuration classes and functions for dataset generation
+including ModelConfig, WBFConfig, and DatasetGenerationConfig.
+"""
+
+import json
+import tempfile
+import unittest
+from pathlib import Path
+from unittest.mock import Mock, patch
+
+from graid.data.config_support import (
+ ConfigurationError,
+ DatasetGenerationConfig,
+ ModelConfig,
+ WBFConfig,
+ create_example_config,
+ load_config_from_dict,
+ load_config_from_file,
+ save_example_config,
+ validate_config_file,
+)
+
+
+class TestModelConfig(unittest.TestCase):
+ """Test cases for ModelConfig class."""
+
+ def test_model_config_creation_valid(self):
+ """Test creating a valid ModelConfig."""
+ config = ModelConfig(
+ backend="detectron",
+ model_name="faster_rcnn_R_50_FPN_3x",
+ confidence_threshold=0.5,
+ device="cpu",
+ custom_config={
+ "config": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
+ "weights": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
+ }
+ )
+ self.assertEqual(config.backend, "detectron")
+ self.assertEqual(config.model_name, "faster_rcnn_R_50_FPN_3x")
+ self.assertEqual(config.confidence_threshold, 0.5)
+ self.assertEqual(config.device, "cpu")
+
+ def test_model_config_invalid_backend(self):
+ """Test creating ModelConfig with invalid backend."""
+ with self.assertRaises(ConfigurationError) as context:
+ ModelConfig(
+ backend="invalid_backend",
+ model_name="some_model",
+ )
+ self.assertIn("Unsupported backend", str(context.exception))
+
+ def test_model_config_invalid_model_name(self):
+ """Test creating ModelConfig with invalid model name."""
+ with self.assertRaises(ConfigurationError) as context:
+ ModelConfig(
+ backend="detectron",
+ model_name="invalid_model",
+ )
+ self.assertIn(
+ "Custom config is required for detectron backend", str(context.exception))
+
+ def test_model_config_custom_config_detectron(self):
+ """Test ModelConfig with custom Detectron config."""
+ custom_config = {
+ "config": "path/to/config.yaml",
+ "weights": "path/to/weights.pth"
+ }
+ config = ModelConfig(
+ backend="detectron",
+ model_name="custom_model",
+ custom_config=custom_config
+ )
+ self.assertEqual(config.custom_config, custom_config)
+
+ def test_model_config_custom_config_invalid_detectron(self):
+ """Test ModelConfig with invalid custom Detectron config."""
+ custom_config = {"config": "path/to/config.yaml"} # Missing weights
+ with self.assertRaises(ConfigurationError) as context:
+ ModelConfig(
+ backend="detectron",
+ model_name="custom_model",
+ custom_config=custom_config
+ )
+ self.assertIn("must have 'config' and 'weights' keys",
+ str(context.exception))
+
+ def test_model_config_custom_config_mmdetection(self):
+ """Test ModelConfig with custom MMDetection config."""
+ custom_config = {
+ "config": "path/to/config.py",
+ "checkpoint": "path/to/checkpoint.pth"
+ }
+ config = ModelConfig(
+ backend="mmdetection",
+ model_name="custom_model",
+ custom_config=custom_config
+ )
+ self.assertEqual(config.custom_config, custom_config)
+
+ def test_model_config_custom_config_invalid_mmdetection(self):
+ """Test ModelConfig with invalid custom MMDetection config."""
+ custom_config = {"config": "path/to/config.py"} # Missing checkpoint
+ with self.assertRaises(ConfigurationError) as context:
+ ModelConfig(
+ backend="mmdetection",
+ model_name="custom_model",
+ custom_config=custom_config
+ )
+ self.assertIn("must have 'config' and 'checkpoint' keys",
+ str(context.exception))
+
+ def test_model_config_to_dict(self):
+ """Test ModelConfig to_dict method."""
+ custom_config = {
+ "config": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
+ "weights": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
+ }
+ config = ModelConfig(
+ backend="detectron",
+ model_name="faster_rcnn_R_50_FPN_3x",
+ confidence_threshold=0.7,
+ device="cuda:0",
+ custom_config=custom_config
+ )
+ result = config.to_dict()
+ expected = {
+ "backend": "detectron",
+ "model_name": "faster_rcnn_R_50_FPN_3x",
+ "custom_config": custom_config,
+ "confidence_threshold": 0.7,
+ "device": "cuda:0"
+ }
+ self.assertEqual(result, expected)
+
+
+class TestWBFConfig(unittest.TestCase):
+ """Test cases for WBFConfig class."""
+
+ def test_wbf_config_creation_default(self):
+ """Test creating WBFConfig with default values."""
+ config = WBFConfig()
+ self.assertEqual(config.iou_threshold, 0.55)
+ self.assertEqual(config.skip_box_threshold, 0.0)
+ self.assertIsNone(config.model_weights)
+
+ def test_wbf_config_creation_custom(self):
+ """Test creating WBFConfig with custom values."""
+ config = WBFConfig(
+ iou_threshold=0.6,
+ skip_box_threshold=0.1,
+ model_weights=[1.0, 2.0, 0.5]
+ )
+ self.assertEqual(config.iou_threshold, 0.6)
+ self.assertEqual(config.skip_box_threshold, 0.1)
+ self.assertEqual(config.model_weights, [1.0, 2.0, 0.5])
+
+ def test_wbf_config_invalid_iou_threshold(self):
+ """Test WBFConfig with invalid iou_threshold."""
+ with self.assertRaises(ConfigurationError) as context:
+ WBFConfig(iou_threshold=1.5)
+ self.assertIn("iou_threshold must be between 0.0 and 1.0",
+ str(context.exception))
+
+ with self.assertRaises(ConfigurationError) as context:
+ WBFConfig(iou_threshold=-0.1)
+ self.assertIn("iou_threshold must be between 0.0 and 1.0",
+ str(context.exception))
+
+ def test_wbf_config_invalid_skip_box_threshold(self):
+ """Test WBFConfig with invalid skip_box_threshold."""
+ with self.assertRaises(ConfigurationError) as context:
+ WBFConfig(skip_box_threshold=1.5)
+ self.assertIn(
+ "skip_box_threshold must be between 0.0 and 1.0", str(context.exception))
+
+ def test_wbf_config_invalid_model_weights(self):
+ """Test WBFConfig with invalid model_weights."""
+ with self.assertRaises(ConfigurationError) as context:
+ WBFConfig(model_weights=[1.0, -0.5, 2.0])
+ self.assertIn("All model weights must be positive",
+ str(context.exception))
+
+ def test_wbf_config_to_dict(self):
+ """Test WBFConfig to_dict method."""
+ config = WBFConfig(
+ iou_threshold=0.7,
+ skip_box_threshold=0.05,
+ model_weights=[1.0, 1.5]
+ )
+ result = config.to_dict()
+ expected = {
+ "iou_threshold": 0.7,
+ "skip_box_threshold": 0.05,
+ "model_weights": [1.0, 1.5]
+ }
+ self.assertEqual(result, expected)
+
+
+class TestDatasetGenerationConfig(unittest.TestCase):
+ """Test cases for DatasetGenerationConfig class."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ self.model_config: ModelConfig = self._create_model_config()
+
+ def _create_model_config(self) -> ModelConfig:
+ """Create a model config for testing."""
+ return ModelConfig(
+ backend="detectron",
+ model_name="faster_rcnn_R_50_FPN_3x",
+ custom_config={
+ "config": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
+ "weights": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
+ }
+ )
+
+ def test_dataset_generation_config_creation_valid(self):
+ """Test creating a valid DatasetGenerationConfig."""
+ config = DatasetGenerationConfig(
+ dataset_name="bdd",
+ split="val",
+ models=[self.model_config],
+ confidence_threshold=0.5,
+ batch_size=2,
+ device="cpu"
+ )
+ self.assertEqual(config.dataset_name, "bdd")
+ self.assertEqual(config.split, "val")
+ self.assertEqual(len(config.models), 1)
+ self.assertEqual(config.confidence_threshold, 0.5)
+ self.assertEqual(config.batch_size, 2)
+
+ def test_dataset_generation_config_invalid_dataset(self):
+ """Test creating DatasetGenerationConfig with invalid dataset."""
+ with self.assertRaises(ConfigurationError) as context:
+ DatasetGenerationConfig(
+ dataset_name="invalid_dataset",
+ split="val",
+ models=[self.model_config]
+ )
+ self.assertIn("Unsupported dataset", str(context.exception))
+
+ def test_dataset_generation_config_invalid_split(self):
+ """Test creating DatasetGenerationConfig with invalid split."""
+ with self.assertRaises(ConfigurationError) as context:
+ DatasetGenerationConfig(
+ dataset_name="bdd",
+ split="invalid_split",
+ models=[self.model_config]
+ )
+ self.assertIn("Invalid split", str(context.exception))
+
+ def test_dataset_generation_config_wbf_insufficient_models(self):
+ """Test WBF configuration with insufficient models."""
+ with self.assertRaises(ConfigurationError) as context:
+ DatasetGenerationConfig(
+ dataset_name="bdd",
+ split="val",
+ models=[self.model_config],
+ use_wbf=True
+ )
+ self.assertIn("WBF requires at least 2 models", str(context.exception))
+
+ def test_dataset_generation_config_wbf_weight_mismatch(self):
+ """Test WBF configuration with weight count mismatch."""
+ model_config_2 = ModelConfig(
+ backend="detectron",
+ model_name="retinanet_R_101_FPN_3x",
+ custom_config={
+ "config": "COCO-Detection/retinanet_R_101_FPN_3x.yaml",
+ "weights": "COCO-Detection/retinanet_R_101_FPN_3x.yaml"
+ }
+ )
+ # Only 1 weight for 2 models
+ wbf_config = WBFConfig(model_weights=[1.0])
+
+ with self.assertRaises(ConfigurationError) as context:
+ DatasetGenerationConfig(
+ dataset_name="bdd",
+ split="val",
+ models=[self.model_config, model_config_2],
+ use_wbf=True,
+ wbf_config=wbf_config
+ )
+ self.assertIn("Number of model weights", str(context.exception))
+
+ def test_dataset_generation_config_allowable_set_valid(self):
+ """Test DatasetGenerationConfig with valid allowable set."""
+ config = DatasetGenerationConfig(
+ dataset_name="bdd",
+ split="val",
+ models=[self.model_config],
+ allowable_set=["person", "car", "truck"]
+ )
+ self.assertEqual(config.allowable_set, ["person", "car", "truck"])
+
+ def test_dataset_generation_config_allowable_set_invalid(self):
+ """Test DatasetGenerationConfig with invalid allowable set."""
+ with self.assertRaises(ConfigurationError) as context:
+ DatasetGenerationConfig(
+ dataset_name="bdd",
+ split="val",
+ models=[self.model_config],
+ allowable_set=["person", "invalid_object"]
+ )
+ self.assertIn("Invalid COCO objects in allowable_set",
+ str(context.exception))
+
+ def test_dataset_generation_config_to_dict(self):
+ """Test DatasetGenerationConfig to_dict method."""
+ config = DatasetGenerationConfig(
+ dataset_name="bdd",
+ split="val",
+ models=[self.model_config],
+ confidence_threshold=0.6,
+ allowable_set=["person", "car"]
+ )
+ result = config.to_dict()
+
+ self.assertEqual(result["dataset_name"], "bdd")
+ self.assertEqual(result["split"], "val")
+ self.assertEqual(len(result["models"]), 1)
+ self.assertEqual(result["confidence_threshold"], 0.6)
+ self.assertEqual(result["allowable_set"], ["person", "car"])
+
+
+class TestConfigurationFunctions(unittest.TestCase):
+ """Test cases for configuration utility functions."""
+
+ def test_create_example_config(self):
+ """Test creating example configuration."""
+ example = create_example_config()
+
+ self.assertIn("dataset_name", example)
+ self.assertIn("split", example)
+ self.assertIn("models", example)
+ self.assertIn("allowable_set", example)
+ self.assertIsInstance(example["models"], list)
+ self.assertGreater(len(example["models"]), 0)
+
+ def test_load_config_from_dict_valid(self):
+ """Test loading configuration from valid dictionary."""
+ config_dict = {
+ "dataset_name": "bdd",
+ "split": "val",
+ "models": [
+ {
+ "backend": "detectron",
+ "model_name": "faster_rcnn_R_50_FPN_3x",
+ "confidence_threshold": 0.5,
+ "custom_config": {
+ "config": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
+ "weights": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
+ }
+ }
+ ],
+ "allowable_set": ["person", "car"]
+ }
+
+ config = load_config_from_dict(config_dict)
+ self.assertEqual(config.dataset_name, "bdd")
+ self.assertEqual(config.split, "val")
+ self.assertEqual(len(config.models), 1)
+ self.assertEqual(config.allowable_set, ["person", "car"])
+
+ def test_load_config_from_file_valid(self):
+ """Test loading configuration from valid file."""
+ config_dict = {
+ "dataset_name": "bdd",
+ "split": "val",
+ "models": [
+ {
+ "backend": "detectron",
+ "model_name": "faster_rcnn_R_50_FPN_3x",
+ "custom_config": {
+ "config": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
+ "weights": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
+ }
+ }
+ ]
+ }
+
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
+ json.dump(config_dict, f)
+ temp_path = f.name
+
+ try:
+ config = load_config_from_file(temp_path)
+ self.assertEqual(config.dataset_name, "bdd")
+ self.assertEqual(config.split, "val")
+ self.assertEqual(len(config.models), 1)
+ finally:
+ Path(temp_path).unlink()
+
+ def test_load_config_from_file_invalid_json(self):
+ """Test loading configuration from invalid JSON file."""
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
+ f.write("invalid json content")
+ temp_path = f.name
+
+ try:
+ with self.assertRaises(ConfigurationError) as context:
+ load_config_from_file(temp_path)
+ self.assertIn("Invalid JSON", str(context.exception))
+ finally:
+ Path(temp_path).unlink()
+
+ def test_load_config_from_file_nonexistent(self):
+ """Test loading configuration from non-existent file."""
+ with self.assertRaises(ConfigurationError):
+ load_config_from_file("nonexistent_file.json")
+
+ def test_save_example_config(self):
+ """Test saving example configuration to file."""
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
+ temp_path = f.name
+
+ try:
+ save_example_config(temp_path)
+
+ # Verify file was created and contains valid JSON
+ with open(temp_path, 'r') as f:
+ loaded_config = json.load(f)
+
+ self.assertIn("dataset_name", loaded_config)
+ self.assertIn("models", loaded_config)
+ self.assertIsInstance(loaded_config["models"], list)
+ finally:
+ Path(temp_path).unlink()
+
+ def test_validate_config_file_valid(self):
+ """Test validating a valid configuration file."""
+ config_dict = {
+ "dataset_name": "bdd",
+ "split": "val",
+ "models": [
+ {
+ "backend": "detectron",
+ "model_name": "faster_rcnn_R_50_FPN_3x",
+ "custom_config": {
+ "config": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
+ "weights": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
+ }
+ }
+ ]
+ }
+
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
+ json.dump(config_dict, f)
+ temp_path = f.name
+
+ try:
+ is_valid, error_msg = validate_config_file(temp_path)
+ self.assertTrue(is_valid)
+ self.assertIsNone(error_msg)
+ finally:
+ Path(temp_path).unlink()
+
+ def test_validate_config_file_invalid(self):
+ """Test validating an invalid configuration file."""
+ config_dict = {
+ "dataset_name": "invalid_dataset",
+ "split": "val",
+ "models": []
+ }
+
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
+ json.dump(config_dict, f)
+ temp_path = f.name
+
+ try:
+ is_valid, error_msg = validate_config_file(temp_path)
+ self.assertFalse(is_valid)
+ self.assertIsNotNone(error_msg)
+ finally:
+ Path(temp_path).unlink()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/unit_tests/test_dataset_generation.py b/tests/unit_tests/test_dataset_generation.py
new file mode 100644
index 0000000..19d1bc2
--- /dev/null
+++ b/tests/unit_tests/test_dataset_generation.py
@@ -0,0 +1,391 @@
+"""
+Unit tests for dataset generation functionality.
+
+This module tests the core dataset generation functionality including
+the HuggingFaceDatasetBuilder class and generate_dataset function.
+"""
+
+import unittest
+from unittest.mock import Mock, patch, MagicMock
+import torch
+from PIL import Image
+import numpy as np
+
+from graid.data.generate_dataset import (
+ HuggingFaceDatasetBuilder,
+ generate_dataset,
+ validate_model_config,
+ validate_models_batch,
+ validate_wbf_compatibility,
+ list_available_models,
+)
+
+
+class TestHuggingFaceDatasetBuilder(unittest.TestCase):
+ """Test cases for HuggingFaceDatasetBuilder class."""
+
+ def test_init_valid_parameters(self):
+ """Test initialization with valid parameters."""
+ builder = HuggingFaceDatasetBuilder(
+ dataset_name="bdd",
+ split="val",
+ models=[],
+ conf_threshold=0.5,
+ batch_size=1,
+ device="cpu"
+ )
+ self.assertEqual(builder.dataset_name, "bdd")
+ self.assertEqual(builder.split, "val")
+ self.assertEqual(builder.conf_threshold, 0.5)
+ self.assertEqual(builder.batch_size, 1)
+
+ def test_init_invalid_dataset(self):
+ """Test initialization with invalid dataset name."""
+ with self.assertRaises(ValueError) as context:
+ HuggingFaceDatasetBuilder(
+ dataset_name="invalid_dataset",
+ split="val",
+ models=[]
+ )
+ self.assertIn("Unsupported dataset", str(context.exception))
+
+ def test_init_with_allowable_set_valid(self):
+ """Test initialization with valid allowable set."""
+ builder = HuggingFaceDatasetBuilder(
+ dataset_name="bdd",
+ split="val",
+ models=[],
+ allowable_set=["person", "car", "truck"]
+ )
+ self.assertEqual(builder.allowable_set, ["person", "car", "truck"])
+
+ @patch('graid.data.generate_dataset.validate_coco_objects')
+ def test_init_with_allowable_set_invalid(self, mock_validate):
+ """Test initialization with invalid allowable set."""
+ mock_validate.return_value = (False, "Invalid objects")
+
+ with self.assertRaises(ValueError) as context:
+ HuggingFaceDatasetBuilder(
+ dataset_name="bdd",
+ split="val",
+ models=[],
+ allowable_set=["person", "invalid_object"]
+ )
+ self.assertIn("Invalid allowable_set", str(context.exception))
+
+ def test_convert_image_to_pil_tensor(self):
+ """Test converting tensor to PIL image."""
+ builder = HuggingFaceDatasetBuilder(
+ dataset_name="bdd",
+ split="val",
+ models=[]
+ )
+
+ # Create a dummy tensor (3, 224, 224) with values 0-1
+ tensor = torch.rand(3, 224, 224)
+
+ pil_image = builder._convert_image_to_pil(tensor)
+
+ self.assertIsInstance(pil_image, Image.Image)
+ self.assertEqual(pil_image.size, (224, 224))
+
+ def test_convert_image_to_pil_numpy(self):
+ """Test converting numpy array to PIL image."""
+ builder = HuggingFaceDatasetBuilder(
+ dataset_name="bdd",
+ split="val",
+ models=[]
+ )
+
+ # Create a dummy numpy array (224, 224, 3) with values 0-255
+ numpy_array = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
+
+ pil_image = builder._convert_image_to_pil(numpy_array)
+
+ self.assertIsInstance(pil_image, Image.Image)
+ self.assertEqual(pil_image.size, (224, 224))
+
+ def test_create_metadata(self):
+ """Test metadata creation."""
+ builder = HuggingFaceDatasetBuilder(
+ dataset_name="bdd",
+ split="val",
+ models=[],
+ conf_threshold=0.3,
+ batch_size=2,
+ device="cpu"
+ )
+
+ metadata = builder._create_metadata()
+
+ self.assertEqual(metadata["dataset_name"], "bdd")
+ self.assertEqual(metadata["split"], "val")
+ self.assertEqual(metadata["confidence_threshold"], 0.3)
+ self.assertEqual(metadata["batch_size"], 2)
+ self.assertEqual(metadata["device"], "cpu")
+ self.assertIn("questions", metadata)
+ self.assertIn("models", metadata)
+
+
+class TestGenerateDataset(unittest.TestCase):
+ """Test cases for generate_dataset function."""
+
+ @patch('graid.data.generate_dataset.HuggingFaceDatasetBuilder')
+ def test_generate_dataset_basic(self, mock_builder_class):
+ """Test basic dataset generation."""
+ mock_builder = Mock()
+ mock_dataset = Mock()
+ mock_builder.build.return_value = mock_dataset
+ mock_builder_class.return_value = mock_builder
+
+ result = generate_dataset(
+ dataset_name="bdd",
+ split="val",
+ models=[],
+ conf_threshold=0.5
+ )
+
+ # Verify builder was created with correct parameters
+ mock_builder_class.assert_called_once()
+ call_args = mock_builder_class.call_args
+ self.assertEqual(call_args[1]["dataset_name"], "bdd")
+ self.assertEqual(call_args[1]["split"], "val")
+ self.assertEqual(call_args[1]["conf_threshold"], 0.5)
+
+ # Verify build was called
+ mock_builder.build.assert_called_once()
+ self.assertEqual(result, mock_dataset)
+
+ @patch('graid.data.generate_dataset.HuggingFaceDatasetBuilder')
+ def test_generate_dataset_with_allowable_set(self, mock_builder_class):
+ """Test dataset generation with allowable set."""
+ mock_builder = Mock()
+ mock_dataset = Mock()
+ mock_builder.build.return_value = mock_dataset
+ mock_builder_class.return_value = mock_builder
+
+ allowable_set = ["person", "car", "truck"]
+ result = generate_dataset(
+ dataset_name="bdd",
+ split="val",
+ models=[],
+ allowable_set=allowable_set
+ )
+
+ # Verify builder was created with allowable_set
+ call_args = mock_builder_class.call_args
+ self.assertEqual(call_args[1]["allowable_set"], allowable_set)
+
+
+class TestValidationFunctions(unittest.TestCase):
+ """Test cases for model validation functions."""
+
+ def test_validate_model_config_valid_detectron(self):
+ """Test model validation with valid Detectron model."""
+ with patch('graid.data.generate_dataset.create_model') as mock_create:
+ mock_model = Mock()
+ mock_create.return_value = mock_model
+
+ is_valid, error_msg = validate_model_config(
+ backend="detectron",
+ model_name="faster_rcnn_R_50_FPN_3x",
+ device="cpu"
+ )
+
+ self.assertTrue(is_valid)
+ self.assertIsNone(error_msg)
+ mock_create.assert_called_once()
+
+ def test_validate_model_config_invalid_backend(self):
+ """Test model validation with invalid backend."""
+ is_valid, error_msg = validate_model_config(
+ backend="invalid_backend",
+ model_name="some_model",
+ device="cpu"
+ )
+
+ self.assertFalse(is_valid)
+ self.assertIsNotNone(error_msg)
+ self.assertIn("Unsupported backend", error_msg)
+
+ def test_validate_model_config_invalid_model_name(self):
+ """Test model validation with invalid model name."""
+ is_valid, error_msg = validate_model_config(
+ backend="detectron",
+ model_name="invalid_model",
+ device="cpu"
+ )
+
+ self.assertFalse(is_valid)
+ self.assertIsNotNone(error_msg)
+ self.assertIn("Detectron backend requires custom_config", error_msg)
+
+ def test_validate_models_batch_valid(self):
+ """Test batch validation with valid models."""
+ model_configs = [
+ {
+ "backend": "detectron",
+ "model_name": "faster_rcnn_R_50_FPN_3x",
+ "config": None
+ },
+ {
+ "backend": "detectron",
+ "model_name": "retinanet_R_101_FPN_3x",
+ "config": None
+ }
+ ]
+
+ with patch('graid.data.generate_dataset.validate_model_config') as mock_validate:
+ mock_validate.return_value = (True, None)
+
+ results = validate_models_batch(model_configs, device="cpu")
+
+ self.assertEqual(len(results), 2)
+ for key, (is_valid, error_msg) in results.items():
+ self.assertTrue(is_valid)
+ self.assertIsNone(error_msg)
+
+ def test_validate_models_batch_mixed_results(self):
+ """Test batch validation with mixed results."""
+ model_configs = [
+ {
+ "backend": "detectron",
+ "model_name": "faster_rcnn_R_50_FPN_3x",
+ "config": None
+ },
+ {
+ "backend": "invalid_backend",
+ "model_name": "some_model",
+ "config": None
+ }
+ ]
+
+ def mock_validate_side_effect(backend, model_name, config=None, device=None):
+ if backend == "detectron":
+ return (True, None)
+ else:
+ return (False, "Invalid backend")
+
+ with patch('graid.data.generate_dataset.validate_model_config') as mock_validate:
+ mock_validate.side_effect = mock_validate_side_effect
+
+ results = validate_models_batch(model_configs, device="cpu")
+
+ self.assertEqual(len(results), 2)
+ # Check that we have both valid and invalid results
+ valid_count = sum(1 for is_valid, _ in results.values() if is_valid)
+ invalid_count = sum(
+ 1 for is_valid, _ in results.values() if not is_valid)
+ self.assertEqual(valid_count, 1)
+ self.assertEqual(invalid_count, 1)
+
+ def test_validate_wbf_compatibility_valid(self):
+ """Test WBF compatibility validation with valid models."""
+ model_configs = [
+ {
+ "backend": "detectron",
+ "model_name": "faster_rcnn_R_50_FPN_3x",
+ "custom_config": {
+ "config": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
+ "weights": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
+ }
+ },
+ {
+ "backend": "detectron",
+ "model_name": "retinanet_R_101_FPN_3x",
+ "custom_config": {
+ "config": "COCO-Detection/retinanet_R_101_FPN_3x.yaml",
+ "weights": "COCO-Detection/retinanet_R_101_FPN_3x.yaml"
+ }
+ }
+ ]
+
+ with patch('graid.data.generate_dataset.validate_models_batch') as mock_validate:
+ with patch('graid.data.generate_dataset.create_model') as mock_create_model:
+ with patch('graid.data.generate_dataset.WBF') as mock_wbf:
+ # Mock all the components
+ mock_validate.return_value = {
+ "model_0": (True, None),
+ "model_1": (True, None)
+ }
+
+ mock_model1 = Mock()
+ mock_model2 = Mock()
+ mock_create_model.side_effect = [mock_model1, mock_model2]
+
+ mock_wbf_instance = Mock()
+ mock_wbf.return_value = mock_wbf_instance
+ mock_wbf_instance.identify_for_image_batch.return_value = []
+
+ is_valid, error_msg = validate_wbf_compatibility(
+ model_configs, device="cpu")
+
+ self.assertTrue(is_valid)
+ self.assertIsNone(error_msg)
+
+ def test_validate_wbf_compatibility_insufficient_models(self):
+ """Test WBF compatibility validation with insufficient models."""
+ model_configs = [
+ {
+ "backend": "detectron",
+ "model_name": "faster_rcnn_R_50_FPN_3x",
+ "config": None
+ }
+ ]
+
+ is_valid, error_msg = validate_wbf_compatibility(
+ model_configs, device="cpu")
+
+ self.assertFalse(is_valid)
+ self.assertIsNotNone(error_msg)
+ self.assertIn("at least 2 models", error_msg)
+
+ def test_validate_wbf_compatibility_invalid_models(self):
+ """Test WBF compatibility validation with invalid models."""
+ model_configs = [
+ {
+ "backend": "detectron",
+ "model_name": "faster_rcnn_R_50_FPN_3x",
+ "config": None
+ },
+ {
+ "backend": "invalid_backend",
+ "model_name": "some_model",
+ "config": None
+ }
+ ]
+
+ with patch('graid.data.generate_dataset.validate_models_batch') as mock_validate:
+ mock_validate.return_value = {
+ "model_0": (True, None),
+ "model_1": (False, "Invalid backend")
+ }
+
+ is_valid, error_msg = validate_wbf_compatibility(
+ model_configs, device="cpu")
+
+ self.assertFalse(is_valid)
+ self.assertIsNotNone(error_msg)
+ self.assertIn("Some models failed validation", error_msg)
+
+
+class TestUtilityFunctions(unittest.TestCase):
+ """Test cases for utility functions."""
+
+ def test_list_available_models(self):
+ """Test listing available models."""
+ models = list_available_models()
+
+ self.assertIsInstance(models, dict)
+ self.assertIn("detectron", models)
+ self.assertIn("mmdetection", models)
+ self.assertIn("ultralytics", models)
+
+ # Check that each backend has a list of models
+ for backend, model_list in models.items():
+ self.assertIsInstance(model_list, list)
+ self.assertGreater(len(model_list), 0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/graid/tests/test_detectron_seg_batch.py b/tests/unit_tests/test_detectron_seg_batch.py
similarity index 100%
rename from graid/tests/test_detectron_seg_batch.py
rename to tests/unit_tests/test_detectron_seg_batch.py
diff --git a/tests/unit_tests/test_generate_db.py b/tests/unit_tests/test_generate_db.py
index 0c67fbf..eec908b 100644
--- a/tests/unit_tests/test_generate_db.py
+++ b/tests/unit_tests/test_generate_db.py
@@ -4,6 +4,11 @@
Tests the import fixes and functionality of the database generation system.
"""
+from graid.data.generate_db import (
+ create_model,
+ generate_db,
+ list_available_models,
+)
import sys
from pathlib import Path
from unittest.mock import MagicMock, Mock, patch
@@ -16,13 +21,6 @@
project_root = Path(__file__).parent.parent / "graid"
sys.path.insert(0, str(project_root))
-from graid.data.generate_db import (
- MODEL_CONFIGS,
- create_model,
- generate_db,
- list_available_models,
-)
-
class TestGenerateDbImportFix:
"""Test the generate_db import and functionality fixes."""
@@ -112,27 +110,6 @@ def test_list_available_models(self):
assert isinstance(models[backend], list)
assert len(models[backend]) > 0
- def test_model_configs_structure(self):
- """Test that MODEL_CONFIGS has proper structure."""
- # Test structure
- assert isinstance(MODEL_CONFIGS, dict)
-
- for backend, models in MODEL_CONFIGS.items():
- assert isinstance(models, dict)
- for model_name, config in models.items():
- # Different backends have different config structures
- if backend == "detectron":
- # Should have config and weights
- assert "config" in config
- assert "weights" in config
- elif backend == "mmdetection":
- # Should have config and checkpoint
- assert "config" in config
- assert "checkpoint" in config
- elif backend == "ultralytics":
- # Should be a string path to model file
- assert isinstance(config, str)
-
def test_create_model_function_exists(self):
"""Test that create_model function can be imported and called."""
# Mock dependencies for create_model
@@ -141,8 +118,13 @@ def test_create_model_function_exists(self):
mock_model = Mock()
mock_detectron.return_value = mock_model
- # Test create_model
- result = create_model("detectron", "faster_rcnn_R_50_FPN_3x", "cpu")
+ # Test create_model with custom_config
+ custom_config = {
+ "config": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
+ "weights": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
+ }
+ result = create_model(
+ "detectron", "faster_rcnn_R_50_FPN_3x", "cpu", custom_config=custom_config)
assert result == mock_model
@@ -167,7 +149,8 @@ def mock_build(**kwargs):
mock_builder_class.return_value = mock_builder
# Test database generation
- result = generate_db(dataset_name="bdd", split="val", conf=0.2, batch_size=50)
+ result = generate_db(dataset_name="bdd", split="val",
+ conf=0.2, batch_size=50)
# Verify successful completion
assert result == "bdd_val_gt"
diff --git a/tests/unit_tests/test_imageloader.py b/tests/unit_tests/test_imageloader.py
index 1ac96d3..f50c896 100644
--- a/tests/unit_tests/test_imageloader.py
+++ b/tests/unit_tests/test_imageloader.py
@@ -4,6 +4,7 @@
Tests the transform function signature fix where transform should receive both image and labels.
"""
+from graid.data.ImageLoader import ImageDataset
import sys
from pathlib import Path
from unittest.mock import MagicMock, Mock, patch
@@ -16,8 +17,6 @@
project_root = Path(__file__).parent.parent / "graid"
sys.path.insert(0, str(project_root))
-from graid.data.ImageLoader import ImageDataset
-
class TestImageLoaderTransformFix:
"""Test the ImageLoader transform bug fix."""
@@ -31,7 +30,8 @@ def test_transform_receives_both_image_and_labels(self):
with patch("graid.data.ImageLoader.ImageDataset.__init__", return_value=None):
dataset = ImageDataset.__new__(ImageDataset)
dataset.transform = mock_transform
- dataset.data = [{"image_path": "test.jpg", "labels": [{"category": "car"}]}]
+ dataset.data = [{"image_path": "test.jpg",
+ "labels": [{"category": "car"}]}]
dataset.dataset_name = "test"
# Mock image loading
@@ -85,42 +85,11 @@ def mock_transform(image, labels):
# Test passes if no exception is raised and labels are converted to empty list
assert len(result) == 2
assert isinstance(result[0], np.ndarray)
- assert result[1] == [] # None labels should be converted to empty list
+ # None labels should be converted to empty list
+ assert result[1] == []
except TypeError as e:
if "'NoneType' object is not iterable" in str(e):
pytest.fail("Transform should handle None labels properly")
else:
# Re-raise if it's a different TypeError
raise
-
- def test_dataset_transforms_lambda_functions(self):
- """Test that dataset has proper transform lambda functions defined."""
- from graid.data.generate_db import MODEL_CONFIGS
-
- # Test that all model configs have proper transform functions
- for backend, models in MODEL_CONFIGS.items():
- for model_name, config in models.items():
- if "transforms" in config:
- transforms = config["transforms"]
-
- # Test that transform functions accept both image and labels
- mock_image = np.zeros((100, 100, 3), dtype=np.uint8)
- mock_labels = [{"category": "car", "bbox": [10, 10, 50, 50]}]
-
- try:
- if "train" in transforms:
- result = transforms["train"](mock_image, mock_labels)
- assert (
- len(result) == 2
- ), f"Train transform for {backend}.{model_name} should return (image, labels)"
-
- if "val" in transforms:
- result = transforms["val"](mock_image, mock_labels)
- assert (
- len(result) == 2
- ), f"Val transform for {backend}.{model_name} should return (image, labels)"
-
- except Exception as e:
- pytest.fail(
- f"Transform for {backend}.{model_name} failed: {str(e)}"
- )
diff --git a/graid/tests/test_integration.py b/tests/unit_tests/test_integration.py
similarity index 97%
rename from graid/tests/test_integration.py
rename to tests/unit_tests/test_integration.py
index 0976182..86cb399 100644
--- a/graid/tests/test_integration.py
+++ b/tests/unit_tests/test_integration.py
@@ -5,6 +5,16 @@
system, including end-to-end workflows and external integrations like WandB.
"""
+from graid.interfaces.ObjectDetectionI import ObjectDetectionResultI
+from graid.data.validation.human_supervised_filter import (
+ HumanSupervisedClassifier,
+ HumanSupervisedFilter,
+)
+from graid.data.validation import (
+ ComprehensiveDetectionValidator,
+ ValidationConfig,
+ ValidationStage,
+)
import json
import os
import shutil
@@ -18,18 +28,8 @@
# Add graid to path
sys.path.append(str(Path(__file__).parent.parent / "src"))
-from graid.data.validation import (
- ComprehensiveDetectionValidator,
- ValidationConfig,
- ValidationStage,
-)
-from graid.data.validation.human_supervised_filter import (
- HumanSupervisedClassifier,
- HumanSupervisedFilter,
-)
-from graid.interfaces.ObjectDetectionI import ObjectDetectionResultI
-
+@unittest.skip("Skipping integration tests")
class TestEndToEndValidation(unittest.TestCase):
"""Test end-to-end validation workflows."""
@@ -168,6 +168,7 @@ def test_error_handling_in_pipeline(self):
self.assertIsInstance(e, Exception)
+@unittest.skip("Skipping integration tests")
class TestHumanSupervisedIntegration(unittest.TestCase):
"""Test integration of human supervised validation (Phase 6)."""
@@ -263,6 +264,7 @@ def test_human_supervised_filter_integration(self):
self.assertLessEqual(result.confidence, 1.0)
+@unittest.skip("Skipping integration tests")
class TestConfigurationManagement(unittest.TestCase):
"""Test configuration management and persistence."""
@@ -303,7 +305,8 @@ def test_config_serialization(self):
def test_config_validation(self):
"""Test configuration validation."""
# Test valid config
- valid_config = ValidationConfig(min_detection_confidence=0.5, device="cpu")
+ valid_config = ValidationConfig(
+ min_detection_confidence=0.5, device="cpu")
self.assertIsNotNone(valid_config)
# Test edge cases
@@ -336,14 +339,17 @@ def test_multiple_validators(self):
)
fast_validator = ComprehensiveDetectionValidator(fast_config)
- comprehensive_validator = ComprehensiveDetectionValidator(comprehensive_config)
+ comprehensive_validator = ComprehensiveDetectionValidator(
+ comprehensive_config)
# Verify they're independent
self.assertNotEqual(id(fast_validator), id(comprehensive_validator))
self.assertEqual(fast_validator.config.require_all_stages, False)
- self.assertEqual(comprehensive_validator.config.require_all_stages, True)
+ self.assertEqual(
+ comprehensive_validator.config.require_all_stages, True)
+@unittest.skip("Skipping integration tests")
class TestPerformanceAndScaling(unittest.TestCase):
"""Test performance characteristics and scaling behavior."""
diff --git a/graid/tests/test_neural_classifiers.py b/tests/unit_tests/test_neural_classifiers.py
similarity index 86%
rename from graid/tests/test_neural_classifiers.py
rename to tests/unit_tests/test_neural_classifiers.py
index 534513d..11d9c46 100644
--- a/graid/tests/test_neural_classifiers.py
+++ b/tests/unit_tests/test_neural_classifiers.py
@@ -5,56 +5,57 @@
of the validation pipeline, including multimodal models combining vision and text.
"""
-import unittest
+import shutil
import sys
-from pathlib import Path
import tempfile
-import shutil
-from PIL import Image
+import unittest
+from pathlib import Path
+
import numpy as np
+from PIL import Image
# Add graid to path
sys.path.append(str(Path(__file__).parent.parent / "src"))
-from graid.data.validation.neural_classifiers import MultimodalValidationClassifier
from graid.data.validation.human_supervised_filter import (
+ HumanEvaluationSample,
HumanSupervisedClassifier,
HumanSupervisionDataLoader,
- HumanEvaluationSample
)
+from graid.data.validation.neural_classifiers import MultimodalValidationClassifier
class TestMultimodalValidationClassifier(unittest.TestCase):
"""Test cases for the multimodal validation classifier."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.sample_image_path = self._create_sample_image()
-
+
def tearDown(self):
"""Clean up test fixtures."""
shutil.rmtree(self.temp_dir)
-
+
def _create_sample_image(self):
"""Create a sample image for testing."""
- image = Image.new('RGB', (224, 224), color='blue')
+ image = Image.new("RGB", (224, 224), color="blue")
image_path = Path(self.temp_dir) / "test_image.jpg"
image.save(image_path)
return str(image_path)
-
+
def test_classifier_initialization(self):
"""Test classifier initialization with default parameters."""
classifier = MultimodalValidationClassifier(
vit_model_name="google/vit-base-patch16-224",
text_model_name="sentence-transformers/all-MiniLM-L6-v2",
- device="cpu"
+ device="cpu",
)
-
+
self.assertIsNotNone(classifier)
self.assertEqual(classifier.device, "cpu")
self.assertFalse(classifier.is_trained)
-
+
def test_classifier_initialization_custom_params(self):
"""Test classifier initialization with custom parameters."""
classifier = MultimodalValidationClassifier(
@@ -63,127 +64,132 @@ def test_classifier_initialization_custom_params(self):
learning_rate=0.001,
n_epochs=10,
batch_size=8,
- device="cpu"
+ device="cpu",
)
-
+
self.assertEqual(classifier.hidden_dims, [128, 64])
self.assertEqual(classifier.dropout_rate, 0.5)
self.assertEqual(classifier.learning_rate, 0.001)
self.assertEqual(classifier.n_epochs, 10)
self.assertEqual(classifier.batch_size, 8)
-
+
def test_training_interface(self):
"""Test training interface without actual training."""
classifier = MultimodalValidationClassifier(device="cpu")
-
+
# Mock training data
image_paths = [self.sample_image_path] * 4
- questions = ["How many cars?", "What color is the car?", "Is there a person?", "Any traffic lights?"]
+ questions = [
+ "How many cars?",
+ "What color is the car?",
+ "Is there a person?",
+ "Any traffic lights?",
+ ]
answers = ["2", "blue", "yes", "no"]
labels = [1, 1, 0, 1] # 1 = valid, 0 = invalid
-
+
# Test training call
result = classifier.fit(
image_paths=image_paths,
questions=questions,
answers=answers,
labels=labels,
- validation_split=0.0
+ validation_split=0.0,
)
-
+
# Check training completed
self.assertTrue(classifier.is_trained)
- self.assertIn('final_train_accuracy', result)
- self.assertIn('train_history', result)
-
+ self.assertIn("final_train_accuracy", result)
+ self.assertIn("train_history", result)
+
def test_prediction_interface(self):
"""Test prediction interface."""
classifier = MultimodalValidationClassifier(device="cpu")
-
+
# Train first
image_paths = [self.sample_image_path] * 2
questions = ["How many cars?", "What color?"]
answers = ["2", "blue"]
labels = [1, 0]
-
+
classifier.fit(image_paths, questions, answers, labels)
-
+
# Test prediction
prediction = classifier.predict_single(
image_path=self.sample_image_path,
question="How many cars are there?",
- answer="3"
+ answer="3",
)
-
+
self.assertIsInstance(prediction, float)
self.assertGreaterEqual(prediction, 0.0)
self.assertLessEqual(prediction, 1.0)
-
+
def test_prediction_without_training(self):
"""Test that prediction fails without training."""
classifier = MultimodalValidationClassifier(device="cpu")
-
+
with self.assertRaises(ValueError):
classifier.predict_single(
image_path=self.sample_image_path,
question="Test question",
- answer="Test answer"
+ answer="Test answer",
)
-
+
def test_model_save_load_interface(self):
"""Test model save/load interface."""
classifier = MultimodalValidationClassifier(device="cpu")
-
+
# Train first
image_paths = [self.sample_image_path] * 2
questions = ["Question 1", "Question 2"]
answers = ["Answer 1", "Answer 2"]
labels = [1, 0]
-
+
classifier.fit(image_paths, questions, answers, labels)
-
+
# Test save
model_path = Path(self.temp_dir) / "test_model.pth"
classifier.save_model(model_path)
-
+
# Test load
new_classifier = MultimodalValidationClassifier(device="cpu")
new_classifier.load_model(model_path)
-
+
self.assertTrue(new_classifier.is_trained)
-
+
def test_save_without_training(self):
"""Test that save fails without training."""
classifier = MultimodalValidationClassifier(device="cpu")
-
+
with self.assertRaises(ValueError):
classifier.save_model("dummy_path.pth")
class TestHumanSupervisedClassifier(unittest.TestCase):
"""Test cases for the human supervised classifier."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.sample_image_path = self._create_sample_image()
self._create_sample_labels_file()
-
+
def tearDown(self):
"""Clean up test fixtures."""
shutil.rmtree(self.temp_dir)
-
+
def _create_sample_image(self):
"""Create a sample image for testing."""
- image = Image.new('RGB', (224, 224), color='red')
+ image = Image.new("RGB", (224, 224), color="red")
image_path = Path(self.temp_dir) / "sample.jpg"
image.save(image_path)
return str(image_path)
-
+
def _create_sample_labels_file(self):
"""Create a sample labels JSON file."""
import json
-
+
labels_data = {
"samples": [
{
@@ -192,7 +198,7 @@ def _create_sample_labels_file(self):
"answer": "3",
"human_evaluation": "yes",
"evaluator_id": "test_evaluator",
- "confidence": 0.9
+ "confidence": 0.9,
},
{
"image_path": self.sample_image_path,
@@ -200,87 +206,80 @@ def _create_sample_labels_file(self):
"answer": "yes",
"human_evaluation": "no", # This should be invalid
"evaluator_id": "test_evaluator",
- "confidence": 0.8
- }
+ "confidence": 0.8,
+ },
]
}
-
+
labels_file = Path(self.temp_dir) / "test_labels.json"
- with open(labels_file, 'w') as f:
+ with open(labels_file, "w") as f:
json.dump(labels_data, f)
-
+
return str(labels_file)
-
+
def test_classifier_initialization(self):
"""Test human supervised classifier initialization."""
classifier = HumanSupervisedClassifier()
-
+
self.assertIsNotNone(classifier)
self.assertFalse(classifier.is_trained)
-
+
def test_classifier_with_model_config(self):
"""Test classifier initialization with model config."""
- model_config = {
- "learning_rate": 0.001,
- "batch_size": 32
- }
-
- classifier = HumanSupervisedClassifier(
- model_config=model_config
- )
-
+ model_config = {"learning_rate": 0.001, "batch_size": 32}
+
+ classifier = HumanSupervisedClassifier(model_config=model_config)
+
self.assertEqual(classifier.model_config["learning_rate"], 0.001)
self.assertEqual(classifier.model_config["batch_size"], 32)
-
+
def test_data_loading_and_training(self):
"""Test loading data and training the classifier."""
classifier = HumanSupervisedClassifier()
-
+
# Mock the training process with minimal data
try:
results = classifier.load_and_train(
labels_dir=self.temp_dir,
test_size=0.0, # No test split for small dataset
- val_size=0.0, # No validation split
+ val_size=0.0, # No validation split
min_samples_per_class=1, # Reduce requirement
- image_base_path=None
+ image_base_path=None,
)
-
+
self.assertTrue(classifier.is_trained)
- self.assertIn('training_results', results)
-
+ self.assertIn("training_results", results)
+
except ValueError as e:
# Expected for insufficient data
self.assertIn("samples", str(e).lower())
-
+
def test_prediction_interface(self):
"""Test prediction interface."""
classifier = HumanSupervisedClassifier()
-
+
# Mock training
classifier.is_trained = True
-
+
# Test prediction
prediction = classifier.predict_validity(
- image=self.sample_image_path,
- question="Test question",
- answer="Test answer"
+ image=self.sample_image_path, question="Test question", answer="Test answer"
)
-
+
self.assertIsInstance(prediction, float)
self.assertGreaterEqual(prediction, 0.0)
self.assertLessEqual(prediction, 1.0)
-
+
def test_model_persistence(self):
"""Test model save/load functionality."""
classifier = HumanSupervisedClassifier()
classifier.is_trained = True # Mock training
-
+
# Test save
model_path = Path(self.temp_dir) / "human_model.pkl"
classifier.save_model(model_path)
self.assertTrue(model_path.exists())
-
+
# Test load
new_classifier = HumanSupervisedClassifier()
new_classifier.load_model(model_path)
@@ -289,20 +288,20 @@ def test_model_persistence(self):
class TestHumanSupervisionDataLoader(unittest.TestCase):
"""Test cases for the human supervision data loader."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self._create_test_data_files()
-
+
def tearDown(self):
"""Clean up test fixtures."""
shutil.rmtree(self.temp_dir)
-
+
def _create_test_data_files(self):
"""Create test data files in different formats."""
import json
-
+
# New format
new_format_data = {
"samples": [
@@ -311,15 +310,15 @@ def _create_test_data_files(self):
"question": "Question 1",
"answer": "Answer 1",
"human_evaluation": "yes",
- "evaluator_id": "evaluator1"
+ "evaluator_id": "evaluator1",
}
]
}
-
+
new_format_file = Path(self.temp_dir) / "new_format.json"
- with open(new_format_file, 'w') as f:
+ with open(new_format_file, "w") as f:
json.dump(new_format_data, f)
-
+
# Legacy format
legacy_format_data = {
"evaluations": {
@@ -328,28 +327,28 @@ def _create_test_data_files(self):
"question": "Question 2",
"answer": "Answer 2",
"human_evaluation": "no",
- "evaluator": "evaluator2"
+ "evaluator": "evaluator2",
}
}
}
-
+
legacy_format_file = Path(self.temp_dir) / "legacy_format.json"
- with open(legacy_format_file, 'w') as f:
+ with open(legacy_format_file, "w") as f:
json.dump(legacy_format_data, f)
-
+
def test_data_loader_initialization(self):
"""Test data loader initialization."""
loader = HumanSupervisionDataLoader(self.temp_dir)
self.assertEqual(loader.labels_dir, Path(self.temp_dir))
-
+
def test_load_manual_evaluations(self):
"""Test loading manual evaluations from JSON files."""
loader = HumanSupervisionDataLoader(self.temp_dir)
samples = loader.load_manual_evaluations()
-
+
# Should load from both new and legacy format files
self.assertGreater(len(samples), 0)
-
+
# Check sample structure
for sample in samples:
self.assertIsInstance(sample, HumanEvaluationSample)
@@ -357,29 +356,29 @@ def test_load_manual_evaluations(self):
self.assertIsInstance(sample.question, str)
self.assertIsInstance(sample.answer, str)
self.assertIn(sample.human_evaluation, ["yes", "no"])
-
+
def test_load_specific_datasets(self):
"""Test loading specific datasets."""
loader = HumanSupervisionDataLoader(self.temp_dir)
-
+
# Test loading specific dataset
samples = loader.load_manual_evaluations(datasets=["new_format"])
-
+
# Should only load from files containing "new_format"
self.assertGreater(len(samples), 0)
-
+
def test_load_from_nonexistent_directory(self):
"""Test loading from non-existent directory."""
loader = HumanSupervisionDataLoader("/nonexistent/path")
samples = loader.load_manual_evaluations()
-
+
# Should return empty list
self.assertEqual(len(samples), 0)
class TestHumanEvaluationSample(unittest.TestCase):
"""Test cases for the HumanEvaluationSample data structure."""
-
+
def test_sample_creation(self):
"""Test creating human evaluation samples."""
sample = HumanEvaluationSample(
@@ -389,9 +388,9 @@ def test_sample_creation(self):
human_evaluation="yes",
evaluator_id="test_evaluator",
confidence=0.9,
- metadata={"source": "test"}
+ metadata={"source": "test"},
)
-
+
self.assertEqual(sample.image_path, "test.jpg")
self.assertEqual(sample.question, "Test question")
self.assertEqual(sample.answer, "Test answer")
@@ -399,16 +398,16 @@ def test_sample_creation(self):
self.assertEqual(sample.evaluator_id, "test_evaluator")
self.assertEqual(sample.confidence, 0.9)
self.assertEqual(sample.metadata["source"], "test")
-
+
def test_sample_with_minimal_data(self):
"""Test creating samples with minimal required data."""
sample = HumanEvaluationSample(
image_path="minimal.jpg",
question="Minimal question",
answer="Minimal answer",
- human_evaluation="no"
+ human_evaluation="no",
)
-
+
self.assertEqual(sample.image_path, "minimal.jpg")
self.assertEqual(sample.human_evaluation, "no")
self.assertIsNone(sample.evaluator_id)
@@ -416,6 +415,6 @@ def test_sample_with_minimal_data(self):
self.assertIsNone(sample.metadata)
-if __name__ == '__main__':
+if __name__ == "__main__":
# Run tests
- unittest.main(verbosity=2)
\ No newline at end of file
+ unittest.main(verbosity=2)
diff --git a/graid/tests/test_threshold_functionality.py b/tests/unit_tests/test_threshold_functionality.py
similarity index 94%
rename from graid/tests/test_threshold_functionality.py
rename to tests/unit_tests/test_threshold_functionality.py
index 7561254..ffe2ad1 100644
--- a/graid/tests/test_threshold_functionality.py
+++ b/tests/unit_tests/test_threshold_functionality.py
@@ -114,7 +114,8 @@ def test_threshold_functionality():
and isinstance(detections_default[0], list)
):
detections_default = detections_default[0]
- print(f"YOLO with default threshold (0.0): {len(detections_default)} detections")
+ print(
+ f"YOLO with default threshold (0.0): {len(detections_default)} detections")
# Test with low threshold
yolo_model.set_threshold(0.1)
@@ -154,10 +155,12 @@ def test_threshold_functionality():
else:
detections_to_check = detections_high
+ # Check that all detections are above threshold
for det in detections_to_check:
if hasattr(det, "score") and det.score < 0.8:
- print(f"β Found detection with score {det.score} below threshold 0.8")
- return False
+ print(
+ f"β Found detection with score {det.score} below threshold 0.8")
+ assert False, f"Detection with score {det.score} below threshold 0.8"
print("β
All confidence scores are above threshold!")
diff --git a/graid/tests/test_validation_pipeline.py b/tests/unit_tests/test_validation_pipeline.py
similarity index 74%
rename from graid/tests/test_validation_pipeline.py
rename to tests/unit_tests/test_validation_pipeline.py
index ad89967..4e4abb8 100644
--- a/graid/tests/test_validation_pipeline.py
+++ b/tests/unit_tests/test_validation_pipeline.py
@@ -5,12 +5,13 @@
and error handling of the detection validation system.
"""
-import unittest
+import logging
import sys
+import unittest
from pathlib import Path
-from PIL import Image, ImageDraw
+
import torch
-import logging
+from PIL import Image, ImageDraw
# Add graid to path
sys.path.append(str(Path(__file__).parent.parent / "src"))
@@ -18,7 +19,7 @@
from graid.data.validation import (
ComprehensiveDetectionValidator,
ValidationConfig,
- ValidationStage
+ ValidationStage,
)
from graid.interfaces.ObjectDetectionI import ObjectDetectionResultI
@@ -28,56 +29,76 @@
class TestValidationPipeline(unittest.TestCase):
"""Test cases for the comprehensive detection validation pipeline."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.image_size = (640, 480)
self.sample_image = self._create_sample_image()
self.sample_detections = self._create_sample_detections()
-
+
def _create_sample_detections(self):
"""Create sample detections for testing."""
h, w = self.image_size
-
+
return [
# Reasonable detections for street scene
ObjectDetectionResultI(
- score=0.85, cls=2, label="car",
- bbox=[100, 200, 300, 350], image_hw=self.image_size
+ score=0.85,
+ cls=2,
+ label="car",
+ bbox=[100, 200, 300, 350],
+ image_hw=self.image_size,
),
ObjectDetectionResultI(
- score=0.92, cls=0, label="person",
- bbox=[320, 150, 380, 340], image_hw=self.image_size
+ score=0.92,
+ cls=0,
+ label="person",
+ bbox=[320, 150, 380, 340],
+ image_hw=self.image_size,
),
ObjectDetectionResultI(
- score=0.78, cls=9, label="traffic_light",
- bbox=[450, 50, 480, 120], image_hw=self.image_size
+ score=0.78,
+ cls=9,
+ label="traffic_light",
+ bbox=[450, 50, 480, 120],
+ image_hw=self.image_size,
),
# Unreasonable detections (should be filtered)
ObjectDetectionResultI(
- score=0.65, cls=21, label="elephant",
- bbox=[200, 180, 400, 380], image_hw=self.image_size
+ score=0.65,
+ cls=21,
+ label="elephant",
+ bbox=[200, 180, 400, 380],
+ image_hw=self.image_size,
),
ObjectDetectionResultI(
- score=0.55, cls=4, label="airplane",
- bbox=[50, 30, 250, 150], image_hw=self.image_size
- )
+ score=0.55,
+ cls=4,
+ label="airplane",
+ bbox=[50, 30, 250, 150],
+ image_hw=self.image_size,
+ ),
]
-
+
def _create_sample_image(self):
"""Create a simple street scene image for testing."""
- image = Image.new('RGB', self.image_size, color='lightblue')
+ image = Image.new("RGB", self.image_size, color="lightblue")
draw = ImageDraw.Draw(image)
-
+
# Ground/road
- draw.rectangle([0, self.image_size[1]//2, self.image_size[0], self.image_size[1]], fill='gray')
-
+ draw.rectangle(
+ [0, self.image_size[1] // 2, self.image_size[0], self.image_size[1]],
+ fill="gray",
+ )
+
# Buildings
- draw.rectangle([0, 100, 150, self.image_size[1]//2], fill='darkred')
- draw.rectangle([500, 80, self.image_size[0], self.image_size[1]//2], fill='darkblue')
-
+ draw.rectangle([0, 100, 150, self.image_size[1] // 2], fill="darkred")
+ draw.rectangle(
+ [500, 80, self.image_size[0], self.image_size[1] // 2], fill="darkblue"
+ )
+
return image
-
+
def test_basic_validation_config(self):
"""Test basic validation configuration creation."""
config = ValidationConfig(
@@ -87,14 +108,14 @@ def test_basic_validation_config(self):
enable_ensemble=False,
enable_segmentation=False,
min_detection_confidence=0.3,
- device="cpu"
+ device="cpu",
)
-
+
self.assertTrue(config.enable_cooccurrence)
self.assertFalse(config.enable_clip_relationships)
self.assertEqual(config.min_detection_confidence, 0.3)
self.assertEqual(config.device, "cpu")
-
+
def test_validator_initialization(self):
"""Test that validator initializes correctly."""
config = ValidationConfig(
@@ -103,26 +124,26 @@ def test_validator_initialization(self):
enable_scene_consistency=False,
enable_ensemble=False,
enable_segmentation=False,
- device="cpu"
+ device="cpu",
)
-
+
validator = ComprehensiveDetectionValidator(config)
-
+
self.assertIsNotNone(validator)
self.assertEqual(validator.config, config)
self.assertIn(ValidationStage.COOCCURRENCE, validator.filters)
-
+
def test_cooccurrence_filtering(self):
"""Test co-occurrence filtering stage."""
from graid.data.validation.cooccurrence_filter import CooccurrenceFilter
-
+
filter_stage = CooccurrenceFilter()
results = filter_stage.validate_detection_set(self.sample_detections)
-
+
# Results should be a list, but length depends on implementation
self.assertIsInstance(results, list)
self.assertGreater(len(results), 0)
-
+
# Check that results have proper structure
for result in results:
self.assertIsInstance(result.confidence, float)
@@ -130,76 +151,80 @@ def test_cooccurrence_filtering(self):
# Confidence should be between 0 and 1
self.assertGreaterEqual(result.confidence, 0.0)
self.assertLessEqual(result.confidence, 1.0)
-
+
def test_basic_validation_pipeline(self):
"""Test basic validation pipeline without external dependencies."""
config = ValidationConfig(
enable_cooccurrence=True,
enable_clip_relationships=False, # Skip to avoid CLIP dependency
- enable_scene_consistency=False, # Skip to avoid OpenAI API
- enable_ensemble=False, # No ensemble models
- enable_segmentation=False, # Skip to avoid SAM2 dependency
+ enable_scene_consistency=False, # Skip to avoid OpenAI API
+ enable_ensemble=False, # No ensemble models
+ enable_segmentation=False, # Skip to avoid SAM2 dependency
require_all_stages=False,
min_detection_confidence=0.3,
- device="cpu"
+ device="cpu",
)
-
+
validator = ComprehensiveDetectionValidator(config)
-
+
# Test validation
valid_detections, validation_records = validator.filter_detections(
self.sample_detections, self.sample_image, debug=False
)
-
+
# Verify results structure
self.assertIsInstance(valid_detections, list)
self.assertIsInstance(validation_records, list)
self.assertEqual(len(validation_records), len(self.sample_detections))
-
+
# Check that we get some filtering (exact results depend on implementation)
self.assertLessEqual(len(valid_detections), len(self.sample_detections))
-
+
# Verify validation records structure
for record in validation_records:
self.assertIsNotNone(record.detection)
self.assertIsInstance(record.final_valid, bool)
self.assertIsInstance(record.stage_results, dict)
-
+
def test_strict_vs_lenient_configuration(self):
"""Test strict vs lenient validation configurations."""
-
+
# Strict configuration
strict_config = ValidationConfig(
enable_cooccurrence=True,
enable_clip_relationships=False,
enable_scene_consistency=False,
enable_segmentation=False,
- require_all_stages=True, # All stages must pass
+ require_all_stages=True, # All stages must pass
min_detection_confidence=0.7, # High threshold
- cooccurrence_threshold=0.01 # Strict co-occurrence
+ cooccurrence_threshold=0.01, # Strict co-occurrence
)
-
- # Lenient configuration
+
+ # Lenient configuration
lenient_config = ValidationConfig(
enable_cooccurrence=True,
enable_clip_relationships=False,
enable_scene_consistency=False,
enable_segmentation=False,
- require_all_stages=False, # Majority vote
+ require_all_stages=False, # Majority vote
min_detection_confidence=0.2, # Low threshold
- cooccurrence_threshold=0.0001 # Permissive co-occurrence
+ cooccurrence_threshold=0.0001, # Permissive co-occurrence
)
-
+
strict_validator = ComprehensiveDetectionValidator(strict_config)
lenient_validator = ComprehensiveDetectionValidator(lenient_config)
-
- strict_valid, _ = strict_validator.filter_detections(self.sample_detections, self.sample_image)
- lenient_valid, _ = lenient_validator.filter_detections(self.sample_detections, self.sample_image)
-
+
+ strict_valid, _ = strict_validator.filter_detections(
+ self.sample_detections, self.sample_image
+ )
+ lenient_valid, _ = lenient_validator.filter_detections(
+ self.sample_detections, self.sample_image
+ )
+
# Both should return some results (exact comparison depends on implementation)
self.assertIsInstance(strict_valid, list)
self.assertIsInstance(lenient_valid, list)
-
+
def test_metrics_collection(self):
"""Test that validation metrics are collected properly."""
config = ValidationConfig(
@@ -208,40 +233,42 @@ def test_metrics_collection(self):
enable_scene_consistency=False,
enable_ensemble=False,
enable_segmentation=False,
- device="cpu"
+ device="cpu",
)
-
+
validator = ComprehensiveDetectionValidator(config)
validator.filter_detections(self.sample_detections, self.sample_image)
-
+
metrics = validator.get_metrics_summary()
-
+
# Check metrics structure
- self.assertIn('total_detections', metrics)
- self.assertIn('final_valid_detections', metrics)
- self.assertIn('overall_pass_rate', metrics)
- self.assertIn('stage_pass_rates', metrics)
-
+ self.assertIn("total_detections", metrics)
+ self.assertIn("final_valid_detections", metrics)
+ self.assertIn("overall_pass_rate", metrics)
+ self.assertIn("stage_pass_rates", metrics)
+
# Check metrics values
- self.assertEqual(metrics['total_detections'], len(self.sample_detections))
- self.assertGreaterEqual(metrics['overall_pass_rate'], 0.0)
- self.assertLessEqual(metrics['overall_pass_rate'], 1.0)
-
+ self.assertEqual(metrics["total_detections"], len(self.sample_detections))
+ self.assertGreaterEqual(metrics["overall_pass_rate"], 0.0)
+ self.assertLessEqual(metrics["overall_pass_rate"], 1.0)
+
def test_detection_input_validation(self):
"""Test input validation for detections."""
config = ValidationConfig(enable_cooccurrence=True, device="cpu")
validator = ComprehensiveDetectionValidator(config)
-
+
# Test empty detections
valid_detections, records = validator.filter_detections([], self.sample_image)
self.assertEqual(len(valid_detections), 0)
self.assertEqual(len(records), 0)
-
+
# Test single detection
single_detection = [self.sample_detections[0]]
- valid_detections, records = validator.filter_detections(single_detection, self.sample_image)
+ valid_detections, records = validator.filter_detections(
+ single_detection, self.sample_image
+ )
self.assertEqual(len(records), 1)
-
+
def test_confidence_filtering(self):
"""Test minimum confidence filtering."""
config = ValidationConfig(
@@ -250,16 +277,18 @@ def test_confidence_filtering(self):
enable_scene_consistency=False,
enable_segmentation=False,
min_detection_confidence=0.8, # High threshold
- device="cpu"
+ device="cpu",
)
-
+
validator = ComprehensiveDetectionValidator(config)
- valid_detections, _ = validator.filter_detections(self.sample_detections, self.sample_image)
-
+ valid_detections, _ = validator.filter_detections(
+ self.sample_detections, self.sample_image
+ )
+
# All valid detections should meet minimum confidence
for detection in valid_detections:
self.assertGreaterEqual(detection.score, 0.8)
-
+
def test_stage_info_retrieval(self):
"""Test stage information retrieval."""
config = ValidationConfig(
@@ -267,54 +296,58 @@ def test_stage_info_retrieval(self):
enable_clip_relationships=False,
enable_scene_consistency=False,
enable_segmentation=False,
- device="cpu"
+ device="cpu",
)
-
+
validator = ComprehensiveDetectionValidator(config)
stage_info = validator.get_stage_info()
-
+
self.assertIsInstance(stage_info, dict)
- self.assertIn('cooccurrence', stage_info)
-
+ self.assertIn("cooccurrence", stage_info)
+
# Each stage should have description
for stage, info in stage_info.items():
- self.assertIn('description', info)
- self.assertIsInstance(info['description'], str)
+ self.assertIn("description", info)
+ self.assertIsInstance(info["description"], str)
class TestValidationComponents(unittest.TestCase):
"""Test individual validation components."""
-
+
def test_validation_result_structure(self):
"""Test ValidationResult data structure."""
from graid.utilities.validation import ValidationResult
-
+
result = ValidationResult(
passed=True,
confidence=0.85,
reason="Test reason",
- metadata={"test": "data"}
+ metadata={"test": "data"},
)
-
+
self.assertTrue(result.passed)
self.assertEqual(result.confidence, 0.85)
self.assertEqual(result.reason, "Test reason")
self.assertEqual(result.metadata["test"], "data")
-
+
def test_validation_stage_enum(self):
"""Test ValidationStage enumeration."""
from graid.utilities.validation import ValidationStage
-
+
# Check that all expected stages exist
expected_stages = {
- "COOCCURRENCE", "CLIP_RELATIONSHIPS", "SCENE_CONSISTENCY",
- "ENSEMBLE_AGREEMENT", "SEGMENTATION", "HUMAN_SUPERVISED"
+ "COOCCURRENCE",
+ "CLIP_RELATIONSHIPS",
+ "SCENE_CONSISTENCY",
+ "ENSEMBLE_AGREEMENT",
+ "SEGMENTATION",
+ "HUMAN_SUPERVISED",
}
-
+
actual_stages = {stage.name for stage in ValidationStage}
self.assertEqual(expected_stages, actual_stages)
-if __name__ == '__main__':
+if __name__ == "__main__":
# Run tests
- unittest.main(verbosity=2)
\ No newline at end of file
+ unittest.main(verbosity=2)
diff --git a/train_yolo_bdd.py b/train_yolo_bdd.py
new file mode 100644
index 0000000..4ec3ea6
--- /dev/null
+++ b/train_yolo_bdd.py
@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+import os
+from pathlib import Path
+
+import torch
+from ultralytics import YOLO
+
+
+def get_available_gpus() -> list[int]:
+ """Return a list of GPU indices visible to the current process.
+
+ When launched under Slurm, CUDA_VISIBLE_DEVICES is set to the GPUs
+ allocated for the job (e.g. "0,1,3,4"). Ultralytics expects a list of
+ integer indices. If the env-var is missing we fall back to
+ torch.cuda.device_count().
+ """
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
+ if cvd:
+ return [int(x) for x in cvd.split(",") if x.strip()]
+ return list(range(torch.cuda.device_count()))
+
+
+def main() -> None:
+ project_root = Path(__file__).resolve().parent
+ print(f"Project root: {project_root}")
+
+ gpus = get_available_gpus()
+ if not gpus:
+ raise RuntimeError("No CUDA devices available β cannot train model.")
+ print(f"Using GPUs: {gpus}")
+
+ # -------------------------
+ # Weights & Biases logging
+ # -------------------------
+ # Ultralytics automatically logs to wandb if the package is installed and
+ # WANDB_MODE is not set to "disabled". Setting WANDB_PROJECT (and optional
+ # WANDB_NAME/ENTITY) here ensures the run is grouped correctly.
+ os.environ.setdefault("WANDB_PROJECT", "yolo_bdd")
+ os.environ.setdefault("WANDB_NAME", "yolo_bdd_train")
+
+ # Initialize model (downloads weights if necessary)
+ model = YOLO("yolov9e.pt")
+
+ # Start training
+ model.train(
+ data="bdd_ultra.yaml",
+ epochs=10,
+ imgsz=1280,
+ batch=32,
+ device=gpus,
+ project="runs", # local directory for artifacts
+ name="yolo_bdd", # run name inside WandB & runs/
+ deterministic=True, # reproducibility
+ workers=8, # dataloader workers
+ )
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/yolo_wandb_tune.py b/yolo_wandb_tune.py
new file mode 100644
index 0000000..5599247
--- /dev/null
+++ b/yolo_wandb_tune.py
@@ -0,0 +1,102 @@
+# import wandb
+# from ultralytics import YOLO
+# from wandb.integration.ultralytics import add_wandb_callback
+
+# wandb.login()
+
+# model_gpu_map = {
+# "yolov5m": [3],
+# "yolov5m6": [2],
+# "yolov6x": [1],
+# "yolov6l": [0],
+# }
+
+# config = {
+# "epochs": 30,
+# "iterations": 100,
+# "imgsz": 640,
+# "batch": 32,
+# "save": True,
+# "plots": True,
+# "optimizer": "AdamW"
+# }
+
+# wandb.init(project="bdd_ultra", name="yolov5m", job_type="tuning", config=config)
+
+# for model_name, device in model_gpu_map.items():
+# model = YOLO(f"{model_name}.pt")
+
+# add_wandb_callback(model, enable_model_checkpointing=True)
+
+
+# model.tune(
+# data="bdd_ultra.yaml",
+# epochs=config["epochs"],
+# iterations=config["iterations"],
+# imgsz=config["imgsz"],
+# device=config["device"],
+# batch=config["batch"],
+# save=config["save"],
+# plots=config["plots"],
+# optimizer=config["optimizer"],
+# )
+
+# wandb.finish()
+
+
+import wandb
+from ultralytics import YOLO
+from wandb.integration.ultralytics import add_wandb_callback
+import threading
+
+wandb.login()
+
+model_gpu_map = {
+ "yolov5mu.pt": [3],
+ "yolov5m6u.pt": [2],
+ "yolov6l.yaml": [1],
+ "yolov10b.pt": [0],
+}
+
+config = {
+ "data": "bdd_ultra.yaml",
+ "epochs": 30,
+ "iterations": 100,
+ "imgsz": 736,
+ "batch": 16,
+ "save": True,
+ "plots": True,
+ "optimizer": "AdamW",
+ "amp": False,
+}
+
+threads = []
+
+for model_name, device in model_gpu_map.items():
+
+ def tune_model(model_name, device, config):
+ wandb.init(project="bdd_ultra", name=model_name, job_type="tuning", config=config)
+
+ model = YOLO(model_name)
+
+ add_wandb_callback(model, enable_model_checkpointing=True)
+
+ model.tune(
+ device=device,
+ **config,
+ )
+
+ threads.append(
+ threading.Thread(
+ target=tune_model,
+ args=(model_name, device, config),
+ )
+ )
+
+ threads[-1].start()
+
+# wait for all threads to finish
+for thread in threads:
+ thread.join()
+
+wandb.finish()
\ No newline at end of file