diff --git a/training/object-detection/.dvc/.gitignore b/training/object-detection/.dvc/.gitignore new file mode 100644 index 0000000..528f30c --- /dev/null +++ b/training/object-detection/.dvc/.gitignore @@ -0,0 +1,3 @@ +/config.local +/tmp +/cache diff --git a/training/object-detection/.dvc/config b/training/object-detection/.dvc/config new file mode 100644 index 0000000..8b006d6 --- /dev/null +++ b/training/object-detection/.dvc/config @@ -0,0 +1,4 @@ +[core] + remote = storage +['remote "storage"'] + url = s3://salmonvision-dvc/rgb_object_detection diff --git a/training/object-detection/.dvcignore b/training/object-detection/.dvcignore new file mode 100644 index 0000000..5197305 --- /dev/null +++ b/training/object-detection/.dvcignore @@ -0,0 +1,3 @@ +# Add patterns of files dvc should ignore, which could improve +# the performance. Learn more at +# https://dvc.org/doc/user-guide/dvcignore diff --git a/training/object-detection/.gitignore b/training/object-detection/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/training/object-detection/.python-version b/training/object-detection/.python-version new file mode 100644 index 0000000..cc1923a --- /dev/null +++ b/training/object-detection/.python-version @@ -0,0 +1 @@ +3.8 diff --git a/training/object-detection/README.md b/training/object-detection/README.md new file mode 100644 index 0000000..21ab44c --- /dev/null +++ b/training/object-detection/README.md @@ -0,0 +1,35 @@ +# Object Detection + +Training pipeline to train the SalmonVision object detection model. + +Install uv: +```bash +curl -LsSf https://astral.sh/uv/install.sh | sh +``` + +Install DVC: +```bash +uv tool install dvc +``` + +Install the module: +```bash +uv pip install -e . +``` + +Check dvc.yaml for the full pipeline. + +Run the following to run specific stages of the pipeline: +```bash +dvc repro stage_name +``` + +For example, building the model input annotations: +```bash +dvc repro build_model_input +``` + +Run tests with +``` +uv run pytest +``` diff --git a/training/object-detection/config/salmon_yolo.yaml b/training/object-detection/config/salmon_yolo.yaml new file mode 100644 index 0000000..d1601dc --- /dev/null +++ b/training/object-detection/config/salmon_yolo.yaml @@ -0,0 +1,28 @@ +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +# Classes updated on 2026-02-12 +path: /training/export_combined_bear_kitwanga_yolo # dataset root dir +train: train.txt +val: val.txt +test: test.txt + +# Classes +names: + 0: Coho + 1: Bull + 2: Rainbow + 3: Sockeye + 4: Pink + 5: Whitefish + 6: Chinook + 7: Shiner + 8: Pikeminnow + 9: Chum + 10: Steelhead + 11: Lamprey + 12: Cutthroat + 13: Stickleback + 14: Sculpin + 15: Jack_Coho + 16: Jack_Chinook + 17: Otter + 18: Sucker diff --git a/training/object-detection/data/01_raw/.gitignore b/training/object-detection/data/01_raw/.gitignore new file mode 100644 index 0000000..767a4b6 --- /dev/null +++ b/training/object-detection/data/01_raw/.gitignore @@ -0,0 +1,4 @@ +/labelstudio_annos +/salmon_vid_counts.csv +/salmon_vid_counts_summary.csv +/sv_water_conditions diff --git a/training/object-detection/data/01_raw/salmon_vid_counts.csv.dvc b/training/object-detection/data/01_raw/salmon_vid_counts.csv.dvc new file mode 100644 index 0000000..1023678 --- /dev/null +++ b/training/object-detection/data/01_raw/salmon_vid_counts.csv.dvc @@ -0,0 +1,5 @@ +outs: +- md5: b6136eee30a358d7473a160197bc91be + size: 6465051 + hash: md5 + path: salmon_vid_counts.csv diff --git a/training/object-detection/data/01_raw/salmon_vid_counts_summary.csv.dvc b/training/object-detection/data/01_raw/salmon_vid_counts_summary.csv.dvc new file mode 100644 index 0000000..4f4cbb3 --- /dev/null +++ b/training/object-detection/data/01_raw/salmon_vid_counts_summary.csv.dvc @@ -0,0 +1,5 @@ +outs: +- md5: c65c2ba8214d26905ccb9608e9071b93 + size: 1031 + hash: md5 + path: salmon_vid_counts_summary.csv diff --git a/training/object-detection/data/01_raw/sv_water_conditions.dvc b/training/object-detection/data/01_raw/sv_water_conditions.dvc new file mode 100644 index 0000000..f4cce12 --- /dev/null +++ b/training/object-detection/data/01_raw/sv_water_conditions.dvc @@ -0,0 +1,6 @@ +outs: +- md5: 717ade5c2fd57a763609ab55d66bf351.dir + size: 199437 + nfiles: 10 + hash: md5 + path: sv_water_conditions diff --git a/training/object-detection/data/02_interim/.gitignore b/training/object-detection/data/02_interim/.gitignore new file mode 100644 index 0000000..c412ac4 --- /dev/null +++ b/training/object-detection/data/02_interim/.gitignore @@ -0,0 +1,3 @@ +/yolo_annos +/yolo_annos_unpacked +/yolo_condition_negatives diff --git a/training/object-detection/data/03_processed/.gitignore b/training/object-detection/data/03_processed/.gitignore new file mode 100644 index 0000000..15e8ec8 --- /dev/null +++ b/training/object-detection/data/03_processed/.gitignore @@ -0,0 +1 @@ +/splits_baseline diff --git a/training/object-detection/dvc.lock b/training/object-detection/dvc.lock new file mode 100644 index 0000000..76dccb0 --- /dev/null +++ b/training/object-detection/dvc.lock @@ -0,0 +1,143 @@ +schema: '2.0' +stages: + update_raw: + cmd: rclone sync -P --filter "- /salm_dataset*/**" --filter "+ *.json" + --filter "- *.zip" aws:salmonvision-ml-datasets/rgb/raw/ + data/01_raw/labelstudio_annos + deps: + - path: s3://salmonvision-ml-datasets/rgb/raw + hash: md5 + md5: 9aab20758f57f39971a38c9a676dad27.dir + size: 720238505 + nfiles: 65 + outs: + - path: data/01_raw/labelstudio_annos + hash: md5 + md5: 188b71c31be326fe36d209cc52c23337.dir + size: 379352848 + nfiles: 50 + split_data: + cmd: rm -rf data/02_interim/yolo_annos_unpacked && scripts/unpack_annos.sh + data/02_interim/yolo_annos data/02_interim/yolo_annos_unpacked && + scripts/unpack_annos.sh data/02_interim/yolo_condition_negatives + data/02_interim/yolo_annos_unpacked && scripts/make_splits.py + --labels-root data/02_interim/yolo_annos_unpacked --out-dir + data/03_processed/splits_baseline --sites tankeeah kitwanga bear --seed 42 + --train-frac 0.8 --val-frac 0.1 --test-frac 0.1 + deps: + - path: data/01_raw/salmon_vid_counts.csv + hash: md5 + md5: b6136eee30a358d7473a160197bc91be + size: 6465051 + - path: data/01_raw/salmon_vid_counts_summary.csv + hash: md5 + md5: c65c2ba8214d26905ccb9608e9071b93 + size: 1031 + - path: data/02_interim/yolo_annos + hash: md5 + md5: 88551957fa5e19f676175f29606a613a.dir + size: 308350474 + nfiles: 5 + - path: data/02_interim/yolo_condition_negatives + hash: md5 + md5: c3505d079516f38874bdb42d2182cb46.dir + size: 338928 + nfiles: 3 + - path: scripts/make_splits.py + hash: md5 + md5: 3921729fec8a5aa6c1eb25c54196a658 + size: 650 + - path: src/object_detection/splits + hash: md5 + md5: 8ca8d1d4f2962a0e27f6fd50da90a8d2.dir + size: 35847 + nfiles: 6 + - path: src/object_detection/utils + hash: md5 + md5: 5828afbf8a25f11a0f8e5bbc6c0065bf.dir + size: 762 + nfiles: 4 + outs: + - path: data/03_processed/splits_baseline + hash: md5 + md5: e821f72b0bf42a09afafb688ce927ac4.dir + size: 16966114 + nfiles: 5 + build_model_input: + cmd: scripts/yolo_converter_ls_video.py data/01_raw/labelstudio_annos + --data-yaml config/salmon_yolo.yaml --out data/02_interim/yolo_annos_fs + --out-shards data/02_interim/yolo_annos --empty-list + data/02_interim/yolo_annos/empty_vids.txt --shard-size 100000 --pattern + '**/*.json' --include-sites tankeeah kitwanga bear --frame-stride 3 + --frame-offset-mode video_hash --include-negatives --negative-ratio 0.10 + --negatives-per-video 11 + deps: + - path: config/salmon_yolo.yaml + hash: md5 + md5: f453f5dc54f1743eaedcc3ab117d269e + size: 547 + - path: data/01_raw/labelstudio_annos + hash: md5 + md5: 188b71c31be326fe36d209cc52c23337.dir + size: 379352848 + nfiles: 50 + - path: scripts/yolo_converter_ls_video.py + hash: md5 + md5: 95696514ec77c20c6491c5af2ccb46a5 + size: 118 + - path: src/object_detection/utils + hash: md5 + md5: 5828afbf8a25f11a0f8e5bbc6c0065bf.dir + size: 762 + nfiles: 4 + - path: src/object_detection/yolo_ls + hash: md5 + md5: 80ca7160d1912861bbc2a013a45ac5d5.dir + size: 52757 + nfiles: 10 + outs: + - path: data/02_interim/yolo_annos + hash: md5 + md5: 88551957fa5e19f676175f29606a613a.dir + size: 308350474 + nfiles: 5 + unpack_annos: + cmd: rm -r data/02_interim/yolo_annos_unpacked || scripts/unpack_annos.sh + data/02_interim/yolo_annos data/02_interim/yolo_annos_unpacked + deps: + - path: data/02_interim/yolo_annos + hash: md5 + md5: 473a8eedf5fd4cc5d5b8b6d1f80681e7.dir + size: 293713920 + nfiles: 3 + build_condition_negatives: + cmd: scripts/create_condition_negatives.py --conditions-csv + data/01_raw/sv_water_conditions/SV_conditions_tracking_tankeeah_2025.csv + data/01_raw/sv_water_conditions/SV_conditions_tracking_bear_2025.csv + data/01_raw/sv_water_conditions/SV_conditions_tracking_kitwanga_2025.csv + --out-dir data/02_interim/yolo_condition_negatives --frames-per-video 20 + --frame-stride 3 --frame-offset-mode video_hash --shard-size 100000 + deps: + - path: data/01_raw/sv_water_conditions + hash: md5 + md5: 717ade5c2fd57a763609ab55d66bf351.dir + size: 199437 + nfiles: 10 + - path: src/object_detection/negatives/cli.py + hash: md5 + md5: d04fcc8b316529ebefeec2830b288197 + size: 2142 + - path: src/object_detection/negatives/conditions.py + hash: md5 + md5: f037be2028fead4a5a9e88f95799254a + size: 18950 + - path: src/object_detection/yolo_ls/shards.py + hash: md5 + md5: 9a1e97420b95c978e10ec236c2645337 + size: 1266 + outs: + - path: data/02_interim/yolo_condition_negatives + hash: md5 + md5: c3505d079516f38874bdb42d2182cb46.dir + size: 338928 + nfiles: 3 diff --git a/training/object-detection/dvc.yaml b/training/object-detection/dvc.yaml new file mode 100644 index 0000000..33ccbaf --- /dev/null +++ b/training/object-detection/dvc.yaml @@ -0,0 +1,142 @@ +stages: + update_raw: + cmd: >- + rclone sync + -P + --filter "- /salm_dataset*/**" + --filter "+ *.json" + --filter "- *.zip" + aws:salmonvision-ml-datasets/rgb/raw/ + data/01_raw/labelstudio_annos + deps: + - s3://salmonvision-ml-datasets/rgb/raw + outs: + - data/01_raw/labelstudio_annos + frozen: true + build_model_input: + cmd: >- + scripts/yolo_converter_ls_video.py + data/01_raw/labelstudio_annos + --data-yaml config/salmon_yolo.yaml + --out data/02_interim/yolo_annos_fs + --out-shards data/02_interim/yolo_annos + --empty-list data/02_interim/yolo_annos/empty_vids.txt + --shard-size 100000 + --pattern '**/*.json' + --include-sites tankeeah kitwanga bear + --frame-stride 3 + --frame-offset-mode video_hash + --include-negatives + --negative-ratio 0.10 + --negatives-per-video 11 + deps: + - scripts/yolo_converter_ls_video.py + - src/object_detection/yolo_ls + - src/object_detection/utils + - data/01_raw/labelstudio_annos + - config/salmon_yolo.yaml + outs: + - data/02_interim/yolo_annos + build_condition_negatives: + cmd: >- + scripts/create_condition_negatives.py + --conditions-csv data/01_raw/sv_water_conditions/SV_conditions_tracking_tankeeah_2025.csv + data/01_raw/sv_water_conditions/SV_conditions_tracking_bear_2025.csv + data/01_raw/sv_water_conditions/SV_conditions_tracking_kitwanga_2025.csv + --out-dir data/02_interim/yolo_condition_negatives + --frames-per-video 20 + --frame-stride 3 + --frame-offset-mode video_hash + --shard-size 100000 + deps: + - src/object_detection/negatives/conditions.py + - src/object_detection/negatives/cli.py + - src/object_detection/yolo_ls/shards.py + - data/01_raw/sv_water_conditions + outs: + - data/02_interim/yolo_condition_negatives + split_data: + cmd: >- + rm -rf data/02_interim/yolo_annos_unpacked && + scripts/unpack_annos.sh + data/02_interim/yolo_annos + data/02_interim/yolo_annos_unpacked && + scripts/unpack_annos.sh + data/02_interim/yolo_condition_negatives + data/02_interim/yolo_annos_unpacked && + scripts/make_splits.py + --labels-root data/02_interim/yolo_annos_unpacked + --out-dir data/03_processed/splits_baseline + --sites tankeeah kitwanga bear + --seed 42 + --train-frac 0.8 --val-frac 0.1 --test-frac 0.1 + deps: + - scripts/make_splits.py + - src/object_detection/splits + - src/object_detection/utils + - data/02_interim/yolo_annos + - data/02_interim/yolo_condition_negatives + - data/01_raw/salmon_vid_counts.csv + - data/01_raw/salmon_vid_counts_summary.csv + outs: + - data/03_processed/splits_baseline + build_video_metadata_index: + cmd: >- + scripts/build_video_metadata_index.py + --json-dir data/01_raw/labelstudio_annos + --out-csv data/02_interim/video_metadata_index.csv + deps: + - src/object_detection/metadata/index.py + - src/object_detection/metadata/cli.py + - data/01_raw/labelstudio_annos + outs: + - data/02_interim/video_metadata_index.csv + pack_split_dataset: + cmd: >- + scripts/pack_split_dataset.py + --splits-dir data/03_processed/splits_baseline + --labels-root data/02_interim/yolo_annos_unpacked + --shards-root ${config.drive}/salmon_dataset/dataset_sharded/shards + --manifests-root ${config.drive}/salmon_dataset/dataset_sharded/manifests + --temp-video-dir ${config.drive}/salmon_dataset/tmp_videos + --metadata-csv data/02_interim/video_metadata_index.csv + data/02_interim/yolo_condition_negatives/condition_negative_video_metadata.csv + --data-yaml config/salmon_yolo.yaml + --bucket prod-salmonvision-edge-assets-labelstudio-source + --image-ext .jpg + --manifest-csv data/03_processed/packed_dataset_manifest.csv + --splits train val test + --shard-size 100000 + deps: + - src/object_detection/frames/parsing.py + - src/object_detection/frames/extractor.py + - src/object_detection/frames/cli.py + - src/object_detection/yolo_ls/shards.py + - data/03_processed/splits_baseline + - data/02_interim/yolo_annos_unpacked + - data/02_interim/video_metadata_index.csv + - data/02_interim/yolo_condition_negatives/condition_negative_video_metadata.csv + - config/salmon_yolo.yaml + params: + - config.drive + outs: + - ${config.drive}/salmon_dataset/dataset_sharded/shards + - ${config.drive}/salmon_dataset/dataset_sharded/manifests + - data/03_processed/packed_dataset_manifest.csv + unpack_split_dataset: + cmd: >- + rm -rf ${config.drive}/salmon_dataset/yolo_workdir && + mkdir -p ${config.drive}/salmon_dataset/yolo_workdir && + scripts/unpack_annos.sh + ${config.drive}/salmon_dataset/dataset_sharded/shards + ${config.drive}/salmon_dataset/yolo_workdir && + cp ${config.drive}/salmon_dataset/dataset_sharded/manifests/train.txt ${config.drive}/salmon_dataset/yolo_workdir/train.txt && + cp ${config.drive}/salmon_dataset/dataset_sharded/manifests/val.txt ${config.drive}/salmon_dataset/yolo_workdir/val.txt && + cp ${config.drive}/salmon_dataset/dataset_sharded/manifests/test.txt ${config.drive}/salmon_dataset/yolo_workdir/test.txt && + cp ${config.drive}/salmon_dataset/dataset_sharded/manifests/data.yaml ${config.drive}/salmon_dataset/yolo_workdir/data.yaml + deps: + - ${config.drive}/salmon_dataset/dataset_sharded/shards + - ${config.drive}/salmon_dataset/dataset_sharded/manifests + params: + - config.drive + frozen: true diff --git a/training/object-detection/main.py b/training/object-detection/main.py new file mode 100644 index 0000000..f376d2d --- /dev/null +++ b/training/object-detection/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from object-detection!") + + +if __name__ == "__main__": + main() diff --git a/training/object-detection/params.yaml b/training/object-detection/params.yaml new file mode 100644 index 0000000..4f332bb --- /dev/null +++ b/training/object-detection/params.yaml @@ -0,0 +1,2 @@ +config: + drive: /mnt/harukassd4tb/masamim diff --git a/training/object-detection/pyproject.toml b/training/object-detection/pyproject.toml new file mode 100644 index 0000000..43fb8fb --- /dev/null +++ b/training/object-detection/pyproject.toml @@ -0,0 +1,26 @@ +[project] +name = "object-detection" +version = "0.1.0" +description = "SalmonVision object detector training" +readme = "README.md" +requires-python = ">=3.8" +dependencies = [ + "boto3>=1.37.38", + "pyyaml>=6.0.3", +] + +[dependency-groups] +dev = [ + "pytest>=8.3.5", +] + +[build-system] +requires = ["setuptools>=61"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] diff --git a/training/object-detection/scripts/build_video_metadata_index.py b/training/object-detection/scripts/build_video_metadata_index.py new file mode 100755 index 0000000..38250cb --- /dev/null +++ b/training/object-detection/scripts/build_video_metadata_index.py @@ -0,0 +1,5 @@ +#!/usr/bin/env -S uv run python +from object_detection.metadata.cli import main + +if __name__ == "__main__": + main() diff --git a/training/object-detection/scripts/create_condition_negatives.py b/training/object-detection/scripts/create_condition_negatives.py new file mode 100755 index 0000000..547ddbb --- /dev/null +++ b/training/object-detection/scripts/create_condition_negatives.py @@ -0,0 +1,5 @@ +#!/usr/bin/env -S uv run python +from object_detection.negatives.cli import main + +if __name__ == "__main__": + main() diff --git a/training/object-detection/scripts/make_splits.py b/training/object-detection/scripts/make_splits.py new file mode 100755 index 0000000..3ca26da --- /dev/null +++ b/training/object-detection/scripts/make_splits.py @@ -0,0 +1,28 @@ +#!/usr/bin/env -S uv run python + +""" +make_splits.py + +Group-wise stratified-ish split for unpacked YOLO label files. + +- Input: unpacked labels directory that looks like: + //frame_000123.txt + + where video_stem looks like: + ORG-site-device-id_YYYYMMDD_HHMMSS_M + +- Output: + out_dir/train.txt + out_dir/val.txt + out_dir/test.txt + out_dir/group_assignments.csv + out_dir/split_report.json + +Split unit: group_id = site + device + date(YYYYMMDD) +Balancing objectives (soft): class counts, time-of-day, density bins, box area bins. +""" + +from object_detection.splits.cli import main + +if __name__ == "__main__": + main() diff --git a/training/object-detection/scripts/pack_split_dataset.py b/training/object-detection/scripts/pack_split_dataset.py new file mode 100755 index 0000000..c4eb631 --- /dev/null +++ b/training/object-detection/scripts/pack_split_dataset.py @@ -0,0 +1,5 @@ +#!/usr/bin/env -S uv run python +from object_detection.frames.cli import main + +if __name__ == "__main__": + main() diff --git a/training/object-detection/scripts/unpack_annos.sh b/training/object-detection/scripts/unpack_annos.sh new file mode 100755 index 0000000..c78c455 --- /dev/null +++ b/training/object-detection/scripts/unpack_annos.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +set -e + +help_msg() { + echo "$0 [-h] in_folder out_folder" +} + +# Get the options +while getopts ":h" option; do + case $option in + h) # display Help + help_msg + exit;; + \?) # Invalid option + echo "Error: Invalid option" + help_msg + exit;; + esac +done + +# Check if exactly two arguments are given +if [ $# -ne 2 ]; then + help_msg + exit 1 +fi + +in_path="$1" +out_path="$2" + +mkdir -p "$out_path" +for f in "$in_path"/*.tar; do + tar -xf "$f" -C "$out_path" +done diff --git a/training/object-detection/scripts/yolo_converter_ls_video.py b/training/object-detection/scripts/yolo_converter_ls_video.py new file mode 100755 index 0000000..9c816a7 --- /dev/null +++ b/training/object-detection/scripts/yolo_converter_ls_video.py @@ -0,0 +1,6 @@ +#!/usr/bin/env -S uv run python + +from object_detection.yolo_ls.cli import main + +if __name__ == "__main__": + main() diff --git a/training/object-detection/src/object_detection/__init__.py b/training/object-detection/src/object_detection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/object-detection/src/object_detection/frames/__init__.py b/training/object-detection/src/object_detection/frames/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/object-detection/src/object_detection/frames/cli.py b/training/object-detection/src/object_detection/frames/cli.py new file mode 100644 index 0000000..6c447fe --- /dev/null +++ b/training/object-detection/src/object_detection/frames/cli.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import argparse +from pathlib import Path +import yaml + +from object_detection.frames.extractor import pack_split_dataset_shards + + +def load_class_names_from_yolo_yaml(path: Path): + data = yaml.safe_load(path.read_text(encoding="utf-8")) + names = data.get("names") + if isinstance(names, dict): + return [names[k] for k in sorted(names, key=lambda x: int(x))] + if isinstance(names, list): + return names + raise ValueError(f"Unsupported names format in {path}") + + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Pack split-aware YOLO dataset into tar shards.") + p.add_argument("--splits-dir", required=True, help="Directory containing train.txt / val.txt / test.txt") + p.add_argument("--labels-root", required=True, help="Root of unpacked label files referenced by split manifests") + p.add_argument("--shards-root", required=True, help="Output directory for packed tar shards") + p.add_argument("--manifests-root", required=True, help="Output directory for new image manifests + data.yaml") + p.add_argument("--temp-video-dir", required=True, help="Temporary directory for downloaded videos") + p.add_argument("--metadata-csv", nargs="+", required=True, + help="One or more metadata CSVs with columns including video_stem,fps,s3_key") + p.add_argument("--data-yaml", required=True, help="YOLO data.yaml used only to get class names") + p.add_argument("--bucket", default="prod-salmonvision-edge-assets-labelstudio-source", + help="Fallback bucket if metadata row lacks s3_key") + p.add_argument("--image-ext", default=".jpg", choices=[".jpg", ".png"]) + p.add_argument("--keep-videos", action="store_true") + p.add_argument("--manifest-csv", default=None) + p.add_argument("--splits", nargs="*", default=["train", "val", "test"]) + p.add_argument("--shard-size", type=int, default=100000) + return p + + +def main() -> None: + args = build_parser().parse_args() + class_names = load_class_names_from_yolo_yaml(Path(args.data_yaml)) + + stats = pack_split_dataset_shards( + splits_dir=Path(args.splits_dir), + labels_root=Path(args.labels_root), + shards_root=Path(args.shards_root), + manifests_root=Path(args.manifests_root), + temp_video_dir=Path(args.temp_video_dir), + metadata_csv_paths=[Path(p) for p in args.metadata_csv], + class_names=class_names, + bucket=args.bucket, + image_ext=args.image_ext, + cleanup_video=not args.keep_videos, + split_names=args.splits, + manifest_csv=Path(args.manifest_csv) if args.manifest_csv else None, + shard_size=args.shard_size, + ) + + print( + f"Done. splits_seen={stats.splits_seen} " + f"videos_seen={stats.videos_seen} " + f"videos_processed={stats.videos_processed} " + f"videos_failed={stats.videos_failed} " + f"frames_requested={stats.frames_requested} " + f"images_written={stats.images_written} " + f"labels_written={stats.labels_written}" + ) diff --git a/training/object-detection/src/object_detection/frames/extractor.py b/training/object-detection/src/object_detection/frames/extractor.py new file mode 100644 index 0000000..9b2860f --- /dev/null +++ b/training/object-detection/src/object_detection/frames/extractor.py @@ -0,0 +1,369 @@ +from __future__ import annotations + +import csv +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Set +import io +import tarfile + +from object_detection.yolo_ls.shards import TarShardWriter +from object_detection.frames.parsing import ( + parse_manifest_relpath, + split_label_relpath_to_packed_paths, + video_stem_to_s3_key, +) + +from object_detection.utils.utils import safe_float + + +@dataclass +class ExtractionStats: + splits_seen: int = 0 + videos_seen: int = 0 + videos_processed: int = 0 + videos_failed: int = 0 + frames_requested: int = 0 + frames_written: int = 0 + labels_written: int = 0 + +def load_video_metadata_index(path: Path) -> Dict[str, Dict[str, str]]: + out: Dict[str, Dict[str, str]] = {} + with path.open("r", newline="", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + video_stem = (row.get("video_stem") or "").strip() + if not video_stem: + continue + out[video_stem] = dict(row) + return out + +def merge_video_metadata_csvs(paths: Iterable[Path]) -> Dict[str, Dict[str, str]]: + """ + Merge metadata CSVs by video_stem. + Later CSVs overwrite earlier CSVs on conflicts. + """ + merged: Dict[str, Dict[str, str]] = {} + for path in paths: + current = load_video_metadata_index(path) + for video_stem, row in current.items(): + if video_stem in merged: + prev = merged[video_stem] + if prev.get("fps") != row.get("fps") or prev.get("s3_key") != row.get("s3_key"): + print( + f"[frames] warning: overriding metadata for {video_stem} " + f"from fps={prev.get('fps')} s3_key={prev.get('s3_key')} " + f"to fps={row.get('fps')} s3_key={row.get('s3_key')}" + ) + merged[video_stem] = row + return merged + +def ensure_dir(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + + +def read_split_manifest(path: Path) -> List[str]: + lines: List[str] = [] + for line in path.read_text(encoding="utf-8").splitlines(): + s = line.strip() + if s: + lines.append(s) + return lines + + +def write_split_manifests( + manifests_root: Path, + split_to_image_relpaths: Dict[str, List[str]], +) -> None: + ensure_dir(manifests_root) + for split, relpaths in split_to_image_relpaths.items(): + out_path = manifests_root / f"{split}.txt" + relpaths = sorted(relpaths) + out_path.write_text("\n".join(relpaths) + ("\n" if relpaths else ""), encoding="utf-8") + + +def write_data_yaml( + manifests_root: Path, + class_names: List[str], +) -> None: + """ + Writes a YOLO-style data.yaml that uses split manifest files. + """ + lines = [ + f"train: {str((manifests_root / 'train.txt').resolve())}", + f"val: {str((manifests_root / 'val.txt').resolve())}", + f"test: {str((manifests_root / 'test.txt').resolve())}", + "names:", + ] + for idx, name in enumerate(class_names): + lines.append(f" {idx}: {name}") + (manifests_root / "data.yaml").write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def load_split_requests(splits_dir: Path, split_names: Iterable[str]) -> Dict[str, Dict[str, List[int]]]: + """ + Returns: + { + "train": {video_stem: [frame_idx, ...], ...}, + "val": {...}, + "test": {...}, + } + """ + out: Dict[str, Dict[str, List[int]]] = {} + + for split in split_names: + manifest = splits_dir / f"{split}.txt" + if not manifest.exists(): + continue + + by_video: Dict[str, List[int]] = {} + for line in read_split_manifest(manifest): + video_stem, frame_idx = parse_manifest_relpath(line) + by_video.setdefault(video_stem, []).append(frame_idx) + + # dedupe + sort + by_video = {k: sorted(set(v)) for k, v in by_video.items()} + out[split] = by_video + + return out + + +def download_s3_video(bucket: str, s3_key: str, local_video_path: Path) -> None: + ensure_dir(local_video_path.parent) + cmd = [ + "aws", "s3", "cp", + f"s3://{bucket}/{s3_key}", + str(local_video_path), + ] + subprocess.run(cmd, check=True) + + +def read_label_text(labels_root: Path, relpath: str) -> str: + path = labels_root / relpath + return path.read_text(encoding="utf-8") + + +def extract_frame_bytes_ffmpeg( + video_path: Path, + frame_idx: int, + fps: float, + image_ext: str = ".jpg", +) -> bytes: + """ + Extract one frame and return the encoded image bytes. + """ + timestamp = frame_idx / float(fps) + + if image_ext == ".jpg": + codec_args = ["-f", "image2", "-vcodec", "mjpeg"] + elif image_ext == ".png": + codec_args = ["-f", "image2", "-vcodec", "png"] + else: + raise ValueError(f"Unsupported image_ext: {image_ext}") + + cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", "error", + "-ss", f"{timestamp:.6f}", + "-i", str(video_path), + "-frames:v", "1", + ] + codec_args + ["pipe:1"] + + result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE) + return result.stdout + + +def extract_frame_ffmpeg( + video_path: Path, + frame_idx: int, + fps: float, + output_path: Path, + overwrite: bool = False, +) -> bool: + """ + Extract one frame using timestamp = frame_idx / fps. + """ + if output_path.exists() and not overwrite: + return False + + ensure_dir(output_path.parent) + timestamp = frame_idx / float(fps) + + cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", "error", + "-ss", f"{timestamp:.6f}", + "-i", str(video_path), + "-frames:v", "1", + "-q:v", "2", + "-y" if overwrite else "-n", + str(output_path), + ] + subprocess.run(cmd, check=True) + return True + + +def pack_split_dataset_shards( + splits_dir: Path, + labels_root: Path, + shards_root: Path, + manifests_root: Path, + temp_video_dir: Path, + metadata_csv_paths: Iterable[Path], + class_names: List[str], + bucket: str, + image_ext: str = ".jpg", + cleanup_video: bool = True, + split_names: Iterable[str] = ("train", "val", "test"), + manifest_csv: Optional[Path] = None, + shard_size: int = 100000, +) -> ExtractionStats: + """ + Reads split manifests containing label relpaths, e.g. + HIRMD-.../frame_000123.txt + + Produces sharded paired dataset: + train//frame_000123.jpg + train//frame_000123.txt + ... + + Also writes fresh split manifests that point to image relpaths inside the packed layout: + train/HIRMD-.../frame_000123.jpg + """ + split_requests = load_split_requests(splits_dir, split_names) + stats = ExtractionStats(splits_seen=len(split_requests)) + metadata_index = merge_video_metadata_csvs(metadata_csv_paths) + + ensure_dir(shards_root) + ensure_dir(manifests_root) + + shard_writers: Dict[str, TarShardWriter] = {} + for split in split_requests.keys(): + shard_writers[split] = TarShardWriter( + shards_root, + shard_size=shard_size, + prefix=split, + ) + + split_to_image_relpaths: Dict[str, List[str]] = {split: [] for split in split_requests.keys()} + manifest_rows: List[Dict[str, str]] = [] + + for split, by_video in split_requests.items(): + writer = shard_writers[split] + + for video_stem, frame_indices in by_video.items(): + stats.videos_seen += 1 + stats.frames_requested += len(frame_indices) + + local_video = temp_video_dir / f"{video_stem}.mp4" + s3_key = "" + fps = 0.0 + + try: + meta = metadata_index.get(video_stem) + if meta is None: + raise KeyError(f"Missing metadata for video_stem={video_stem}") + + fps = safe_float(meta.get("fps", ""), 0.0) + if fps <= 0: + raise ValueError(f"Invalid fps for video_stem={video_stem}: {meta.get('fps', '')!r}") + + s3_key = (meta.get("s3_key") or "").strip() + if not s3_key: + if not bucket: + raise ValueError(f"Missing s3_key for video_stem={video_stem}") + s3_key = video_stem_to_s3_key(video_stem) + + download_s3_video(bucket=bucket, s3_key=s3_key, local_video_path=local_video) + + for frame_idx in frame_indices: + label_relpath = f"{video_stem}/frame_{frame_idx:06d}.txt" + image_relpath, packed_label_relpath = split_label_relpath_to_packed_paths( + split=split, + relpath=label_relpath, + image_ext=image_ext, + ) + + image_bytes = extract_frame_bytes_ffmpeg( + video_path=local_video, + frame_idx=frame_idx, + fps=fps, + image_ext=image_ext, + ) + label_text = read_label_text(labels_root, label_relpath) + + writer.write_bytes(str(image_relpath), image_bytes) + split_to_image_relpaths[split].append(str(image_relpath)) + + stats.images_written += 1 + + writer.write_text(str(packed_label_relpath), label_text) + stats.labels_written += 1 + + stats.videos_processed += 1 + + manifest_rows.append({ + "split": split, + "video_stem": video_stem, + "s3_key": s3_key, + "fps": str(fps), + "requested_frames": str(len(frame_indices)), + "images_written": str(len(frame_indices)), + "labels_written": str(len(frame_indices)), + "status": "ok", + "error": "", + }) + + except Exception as e: + stats.videos_failed += 1 + manifest_rows.append({ + "split": split, + "video_stem": video_stem, + "s3_key": s3_key, + "fps": str(fps) if fps > 0 else "", + "requested_frames": str(len(frame_indices)), + "images_written": "0", + "labels_written": "0", + "status": "error", + "error": repr(e), + }) + + finally: + if cleanup_video: + try: + if local_video.exists(): + local_video.unlink() + except Exception: + pass + + for writer in shard_writers.values(): + writer.close() + + write_split_manifests(manifests_root, split_to_image_relpaths) + write_data_yaml(manifests_root, class_names) + + if manifest_csv is not None: + ensure_dir(manifest_csv.parent) + with manifest_csv.open("w", newline="", encoding="utf-8") as f: + w = csv.DictWriter( + f, + fieldnames=[ + "split", + "video_stem", + "s3_key", + "fps", + "requested_frames", + "images_written", + "labels_written", + "status", + "error", + ], + ) + w.writeheader() + for row in manifest_rows: + w.writerow(row) + + return stats diff --git a/training/object-detection/src/object_detection/frames/parsing.py b/training/object-detection/src/object_detection/frames/parsing.py new file mode 100644 index 0000000..3685b8d --- /dev/null +++ b/training/object-detection/src/object_detection/frames/parsing.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import re +from pathlib import Path +from typing import Dict, Optional, Tuple + +from object_detection.utils.utils import parse_video_stem + + +def split_label_relpath_to_packed_paths( + split: str, + relpath: str, + image_ext: str = ".jpg", +) -> Tuple[Path, Path]: + """ + Convert a split manifest label entry like: + HIRMD-tankeeah-jetson-0_20250714_012827_M/frame_000123.txt + + into packed dataset paths: + train/HIRMD-tankeeah-jetson-0_20250714_012827_M/frame_000123.jpg + train/HIRMD-tankeeah-jetson-0_20250714_012827_M/frame_000123.txt + """ + p = Path(relpath.strip()) + label_rel = Path(split) / p + image_rel = label_rel.with_suffix(image_ext) + return image_rel, label_rel + + +def parse_frame_idx(label_filename: str) -> Optional[int]: + m = re.match(r"^frame_(\d+)\.txt$", label_filename) + if not m: + return None + try: + return int(m.group(1)) + except Exception: + return None + + +def video_stem_to_s3_key(video_stem: str) -> str: + meta = parse_video_stem(video_stem) + if meta is None: + raise ValueError("Could not parse video stem: %s" % video_stem) + return f"{meta['org']}/{meta['site']}/{meta['device']}/motion_vids/{video_stem}.mp4" + + +def parse_manifest_relpath(relpath: str) -> Tuple[str, int]: + """ + Input line example: + HIRMD-tankeeah-jetson-0_20250714_012827_M/frame_000123.txt + Returns: + (video_stem, frame_idx) + """ + p = Path(relpath.strip()) + if len(p.parts) < 2: + raise ValueError("Invalid manifest relpath: %s" % relpath) + + video_stem = p.parts[0] + frame_idx = parse_frame_idx(p.name) + if frame_idx is None: + raise ValueError("Invalid frame filename: %s" % p.name) + + return video_stem, frame_idx + + +def label_relpath_to_image_relpath(relpath: str, image_ext: str = ".jpg") -> Path: + p = Path(relpath.strip()) + return p.with_suffix(image_ext) diff --git a/training/object-detection/src/object_detection/metadata/__init__.py b/training/object-detection/src/object_detection/metadata/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/object-detection/src/object_detection/metadata/cli.py b/training/object-detection/src/object_detection/metadata/cli.py new file mode 100644 index 0000000..1cc3942 --- /dev/null +++ b/training/object-detection/src/object_detection/metadata/cli.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +from object_detection.metadata.index import build_video_metadata_index, write_video_metadata_index + + +def main() -> None: + p = argparse.ArgumentParser(description="Build per-video metadata index from Label Studio task JSONs.") + p.add_argument("--json-dir", required=True) + p.add_argument("--out-csv", required=True) + args = p.parse_args() + + rows = build_video_metadata_index(Path(args.json_dir)) + write_video_metadata_index(rows, Path(args.out_csv)) + + print(f"Done. indexed_videos={len(rows)} out={args.out_csv}") diff --git a/training/object-detection/src/object_detection/metadata/index.py b/training/object-detection/src/object_detection/metadata/index.py new file mode 100644 index 0000000..69d3a72 --- /dev/null +++ b/training/object-detection/src/object_detection/metadata/index.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import csv +import json +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + + +def safe_float(v: Any, default: float = 0.0) -> float: + try: + return float(v) + except Exception: + return default + + +def parse_ffmpeg_rate(rate: Any) -> float: + if rate is None: + return 0.0 + if isinstance(rate, (int, float)): + return float(rate) + + s = str(rate).strip() + if "/" in s: + a, b = s.split("/", 1) + try: + num = float(a) + den = float(b) + return num / den if den else 0.0 + except Exception: + return 0.0 + try: + return float(s) + except Exception: + return 0.0 + + +def infer_fps(data: Dict[str, Any]) -> float: + fps = safe_float(data.get("frames_per_second"), 0.0) + if fps > 0: + return fps + + fps = parse_ffmpeg_rate(data.get("metadata_video_r_frame_rate")) + if fps > 0: + return fps + + fps = parse_ffmpeg_rate(data.get("metadata_video_avg_frame_rate")) + if fps > 0: + return fps + + duration = safe_float(data.get("metadata_video_duration", data.get("duration")), 0.0) + nb_frames = int(safe_float(data.get("metadata_video_nb_frames"), 0)) + if duration > 0 and nb_frames > 0: + return nb_frames / duration + + return 0.0 + + +def infer_s3_key(data: Dict[str, Any], video_stem: str) -> str: + org = data.get("metadata_file_organization_reference_string", "") + site = data.get("metadata_file_site_reference_string", "") + cam = data.get("metadata_file_camera_reference_string", "") + if org and site and cam: + return f"{org}/{site}/{cam}/motion_vids/{video_stem}.mp4" + return "" + + +def iter_task_items(json_dir: Path, pattern: str = "**/*.json") -> Iterable[Dict[str, Any]]: + for path in sorted(json_dir.glob(pattern)): + try: + payload = json.loads(path.read_text(encoding="utf-8")) + except Exception: + continue + + if isinstance(payload, dict): + yield payload + elif isinstance(payload, list): + for item in payload: + if isinstance(item, dict): + yield item + + +def build_video_metadata_index(json_dir: Path) -> List[Dict[str, str]]: + rows: Dict[str, Dict[str, str]] = {} + + for item in iter_task_items(json_dir): + data = item.get("data") or {} + filename = data.get("metadata_file_filename") or data.get("video") or "" + video_stem = Path(filename).stem + if not video_stem: + continue + + row = { + "video_stem": video_stem, + "fps": str(infer_fps(data)), + "nb_frames": str(int(safe_float(data.get("metadata_video_nb_frames"), 0))), + "duration": str(safe_float(data.get("metadata_video_duration", data.get("duration")), 0.0)), + "width": str(int(safe_float(data.get("metadata_video_width"), 0))), + "height": str(int(safe_float(data.get("metadata_video_height"), 0))), + "org": str(data.get("metadata_file_organization_reference_string", "")), + "site": str(data.get("metadata_file_site_reference_string", "")), + "device": str(data.get("metadata_file_camera_reference_string", "")), + "s3_key": infer_s3_key(data, video_stem), + } + rows[video_stem] = row + + return list(rows.values()) + + +def write_video_metadata_index(rows: List[Dict[str, str]], out_csv: Path) -> None: + out_csv.parent.mkdir(parents=True, exist_ok=True) + fieldnames = [ + "video_stem", + "fps", + "nb_frames", + "duration", + "width", + "height", + "org", + "site", + "device", + "s3_key", + ] + with out_csv.open("w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for row in rows: + w.writerow(row) + diff --git a/training/object-detection/src/object_detection/negatives/cli.py b/training/object-detection/src/object_detection/negatives/cli.py new file mode 100644 index 0000000..4d2aa4a --- /dev/null +++ b/training/object-detection/src/object_detection/negatives/cli.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +from object_detection.negatives.conditions import create_condition_negative_shards + + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Create condition-balanced negative YOLO label shards from water-conditions CSVs") + p.add_argument("--conditions-csv", nargs="+", required=True, help="One or more water-conditions CSV files") + p.add_argument("--out-dir", required=True, help="Output directory for negative shards and manifests") + p.add_argument("--bucket", default="prod-salmonvision-edge-assets-labelstudio-source") + p.add_argument("--frames-per-video", type=int, default=5) + p.add_argument("--frame-stride", type=int, default=3) + p.add_argument("--frame-offset-mode", choices=["fixed", "video_hash"], default="video_hash") + p.add_argument("--frame-offset", type=int, default=0) + p.add_argument("--shard-size", type=int, default=100000) + p.add_argument("--negative-seed", type=int, default=42) + p.add_argument("--result-type", default="videorectangle") + p.add_argument("--from-name", default=None) + p.add_argument("--to-name", default=None) + p.add_argument("--aws-profile", default=None) + p.add_argument("--cache-task-json-dir", default=None) + return p + + +def main() -> None: + args = build_parser().parse_args() + + summary = create_condition_negative_shards( + csv_paths=[Path(p) for p in args.conditions_csv], + out_dir=Path(args.out_dir), + bucket=args.bucket, + frames_per_video=args.frames_per_video, + frame_stride=args.frame_stride, + frame_offset_mode=args.frame_offset_mode, + frame_offset=args.frame_offset, + shard_size=args.shard_size, + negative_seed=args.negative_seed, + result_type=args.result_type, + from_name=args.from_name, + to_name=args.to_name, + aws_profile=args.aws_profile, + cache_task_json_dir=Path(args.cache_task_json_dir) if args.cache_task_json_dir else None, + ) + + print( + f"Done. input_rows={summary['input_rows']} " + f"selected_videos={summary['written_videos']} " + f"written_negative_frames={summary['written_negative_frames']} " + f"failures={len(summary['failures'])}" + ) diff --git a/training/object-detection/src/object_detection/negatives/conditions.py b/training/object-detection/src/object_detection/negatives/conditions.py new file mode 100644 index 0000000..d68b7c8 --- /dev/null +++ b/training/object-detection/src/object_detection/negatives/conditions.py @@ -0,0 +1,706 @@ +from __future__ import annotations + +import csv +import json +import random +from collections import Counter +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple + +import boto3 + +from object_detection.utils.utils import safe_float +from object_detection.yolo_ls.shards import TarShardWriter + + +EXCLUDED_COLUMNS = { + "Project", + "Site", + "Camera", + "Filename", # Label Studio task ID, not video filename + "Date", + "Time", + "Notes:", + "Image Link:", +} + +DEFAULT_BUCKET = "prod-salmonvision-edge-assets-labelstudio-source" + + +@dataclass(frozen=True) +class ConditionRow: + project: str + site: str + camera: str + labelstudio_task_id: str + date: str + time: str + video_stem: str + s3_key: str + conditions: Dict[str, str] + source_csv: str + + +@dataclass +class VideoNegativeSample: + video_stem: str + s3_key: str + sampled_frames: List[int] + total_frames: int + positive_frames: int + eligible_negative_frames: int + conditions: Dict[str, str] + source_csv: str + + +@dataclass +class VideoMetadataRecord: + video_stem: str + s3_key: str + fps: float + nb_frames: int + duration: float + width: int + height: int + org: str + site: str + device: str + source_csv: str + + +def normalize_value(v: Any) -> Optional[str]: + if v is None: + return None + s = str(v).strip() + if not s or s.upper() == "NA": + return None + return s + + +def normalize_date(date_str: str) -> str: + dt = datetime.strptime(date_str.strip(), "%Y-%m-%d") + return dt.strftime("%Y%m%d") + + +def normalize_time(time_str: str) -> str: + # Handles "3:17:05" and "03:17:05" + dt = datetime.strptime(time_str.strip(), "%H:%M:%S") + return dt.strftime("%H%M%S") + + +def construct_video_stem(project: str, site: str, camera: str, date_str: str, time_str: str) -> str: + return f"{project}-{site}-{camera}_{normalize_date(date_str)}_{normalize_time(time_str)}_M" + + +def construct_task_s3_key(project: str, site: str, camera: str, video_stem: str) -> str: + return f"{project}/{site}/{camera}/labelstudio_tasks/{video_stem}.json" + + +def infer_condition_columns(fieldnames: Sequence[str]) -> List[str]: + cols: List[str] = [] + for name in fieldnames: + if name is None: + continue + s = name.strip() + if not s: + continue + if s in EXCLUDED_COLUMNS: + continue + cols.append(s) + return cols + + +def load_condition_rows(csv_paths: Sequence[Path]) -> Tuple[List[ConditionRow], List[str]]: + rows: List[ConditionRow] = [] + all_fieldnames: List[str] = [] + + for csv_path in csv_paths: + with csv_path.open("r", newline="", encoding="utf-8-sig") as f: + reader = csv.DictReader(f) + if reader.fieldnames: + for fn in reader.fieldnames: + if fn not in all_fieldnames: + all_fieldnames.append(fn) + + for raw in reader: + project = normalize_value(raw.get("Project")) + site = normalize_value(raw.get("Site")) + camera = normalize_value(raw.get("Camera")) + labelstudio_task_id = normalize_value(raw.get("Filename")) + date_str = normalize_value(raw.get("Date")) + time_str = normalize_value(raw.get("Time")) + + # Skip blank rows and placeholder rows like "NA" + if not project or not site or not camera or not date_str or not time_str: + continue + if not labelstudio_task_id: + continue + + try: + video_stem = construct_video_stem(project, site, camera, date_str, time_str) + except ValueError: + continue + + condition_values: Dict[str, str] = {} + for col in infer_condition_columns(reader.fieldnames or []): + v = normalize_value(raw.get(col)) + if v is not None: + condition_values[col] = v + + row = ConditionRow( + project=project, + site=site, + camera=camera, + labelstudio_task_id=labelstudio_task_id, + date=date_str, + time=time_str, + video_stem=video_stem, + s3_key=construct_task_s3_key(project, site, camera, video_stem), + conditions=condition_values, + source_csv=str(csv_path), + ) + rows.append(row) + + # Dedupe by real video stem + dedup: Dict[str, ConditionRow] = {} + for row in rows: + dedup[row.video_stem] = row + + deduped = list(dedup.values()) + condition_columns = infer_condition_columns(all_fieldnames) + return deduped, condition_columns + + +def active_condition_columns(rows: Sequence[ConditionRow], condition_columns: Sequence[str]) -> List[str]: + keep: List[str] = [] + for col in condition_columns: + vals = sorted({r.conditions[col] for r in rows if col in r.conditions}) + if len(vals) >= 2: + keep.append(col) + return keep + + +def compute_condition_targets( + rows: Sequence[ConditionRow], + condition_columns: Sequence[str], +) -> Tuple[Dict[Tuple[str, str], int], Dict[str, Counter]]: + per_col_counts: Dict[str, Counter] = {} + targets: Dict[Tuple[str, str], int] = {} + + for col in condition_columns: + c = Counter() + for row in rows: + if col in row.conditions: + c[row.conditions[col]] += 1 + if not c: + continue + + per_col_counts[col] = c + target = min(c.values()) + for value in c: + targets[(col, value)] = target + + return targets, per_col_counts + + +def greedy_select_balanced_rows( + rows: Sequence[ConditionRow], + condition_columns: Sequence[str], +) -> Tuple[List[ConditionRow], Dict[Tuple[str, str], int], Dict[str, Counter]]: + """ + Greedy marginal balancing: + - each condition column is balanced independently to its rarest category count + - rows that satisfy multiple deficits are preferred + """ + targets, per_col_counts = compute_condition_targets(rows, condition_columns) + deficits = dict(targets) + + remaining = list(rows) + selected: List[ConditionRow] = [] + + def row_score(row: ConditionRow) -> int: + score = 0 + for col in condition_columns: + val = row.conditions.get(col) + if val is None: + continue + score += max(deficits.get((col, val), 0), 0) + return score + + while True: + best_row = None + best_score = 0 + + for row in remaining: + score = row_score(row) + if score > best_score: + best_score = score + best_row = row + + if best_row is None or best_score <= 0: + break + + selected.append(best_row) + remaining.remove(best_row) + + for col in condition_columns: + val = best_row.conditions.get(col) + if val is None: + continue + key = (col, val) + if key in deficits and deficits[key] > 0: + deficits[key] -= 1 + + return selected, targets, per_col_counts + + +def parse_ts(s: str) -> datetime: + return datetime.fromisoformat(s.replace("Z", "+00:00")) + + +def parse_ffmpeg_rate(rate: Any) -> float: + if rate is None: + return 0.0 + if isinstance(rate, (int, float)): + return float(rate) + + s = str(rate).strip() + if "/" in s: + a, b = s.split("/", 1) + try: + num = float(a) + den = float(b) + return num / den if den else 0.0 + except Exception: + return 0.0 + try: + return float(s) + except Exception: + return 0.0 + + +def infer_fps(data: Dict[str, Any]) -> float: + fps = safe_float(data.get("frames_per_second"), 0.0) + if fps > 0: + return fps + + fps = parse_ffmpeg_rate(data.get("metadata_video_r_frame_rate")) + if fps > 0: + return fps + + fps = parse_ffmpeg_rate(data.get("metadata_video_avg_frame_rate")) + if fps > 0: + return fps + + duration = safe_float(data.get("metadata_video_duration", data.get("duration", 0.0)), 0.0) + nb_frames = int(safe_float(data.get("metadata_video_nb_frames"), 0)) + if duration > 0 and nb_frames > 0: + return nb_frames / duration + + return 0.0 + + +def extract_video_metadata_record( + item: dict, + *, + video_stem: str, + s3_key: str, + source_csv: str, +) -> VideoMetadataRecord: + data = item.get("data") or {} + + return VideoMetadataRecord( + video_stem=video_stem, + s3_key=s3_key, + fps=infer_fps(data), + nb_frames=int(safe_float(data.get("metadata_video_nb_frames"), 0)), + duration=safe_float(data.get("metadata_video_duration", data.get("duration", 0.0)), 0.0), + width=int(safe_float(data.get("metadata_video_width"), 0)), + height=int(safe_float(data.get("metadata_video_height"), 0)), + org=str(data.get("metadata_file_organization_reference_string", "")), + site=str(data.get("metadata_file_site_reference_string", "")), + device=str(data.get("metadata_file_camera_reference_string", "")), + source_csv=source_csv, + ) + + +def cache_task_json(task_json: Any, cache_root: Path, s3_key: str) -> Path: + out_path = cache_root / s3_key + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(task_json, indent=2, sort_keys=True) + "\n", encoding="utf-8") + return out_path + + +def interpolate_sequence(seq: Iterable[dict]) -> Dict[int, List[Tuple[float, float, float, float]]]: + kfs = sorted(seq, key=lambda k: int(safe_float(k.get("frame"), 0))) + frames_boxes: Dict[int, List[Tuple[float, float, float, float]]] = {} + + if not kfs: + return frames_boxes + + for k in kfs: + f = int(safe_float(k.get("frame"), -1)) + if f < 0: + continue + x = safe_float(k.get("x")) + y = safe_float(k.get("y")) + w = safe_float(k.get("width")) + h = safe_float(k.get("height")) + frames_boxes.setdefault(f, []).append((x, y, w, h)) + + for i in range(len(kfs) - 1): + k0 = kfs[i] + k1 = kfs[i + 1] + + f0 = int(safe_float(k0.get("frame"), -1)) + f1 = int(safe_float(k1.get("frame"), -1)) + if f0 < 0 or f1 <= f0: + continue + + enabled0 = bool(k0.get("enabled", True)) + if not enabled0: + continue + + x0 = safe_float(k0.get("x")) + y0 = safe_float(k0.get("y")) + w0 = safe_float(k0.get("width")) + h0 = safe_float(k0.get("height")) + + x1 = safe_float(k1.get("x")) + y1 = safe_float(k1.get("y")) + w1 = safe_float(k1.get("width")) + h1 = safe_float(k1.get("height")) + + for f in range(f0 + 1, f1): + t = (f - f0) / float(f1 - f0) + x = x0 + (x1 - x0) * t + y = y0 + (y1 - y0) * t + w = w0 + (w1 - w0) * t + h = h0 + (h1 - h0) * t + frames_boxes.setdefault(f, []).append((x, y, w, h)) + + return frames_boxes + + +def extract_task_item(task_json: Any, expected_video_stem: str) -> dict: + if isinstance(task_json, dict): + return task_json + if isinstance(task_json, list): + if len(task_json) == 1: + return task_json[0] + for item in task_json: + data = item.get("data") or {} + stem = Path(data.get("metadata_file_filename") or data.get("video") or "").stem + if stem == expected_video_stem: + return item + return task_json[0] + raise ValueError("Unsupported task JSON structure") + + +def infer_total_frames(item: dict, results: Optional[List[dict]] = None) -> int: + data = item.get("data") or {} + + n = int(safe_float(data.get("metadata_video_nb_frames"), 0)) + if n > 0: + return n + + if results: + for r in results: + value = r.get("value") or {} + n = int(safe_float(value.get("framesCount"), 0)) + if n > 0: + return n + + duration = safe_float(data.get("metadata_video_duration", data.get("duration", 0.0)), 0.0) + + fps = safe_float(data.get("frames_per_second"), 0.0) + if fps <= 0: + fps = parse_ffmpeg_rate(data.get("metadata_video_r_frame_rate")) + if fps <= 0: + fps = parse_ffmpeg_rate(data.get("metadata_video_avg_frame_rate")) + + if duration > 0 and fps > 0: + return int(round(duration * fps)) + + return 0 + + +def extract_latest_results( + item: dict, + result_type: str = "videorectangle", + from_name: Optional[str] = None, + to_name: Optional[str] = None, +) -> List[dict]: + annos = item.get("annotations") or [] + if not annos: + return [] + + latest_ann = max(annos, key=lambda a: parse_ts(a["updated_at"])) + out: List[dict] = [] + for r in (latest_ann.get("result") or []): + if r.get("type") != result_type: + continue + if from_name is not None and r.get("from_name") != from_name: + continue + if to_name is not None and r.get("to_name") != to_name: + continue + out.append(r) + return out + + +def extract_positive_frames( + item: dict, + result_type: str = "videorectangle", + from_name: Optional[str] = None, + to_name: Optional[str] = None, +) -> Set[int]: + results = extract_latest_results(item, result_type=result_type, from_name=from_name, to_name=to_name) + positive: Set[int] = set() + + for r in results: + value = r.get("value") or {} + seq = value.get("sequence") or [] + frame_boxes = interpolate_sequence(seq) + positive.update(frame_boxes.keys()) + + return positive + + +def stride_offset(video_stem: str, frame_stride: int, frame_offset_mode: str, frame_offset: int) -> int: + if frame_stride <= 1: + return 0 + if frame_offset_mode == "fixed": + return int(frame_offset) % frame_stride + if frame_offset_mode == "video_hash": + import zlib + return zlib.crc32(video_stem.encode("utf-8")) % frame_stride + raise ValueError(f"Invalid frame_offset_mode: {frame_offset_mode}") + + +def eligible_negative_frames( + video_stem: str, + total_frames: int, + positive_frames: Set[int], + frame_stride: int, + frame_offset_mode: str, + frame_offset: int, +) -> List[int]: + off = stride_offset(video_stem, frame_stride, frame_offset_mode, frame_offset) + return [ + f for f in range(total_frames) + if (f % frame_stride) == off and f not in positive_frames + ] + + +def fetch_task_json(s3_client: Any, bucket: str, key: str) -> Any: + obj = s3_client.get_object(Bucket=bucket, Key=key) + return json.loads(obj["Body"].read().decode("utf-8")) + + +def sample_frames(video_stem: str, eligible_frames: Sequence[int], k: int, seed: int) -> List[int]: + if not eligible_frames or k <= 0: + return [] + if k >= len(eligible_frames): + return sorted(eligible_frames) + + rng = random.Random(f"{seed}:{video_stem}") + return sorted(rng.sample(list(eligible_frames), k)) + + +def create_condition_negative_shards( + csv_paths: Sequence[Path], + out_dir: Path, + *, + bucket: str = DEFAULT_BUCKET, + frames_per_video: int = 5, + frame_stride: int = 3, + frame_offset_mode: str = "video_hash", + frame_offset: int = 0, + shard_size: int = 100000, + negative_seed: int = 42, + result_type: str = "videorectangle", + from_name: Optional[str] = None, + to_name: Optional[str] = None, + aws_profile: Optional[str] = None, + cache_task_json_dir: Optional[Path] = None, +) -> Dict[str, Any]: + out_dir.mkdir(parents=True, exist_ok=True) + manifest_csv = out_dir / "condition_negative_manifest.csv" + summary_json = out_dir / "condition_negative_summary.json" + metadata_csv = out_dir / "condition_negative_video_metadata.csv" + if cache_task_json_dir is not None: + cache_task_json_dir.mkdir(parents=True, exist_ok=True) + + session = boto3.Session(profile_name=aws_profile) if aws_profile else boto3.Session() + s3_client = session.client("s3") + + rows, raw_condition_columns = load_condition_rows(csv_paths) + condition_columns = active_condition_columns(rows, raw_condition_columns) + + selected_rows, targets, per_col_counts = greedy_select_balanced_rows(rows, condition_columns) + + writer = TarShardWriter(out_dir, shard_size=shard_size, prefix="condition_negatives") + + samples: List[VideoNegativeSample] = [] + failures: List[Dict[str, str]] = [] + metadata_records: List[VideoMetadataRecord] = [] + + for row in selected_rows: + try: + task_json = fetch_task_json(s3_client, bucket, row.s3_key) + + if cache_task_json_dir is not None: + cache_task_json(task_json, cache_task_json_dir, row.s3_key) + + item = extract_task_item(task_json, row.video_stem) + + metadata_records.append( + extract_video_metadata_record( + item, + video_stem=row.video_stem, + s3_key=row.s3_key, + source_csv=row.source_csv, + ) + ) + + results = extract_latest_results( + item, + result_type=result_type, + from_name=from_name, + to_name=to_name, + ) + total_frames = infer_total_frames(item, results=results) + if total_frames <= 0: + failures.append({"video_stem": row.video_stem, "reason": "total_frames_unavailable"}) + continue + + positive = extract_positive_frames( + item, + result_type=result_type, + from_name=from_name, + to_name=to_name, + ) + eligible = eligible_negative_frames( + row.video_stem, + total_frames, + positive, + frame_stride=frame_stride, + frame_offset_mode=frame_offset_mode, + frame_offset=frame_offset, + ) + + sampled = sample_frames( + row.video_stem, + eligible, + k=frames_per_video, + seed=negative_seed, + ) + if not sampled: + failures.append({"video_stem": row.video_stem, "reason": "no_eligible_negative_frames"}) + continue + + for frame_idx in sampled: + writer.write_text(f"{row.video_stem}/frame_{frame_idx:06d}.txt", "") + + samples.append( + VideoNegativeSample( + video_stem=row.video_stem, + s3_key=row.s3_key, + sampled_frames=sampled, + total_frames=total_frames, + positive_frames=len(positive), + eligible_negative_frames=len(eligible), + conditions=row.conditions, + source_csv=row.source_csv, + ) + ) + except Exception as e: + failures.append({"video_stem": row.video_stem, "reason": repr(e)}) + + writer.close() + + with manifest_csv.open("w", newline="", encoding="utf-8") as f: + fieldnames = [ + "video_stem", + "s3_key", + "source_csv", + "total_frames", + "positive_frames", + "eligible_negative_frames", + "sampled_frames", + ] + condition_columns + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for s in samples: + row = { + "video_stem": s.video_stem, + "s3_key": s.s3_key, + "source_csv": s.source_csv, + "total_frames": s.total_frames, + "positive_frames": s.positive_frames, + "eligible_negative_frames": s.eligible_negative_frames, + "sampled_frames": " ".join(str(x) for x in s.sampled_frames), + } + for col in condition_columns: + row[col] = s.conditions.get(col, "") + w.writerow(row) + + with metadata_csv.open("w", newline="", encoding="utf-8") as f: + fieldnames = [ + "video_stem", + "s3_key", + "fps", + "nb_frames", + "duration", + "width", + "height", + "org", + "site", + "device", + "source_csv", + ] + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for rec in metadata_records: + w.writerow(asdict(rec)) + + selected_condition_counts: Dict[str, Counter] = {} + for col in condition_columns: + c = Counter() + for s in samples: + if col in s.conditions: + c[s.conditions[col]] += 1 + selected_condition_counts[col] = c + + summary = { + "bucket": bucket, + "csv_paths": [str(p) for p in csv_paths], + "condition_columns": condition_columns, + "input_rows": len(rows), + "selected_videos_before_fetch": len(selected_rows), + "written_videos": len(samples), + "written_negative_frames": sum(len(s.sampled_frames) for s in samples), + "frames_per_video": frames_per_video, + "frame_stride": frame_stride, + "frame_offset_mode": frame_offset_mode, + "frame_offset": frame_offset, + "metadata_csv": str(metadata_csv), + "cached_task_json_dir": str(cache_task_json_dir) if cache_task_json_dir is not None else "", + "metadata_records_written": len(metadata_records), + "targets_by_condition": { + col: {val: targets[(col, val)] for val in per_col_counts.get(col, {})} + for col in condition_columns + }, + "input_counts_by_condition": { + col: dict(per_col_counts[col]) for col in condition_columns + }, + "selected_counts_by_condition": { + col: dict(selected_condition_counts[col]) for col in condition_columns + }, + "failures": failures, + } + summary_json.write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + return summary diff --git a/training/object-detection/src/object_detection/splits/cli.py b/training/object-detection/src/object_detection/splits/cli.py new file mode 100644 index 0000000..cbd00b2 --- /dev/null +++ b/training/object-detection/src/object_detection/splits/cli.py @@ -0,0 +1,129 @@ +import argparse +import csv +import json +from pathlib import Path + +from object_detection.splits.splitter import ( + build_groups, + split_groups_greedy, + write_manifest, + summarize_split, +) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--labels-root", required=True, type=Path, + help="Root of exploded YOLO labels, e.g. data/99_work/yolo_annos_exploded") + ap.add_argument("--out-dir", required=True, type=Path, + help="Output directory for split manifests") + ap.add_argument("--sites", nargs="*", default=["tankeeah", "kitwanga", "bear"], + help="Sites to include (baseline)") + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--train-frac", type=float, default=0.80) + ap.add_argument("--val-frac", type=float, default=0.10) + ap.add_argument("--test-frac", type=float, default=0.10) + ap.add_argument("--limit-files", type=int, default=None, + help="Debug: limit to N random label files") + + # Objective weights + ap.add_argument("--w-class", type=float, default=4.0) + ap.add_argument("--w-tod", type=float, default=1.0) + ap.add_argument("--w-density", type=float, default=1.0) + ap.add_argument("--w-area", type=float, default=1.0) + ap.add_argument("--w-ar", type=float, default=1.0) + ap.add_argument("--w-size", type=float, default=2.0) + + args = ap.parse_args() + + if not args.labels_root.exists(): + raise SystemExit(f"labels-root not found: {args.labels_root}") + + ssum = args.train_frac + args.val_frac + args.test_frac + if abs(ssum - 1.0) > 1e-6: + raise SystemExit(f"train/val/test fractions must sum to 1.0; got {ssum}") + + groups = build_groups( + labels_root=args.labels_root, + sites_keep=args.sites, + seed=args.seed, + limit=args.limit_files, + ) + + # Split + weights = { + "class": args.w_class, + "tod": args.w_tod, + "density": args.w_density, + "area": args.w_area, + "ar": args.w_ar, + "size": args.w_size, + } + + train, val, test, report = split_groups_greedy( + groups=groups, + seed=args.seed, + train_frac=args.train_frac, + val_frac=args.val_frac, + test_frac=args.test_frac, + weights=weights, + ) + + out_dir = args.out_dir + out_dir.mkdir(parents=True, exist_ok=True) + + # Write manifests (relative paths to label files) + write_manifest(out_dir / "train.txt", train.frame_paths) + write_manifest(out_dir / "val.txt", val.frame_paths) + write_manifest(out_dir / "test.txt", test.frame_paths) + + # Group assignment CSV + with (out_dir / "group_assignments.csv").open("w", newline="") as f: + w = csv.writer(f) + w.writerow(["group_id", "split", "site", "device", "date", "n_frames", "n_boxes"]) + for s in [train, val, test]: + for gid in s.group_ids: + g = groups[gid] + w.writerow([gid, s.name, g.site, g.device, g.date, g.n_frames, g.n_boxes]) + + # JSON report + full_report = { + "params": { + "labels_root": str(args.labels_root), + "sites": args.sites, + "seed": args.seed, + "fractions": {"train": args.train_frac, "val": args.val_frac, "test": args.test_frac}, + "weights": weights, + "grouping": "group_id = site|device|YYYYMMDD", + "notes": [ + "Split is group-wise to reduce leakage from temporally adjacent frames.", + "Time-of-day bucket derives from video clip HHMMSS in stem; frames inherit clip bucket.", + "Balancing is soft; rare classes are prioritized earlier in greedy assignment.", + ], + }, + "targets": { + "total_frames": report["total_frames"], + "target_frames": report["target_frames"], + "actual_frames": report["actual_frames"], + "global_class_dist": report["class_dist"], + "global_tod_dist": report["tod_dist"], + "global_density_dist": report["density_dist"], + "global_area_dist": report["area_dist"], + "global_ar_dist": report["ar_dist"], + }, + "splits": { + "train": summarize_split(train), + "val": summarize_split(val), + "test": summarize_split(test), + }, + } + + (out_dir / "split_report.json").write_text(json.dumps(full_report, indent=2, sort_keys=True) + "\n") + + print("[make_splits] wrote:") + print(f" {out_dir / 'train.txt'} ({len(train.frame_paths)} frames)") + print(f" {out_dir / 'val.txt'} ({len(val.frame_paths)} frames)") + print(f" {out_dir / 'test.txt'} ({len(test.frame_paths)} frames)") + print(f" {out_dir / 'group_assignments.csv'}") + print(f" {out_dir / 'split_report.json'}") + + diff --git a/training/object-detection/src/object_detection/splits/parsing.py b/training/object-detection/src/object_detection/splits/parsing.py new file mode 100644 index 0000000..bf21982 --- /dev/null +++ b/training/object-detection/src/object_detection/splits/parsing.py @@ -0,0 +1,120 @@ +import re +from pathlib import Path +from typing import Dict, Tuple, Optional +from collections import Counter + +from object_detection.utils.utils import safe_float + +def time_bucket(hhmmss: str) -> str: + """Coarse time-of-day buckets based on HH.""" + try: + hh = int(hhmmss[0:2]) + except Exception: + return "unknown" + if 0 <= hh <= 5: + return "night" + if 6 <= hh <= 11: + return "morning" + if 12 <= hh <= 17: + return "afternoon" + if 18 <= hh <= 23: + return "evening" + return "unknown" + + +def density_bin(n_boxes: int) -> str: + """Bins for boxes per frame.""" + if n_boxes <= 0: + return "0" + if n_boxes == 1: + return "1" + if n_boxes == 2: + return "2" + if 3 <= n_boxes <= 4: + return "3-4" + if 5 <= n_boxes <= 9: + return "5-9" + return "10+" + +def ar_bin(w: float, h: float) -> str: + """ + Aspect ratio bins based on w/h. + w,h are YOLO normalized widths/heights in [0,1]. + """ + if w <= 0 or h <= 0: + return "invalid" + r = w / h + + # You can tune these thresholds, but this is a good start: + if r < 0.67: + return "tall" # height-dominant + if r <= 1.5: + return "square" # roughly square-ish + return "wide" # width-dominant + +def area_bin(area: float) -> str: + """ + Bin YOLO normalized bbox area (w*h) in [0,1]. + Tune thresholds if needed. + """ + if area <= 0: + return "0" + if area < 0.0025: + return "<0.0025" + if area < 0.01: + return "0.0025-0.01" + if area < 0.04: + return "0.01-0.04" + if area < 0.16: + return "0.04-0.16" + return ">=0.16" + +def parse_frame_idx(filename: str) -> Optional[int]: + # frame_000123.txt + m = re.match(r"^frame_(\d+)\.txt$", filename) + if not m: + return None + try: + return int(m.group(1)) + except Exception: + return None + + +def read_yolo_label(path: Path) -> Tuple[int, Counter, Counter, Counter]: + """ + Returns: + n_boxes, + class_counts (class_id -> count), + area_bins (area_bin -> count) + """ + n_boxes = 0 + class_counts: Counter = Counter() + area_counts: Counter = Counter() + ar_counts: Counter = Counter() + + try: + txt = path.read_text().strip() + except Exception: + return 0, Counter(), Counter(), Counter() + + if not txt: + return 0, Counter(), Counter(), Counter() + + for line in txt.splitlines(): + parts = line.strip().split() + if len(parts) < 5: + continue + cls = parts[0] + w = safe_float(parts[3], 0.0) + h = safe_float(parts[4], 0.0) + try: + cls_id = int(cls) + except Exception: + continue + n_boxes += 1 + class_counts[cls_id] += 1 + area_counts[area_bin(w * h)] += 1 + ar_counts[ar_bin(w, h)] += 1 + + return n_boxes, class_counts, area_counts, ar_counts + diff --git a/training/object-detection/src/object_detection/splits/splitter.py b/training/object-detection/src/object_detection/splits/splitter.py new file mode 100644 index 0000000..12dd77d --- /dev/null +++ b/training/object-detection/src/object_detection/splits/splitter.py @@ -0,0 +1,361 @@ +import random +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Tuple, Optional, Iterable, Any +from collections import Counter + +from object_detection.splits.parsing import ( + time_bucket, + density_bin, + parse_frame_idx, + read_yolo_label, +) +from object_detection.utils.utils import parse_video_stem + +# ----------------------------- +# Data structures +# ----------------------------- + +@dataclass +class FrameRecord: + rel_path: str # relative path to label file within labels_root + video_stem: str + frame_idx: int + org: str + site: str + device: str + date: str # YYYYMMDD + tod: str # time-of-day bucket + + n_boxes: int + class_counts: Counter # class_id -> count + density_bin: str + area_bins: Counter # area_bin -> count (counts per box) + + +@dataclass +class GroupStats: + group_id: str + site: str + device: str + date: str + + n_frames: int + n_boxes: int + + class_counts: Counter + tod_counts: Counter + density_counts: Counter + area_counts: Counter + ar_counts: Counter + + # list of frame rel_paths (for manifest writing) + frame_paths: List[str] + + +# ----------------------------- +# Scanning labels +# ----------------------------- + +def iter_label_files(labels_root: Path) -> Iterable[Path]: + # Expect structure: labels_root//frame_000123.txt + for p in labels_root.rglob("frame_*.txt"): + if p.is_file(): + yield p + +def build_groups( + labels_root: Path, + sites_keep: List[str], + seed: int, + limit: Optional[int] = None, +) -> Dict[str, GroupStats]: + """ + Build per-group aggregates. Group id = site|device|date. + """ + rnd = random.Random(seed) + + files = list(iter_label_files(labels_root)) + files.sort() + if limit is not None: + rnd.shuffle(files) + files = files[:limit] + + groups: Dict[str, GroupStats] = {} + + skipped = 0 + for f in files: + video_stem = f.parent.name + meta = parse_video_stem(video_stem) + if meta is None: + skipped += 1 + continue + + site = meta["site"] + if sites_keep and site not in sites_keep: + continue + + frame_idx = parse_frame_idx(f.name) + if frame_idx is None: + skipped += 1 + continue + + n_boxes, class_counts, area_counts, ar_counts = read_yolo_label(f) + # If you ever include negatives, you may want to keep empty files too. + # For now, assume label files exist only when boxes exist. + dens_bin = density_bin(n_boxes) + tod = time_bucket(meta["time"]) + + rel_path = str(f.relative_to(labels_root)) + + group_id = f"{site}|{meta['device']}|{meta['date']}" + if group_id not in groups: + groups[group_id] = GroupStats( + group_id=group_id, + site=site, + device=meta["device"], + date=meta["date"], + n_frames=0, + n_boxes=0, + class_counts=Counter(), + tod_counts=Counter(), + density_counts=Counter(), + area_counts=Counter(), + ar_counts=Counter(), + frame_paths=[], + ) + + g = groups[group_id] + g.n_frames += 1 + g.n_boxes += n_boxes + g.class_counts.update(class_counts) + g.tod_counts[tod] += 1 + g.density_counts[dens_bin] += 1 + g.area_counts.update(area_counts) + g.ar_counts.update(ar_counts) + g.frame_paths.append(rel_path) + + if skipped: + print(f"[make_splits] skipped {skipped} files due to parse issues") + print(f"[make_splits] groups={len(groups)} from label files={len(files)}") + return groups + + +# ----------------------------- +# Split objective +# ----------------------------- + +def normalize_counter(c: Counter) -> Dict[Any, float]: + s = float(sum(c.values())) + if s <= 0: + return {} + return {k: v / s for k, v in c.items()} + + +def l1_dist(p: Dict[Any, float], q: Dict[Any, float], keys: Iterable[Any]) -> float: + d = 0.0 + for k in keys: + d += abs(p.get(k, 0.0) - q.get(k, 0.0)) + return d + + +@dataclass +class SplitState: + name: str + target_frac: float + n_frames: int = 0 + + class_counts: Counter = None + tod_counts: Counter = None + density_counts: Counter = None + area_counts: Counter = None + ar_counts: Counter = None + + group_ids: List[str] = None + frame_paths: List[str] = None + + def __post_init__(self): + self.class_counts = Counter() + self.tod_counts = Counter() + self.density_counts = Counter() + self.area_counts = Counter() + self.ar_counts = Counter() + self.group_ids = [] + self.frame_paths = [] + + def add_group(self, g: GroupStats): + self.n_frames += g.n_frames + self.class_counts.update(g.class_counts) + self.tod_counts.update(g.tod_counts) + self.density_counts.update(g.density_counts) + self.area_counts.update(g.area_counts) + self.ar_counts.update(g.ar_counts) + self.group_ids.append(g.group_id) + self.frame_paths.extend(g.frame_paths) + + +def compute_global_targets(groups: List[GroupStats]) -> Dict[str, Any]: + total_frames = sum(g.n_frames for g in groups) + + global_class = Counter() + global_tod = Counter() + global_density = Counter() + global_area = Counter() + global_ar = Counter() + + for g in groups: + global_class.update(g.class_counts) + global_tod.update(g.tod_counts) + global_density.update(g.density_counts) + global_area.update(g.area_counts) + global_ar.update(g.ar_counts) + + targets = { + "total_frames": total_frames, + "class_keys": sorted(global_class.keys()), + "tod_keys": sorted(global_tod.keys()), + "density_keys": sorted(global_density.keys()), + "area_keys": sorted(global_area.keys()), + "ar_keys": sorted(global_ar.keys()), + "class_dist": normalize_counter(global_class), + "tod_dist": normalize_counter(global_tod), + "density_dist": normalize_counter(global_density), + "area_dist": normalize_counter(global_area), + "ar_dist": normalize_counter(global_ar), + } + return targets + + +def rarity_score(g: GroupStats, global_class_dist: Dict[int, float]) -> float: + """ + Higher score => assign earlier. + Use inverse frequency weighting on classes present in the group. + """ + s = 0.0 + for cls_id, cnt in g.class_counts.items(): + p = global_class_dist.get(cls_id, 1e-12) + # weight by amount of that class in the group + s += cnt * (1.0 / max(p, 1e-6)) + # also emphasize very dense groups a bit + s += 0.25 * g.n_boxes + return s + + +def split_groups_greedy( + groups: Dict[str, GroupStats], + seed: int, + train_frac: float, + val_frac: float, + test_frac: float, + weights: Dict[str, float], +) -> Tuple[SplitState, SplitState, SplitState, Dict[str, Any]]: + """ + Greedy group assignment minimizing distance to global distributions + size penalty. + + weights keys: class, tod, density, area, size + """ + rnd = random.Random(seed) + group_list = list(groups.values()) + + targets = compute_global_targets(group_list) + total_frames = targets["total_frames"] + + # Sort groups by rarity (desc), stable tie-break with seed + rnd.shuffle(group_list) + group_list.sort(key=lambda g: rarity_score(g, targets["class_dist"]), reverse=True) + + train = SplitState("train", train_frac) + val = SplitState("val", val_frac) + test = SplitState("test", test_frac) + splits = [test, val, train] + + # precompute target frame counts + target_frames = { + "train": train_frac * total_frames, + "val": val_frac * total_frames, + "test": test_frac * total_frames, + } + + def score_split(after: SplitState) -> float: + # distribution distances (L1) + class_d = l1_dist(normalize_counter(after.class_counts), targets["class_dist"], targets["class_keys"]) + tod_d = l1_dist(normalize_counter(after.tod_counts), targets["tod_dist"], targets["tod_keys"]) + dens_d = l1_dist(normalize_counter(after.density_counts), targets["density_dist"], targets["density_keys"]) + area_d = l1_dist(normalize_counter(after.area_counts), targets["area_dist"], targets["area_keys"]) + ar_d = l1_dist(normalize_counter(after.ar_counts), targets["ar_dist"], targets["ar_keys"]) + + # size penalty: keep n_frames close to target + tf = target_frames[after.name] + size_d = abs(after.n_frames - tf) / max(tf, 1.0) + + return ( + weights["class"] * class_d + + weights["tod"] * tod_d + + weights["density"] * dens_d + + weights["area"] * area_d + + weights["ar"] * ar_d + + weights["size"] * size_d + ) + + # Greedy: for each group, try each split, pick minimal total score across all splits + for g in group_list: + best = None + best_score = float("inf") + + for s in splits: + # clone minimal stats (cheap-ish since Counters) + tmp = SplitState(s.name, s.target_frac) + tmp.n_frames = s.n_frames + tmp.class_counts = s.class_counts.copy() + tmp.tod_counts = s.tod_counts.copy() + tmp.density_counts = s.density_counts.copy() + tmp.area_counts = s.area_counts.copy() + + tmp.add_group(g) + + # compute global score as sum of each split score + # (this keeps all splits moving toward their targets) + total = 0.0 + for other in splits: + if other.name == s.name: + total += score_split(tmp) + else: + total += score_split(other) + + if total < best_score: + best_score = total + best = s + + assert best is not None + best.add_group(g) + + report = { + "total_frames": total_frames, + "target_frames": target_frames, + "actual_frames": {s.name: s.n_frames for s in splits}, + } + return train, val, test, {**targets, **report} + + +# ----------------------------- +# Reporting + writing +# ----------------------------- + +def summarize_split(s: SplitState) -> Dict[str, Any]: + return { + "n_frames": s.n_frames, + "n_boxes": int(sum(s.class_counts.values())), + "n_groups": len(s.group_ids), + "class_counts": dict(s.class_counts), + "tod_counts": dict(s.tod_counts), + "density_counts": dict(s.density_counts), + "area_counts": dict(s.area_counts), + "ar_counts": dict(s.ar_counts), + } + + +def write_manifest(out_path: Path, rel_paths: List[str]): + out_path.parent.mkdir(parents=True, exist_ok=True) + rel_paths = list(rel_paths) + rel_paths.sort() + out_path.write_text("\n".join(rel_paths) + ("\n" if rel_paths else "")) + diff --git a/training/object-detection/src/object_detection/utils/__init__.py b/training/object-detection/src/object_detection/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/object-detection/src/object_detection/utils/utils.py b/training/object-detection/src/object_detection/utils/utils.py new file mode 100644 index 0000000..a3506da --- /dev/null +++ b/training/object-detection/src/object_detection/utils/utils.py @@ -0,0 +1,56 @@ +import re +from typing import Dict, Optional + +def safe_float(x: str, default: float = 0.0) -> float: + try: + return float(x) + except Exception: + return default + + +_STEM_RE = re.compile( + r""" + ^ + (?P.+?) + _ + (?P\d{8}) + _ + (?P