diff --git a/pyproject.toml b/pyproject.toml index 4749b92..c14665a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,14 @@ dynamic = ["version", "description"] dev = ["black", "bumpver", "isort", "pip-tools", "pytest", "ruff"] mediapipe = ["mediapipe==0.10.14"] yolo = ["ultralytics~=8.0.202"] +dlc = ["deeplabcut==3.0.0rc9", "pillow>=11.1.0", "pyyaml>=6.0.2"] +all = ["ultralytics~=8.0.202", "mediapipe==0.10.14"] + +[project.optional-dependencies] +dev = ["black", "bumpver", "isort", "pip-tools", "pytest", "ruff"] +mediapipe = ["mediapipe==0.10.14"] +yolo = ["ultralytics~=8.0.202"] +dlc = ["deeplabcut==3.0.0rc9", "pillow>=11.1.0", "pyyaml>=6.0.2"] all = ["ultralytics~=8.0.202", "mediapipe==0.10.14"] [project.urls] diff --git a/skellytracker/__init__.py b/skellytracker/__init__.py index ee55780..bba7d46 100644 --- a/skellytracker/__init__.py +++ b/skellytracker/__init__.py @@ -21,9 +21,6 @@ configure_logging(LOG_LEVEL) -# - -# # try: # from skellytracker.trackers.mediapipe_tracker.mediapipe_holistic_tracker import ( # MediapipeHolisticTracker, diff --git a/skellytracker/trackers/base_tracker/base_tracker_abcs.py b/skellytracker/trackers/base_tracker/base_tracker_abcs.py index 0f2f284..ec0498e 100644 --- a/skellytracker/trackers/base_tracker/base_tracker_abcs.py +++ b/skellytracker/trackers/base_tracker/base_tracker_abcs.py @@ -2,7 +2,7 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import List +from typing import List, Sequence import cv2 import numpy as np @@ -22,30 +22,31 @@ class BaseObservation(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) - frame_number: int # the frame number of the image in which this observation was made - tracker_type: TrackerTypeString = Field(description="Name of the tracker that made this observation.") - + frame_number: ( + int # the frame number of the image in which this observation was made + ) + tracker_type: TrackerTypeString = Field( + description="Name of the tracker that made this observation." + ) @classmethod @abstractmethod def from_detection_results(cls, *args, **kwargs): pass - @abstractmethod - def to_tracked_points(cls, *args, **kwargs) -> dict[TrackedPointIdString, TrackedPoint2d]: + def to_tracked_points( + cls, *args, **kwargs + ) -> dict[TrackedPointIdString, TrackedPoint2d]: pass - @abstractmethod def to_array(self) -> np.ndarray: pass - def to_json_string(self) -> str: return json.dumps(self.model_dump_json(), indent=4) - def to_json_bytes(self) -> bytes: return self.to_json_string().encode("utf-8") @@ -59,7 +60,7 @@ class BaseImageAnnotatorConfig(BaseModel, ABC): class BaseImageAnnotator(BaseModel, ABC): config: BaseImageAnnotatorConfig - observations: BaseObservations # make it a list to allow plotting trails, etc. + observations: BaseObservations @classmethod @abstractmethod @@ -67,19 +68,33 @@ def create(cls, config: BaseImageAnnotatorConfig): pass @abstractmethod - def annotate_image(self, image: np.ndarray, latest_observation: BaseObservation) -> np.ndarray: + def annotate_image( + self, image: np.ndarray, latest_observation: BaseObservation + ) -> np.ndarray: pass @staticmethod - def draw_doubled_text(image: np.ndarray, - text: str, - x: int, - y: int, - font_scale: float, - color: tuple[int, int, int], - thickness: int): - cv2.putText(image, text, (x, y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), thickness * 3) - cv2.putText(image, text, (x, y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness) + def draw_doubled_text( + image: np.ndarray, + text: str, + x: int, + y: int, + font_scale: float, + color: tuple[int, int, int], + thickness: int, + ): + cv2.putText( + image, + text, + (x, y), + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + (0, 0, 0), + thickness * 3, + ) + cv2.putText( + image, text, (x, y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness + ) class BaseDetectorConfig(BaseModel, ABC): @@ -100,21 +115,25 @@ class BaseDetector(BaseModel, ABC): @classmethod def create(cls, config: BaseDetectorConfig): - raise NotImplementedError("Must implement a method to create a detector from a config.") + raise NotImplementedError( + "Must implement a method to create a detector from a config." + ) @abstractmethod - def detect(self, - frame_number: int, - image: np.ndarray) -> BaseObservation: + def detect(self, frame_number: int, image: np.ndarray) -> BaseObservation: pass class BaseRecorder(BaseModel, ABC): + # TODO: could be called ObservationGroup observations: List[BaseObservation] = Field(default_factory=list) def add_observation(self, observation: BaseObservation): self.observations.append(observation) + def add_observations(self, observations: Sequence[BaseObservation]): + self.observations.extend(observations) + # I'm imagining these can be used if you want the data but want to handle saving elsewhere @property def to_array(self) -> np.ndarray: @@ -122,8 +141,10 @@ def to_array(self) -> np.ndarray: @property def to_json_string(self) -> str: - output_dict = {frame_number: observation.model_dump_json() for frame_number, observation in - enumerate(self.observations)} + output_dict = { + frame_number: observation.model_dump_json() + for frame_number, observation in enumerate(self.observations) + } return json.dumps(output_dict, indent=4) # and these are used if you want skellytracker to handle the saving @@ -131,7 +152,7 @@ def save_array(self, output_path: Path): np.save(file=output_path, arr=self.to_array) def save_json_file(self, output_path: Path): - with open(output_path, 'w') as json_file: + with open(output_path, "w") as json_file: json_file.write(self.to_json_string) def clear(self): @@ -139,7 +160,7 @@ def clear(self): class BaseObservationManager(BaseModel, ABC): - observations: List[BaseObservation] + observations: Sequence[BaseObservation] @abstractmethod def create_observation(self, **kwargs) -> BaseObservation: @@ -150,92 +171,83 @@ class BaseTracker(BaseModel, ABC): config: BaseTrackerConfig detector: BaseDetector annotator: BaseImageAnnotator - recorder: BaseRecorder | None = None + recorder: BaseRecorder | None @classmethod def create(cls, config: BaseTrackerConfig): - raise NotImplementedError("Must implement a method to create a tracker from a config.") + raise NotImplementedError( + "Must implement a method to create a tracker from a config." + ) - def process_image(self, - frame_number: int, - image: np.ndarray, - record_observation: bool = True) -> BaseObservation: - latest_observation = self.detector.detect(image=image, frame_number=frame_number) + def process_image( + self, frame_number: int, image: np.ndarray, record_observation: bool = True + ) -> BaseObservation: + latest_observation = self.detector.detect( + image=image, frame_number=frame_number + ) if record_observation and self.recorder is not None: self.recorder.add_observation(observation=latest_observation) return latest_observation - def annotate_image(self, image: np.ndarray, latest_observation: BaseObservation) -> np.ndarray: - return self.annotator.annotate_image(image=image, latest_observation=latest_observation) + def annotate_image( + self, image: np.ndarray, latest_observation: BaseObservation + ) -> np.ndarray: + return self.annotator.annotate_image( + image=image, latest_observation=latest_observation + ) def demo(self) -> None: camera_viewer = WebcamDemoViewer( - tracker=self, - window_title=self.__class__.__name__ + tracker=self, window_title=self.__class__.__name__ ) camera_viewer.run() def image_demo(self, image_path: Path) -> None: """ Run tracker on single image - + :return: None """ image_viewer = ImageDemoViewer(self, self.__class__.__name__) image_viewer.run(image_path=image_path) -# -# class BaseCumulativeTracker(BaseTracker): -# """ -# A base class for tracking algorithms that run cumulatively, i.e are not able to process videos frame by frame. -# Throws a descriptive error for the abstract methods of BaseTracker that do not apply to this type of tracker. -# Trackers inheriting from this will need to overwrite the `process_video` method. -# """ -# -# def __init__( -# self, -# tracked_object_names: List[str], -# recorder: BaseCumulativeRecorder, -# **data: Any, -# ): -# super().__init__( -# tracked_object_names=tracked_object_names, recorder=recorder, **data -# ) -# -# def process_image(self, **kwargs) -> None: -# raise NotImplementedError( -# "This tracker does not support processing individual images, please use process_video instead." -# ) -# -# def annotate_image(self, **kwargs) -> None: -# raise NotImplementedError( -# "This tracker does not support processing individual images, please use process_video instead." -# ) -# -# @abstractmethod -# def process_video( -# self, -# input_video_filepath: Union[str, Path], -# output_video_filepath: Optional[Union[str, Path]] = None, -# save_data_bool: bool = False, -# use_tqdm: bool = True, -# **kwargs, -# ) -> Union[np.ndarray, None]: -# """ -# Run the tracker on a video. -# -# :param input_video_filepath: Path to video file. -# :param output_video_filepath: Path to save annotated video to, does not save video if None. -# :param save_data_bool: Whether to save the data to a file. -# :param use_tqdm: Whether to use tqdm to show a progress bar -# :return: Array of tracked keypoint data -# """ -# pass -# -# def image_demo(self, image_path: Path) -> None: -# raise NotImplementedError( -# "This tracker does not support processing individual images, please use process_video instead." -# ) + +class CumulativeBaseTracker(BaseTracker): + def process_image( + self, frame_number: int, image: np.ndarray, record_observation: bool = True + ) -> BaseObservation: + raise NotImplementedError( + "This tracker does not support processing individual images, please use process_video instead." + ) + + def annotate_image( + self, image: np.ndarray, latest_observation: BaseObservation + ) -> np.ndarray: + raise NotImplementedError( + "This tracker does not support processing individual images, please use process_video instead." + ) + + def demo(self) -> None: + raise NotImplementedError( + "This tracker does not support processing individual images, please use process_video instead." + ) + + def image_demo(self, image_path: Path) -> None: + raise NotImplementedError( + "This tracker does not support processing individual images, please use process_video instead." + ) + + @abstractmethod + def process_video( + self, input_video_filepath: Path, **kwargs + ) -> Sequence[BaseObservation]: + pass + + @abstractmethod + def annotate_video( + self, input_video_filepath: Path, output_video_filepath: Path, **kwargs + ) -> None: + pass diff --git a/skellytracker/trackers/dlc_tracker/__dlc_tracker.py b/skellytracker/trackers/dlc_tracker/__dlc_tracker.py new file mode 100644 index 0000000..5bcf9a1 --- /dev/null +++ b/skellytracker/trackers/dlc_tracker/__dlc_tracker.py @@ -0,0 +1,133 @@ +from pathlib import Path +import cv2 +from deeplabcut.utils import auxiliaryfunctions +import numpy as np +import pandas as pd +from pydantic import BaseModel +from skellytracker.trackers.base_tracker.base_tracker_abcs import BaseRecorder, CumulativeBaseTracker +from skellytracker.trackers.dlc_tracker.dlc_annotator import DeepLabCutAnnotatorConfig, DeepLabCutImageAnnotator +from skellytracker.trackers.dlc_tracker.dlc_detector import DeepLabCutDetector, DeepLabCutDetectorConfig +from skellytracker.trackers.dlc_tracker.dlc_observation import DeepLabCutObservation + +class DeepLabCutTrackerConfig(BaseModel): + tracker_name: str + iteration: int + detector_config: DeepLabCutDetectorConfig + annotator_config: DeepLabCutAnnotatorConfig + + @classmethod + def from_config_yaml(cls, config_path: str): + config = auxiliaryfunctions.read_config(config_path) + return cls( + tracker_name=config.get("Task", "DeepLabCutTracker"), + iteration=config.get("iteration", 0), + detector_config=DeepLabCutDetectorConfig(dlc_config=config_path), + annotator_config=DeepLabCutAnnotatorConfig(), + ) + +class DeepLabCutRecorder(BaseRecorder): + def load_deeplabcut_csv(self, csv_path: Path, image_size: tuple[int, int] = (1280, 720)) -> list[DeepLabCutObservation]: + df = pd.read_csv(csv_path, header=[1,2]) + + df = df.iloc[:, 1:] + + if df.shape[1] % 3 != 0: + raise ValueError(f"csv file {csv_path} has {df.shape[1]} columns, which is not divisible by 3") + + try: + points = df.values.reshape(df.shape[0], df.shape[1] // 3, 3) + except ValueError as e: + raise ValueError(f"Reshape failed for csv file {csv_path} with shape {df.shape}: {e}") + + observations = [ + DeepLabCutObservation(frame_number=i, + pose_points=points[i, :, :2], + confidence_values=points[i, :, 2], + image_size=image_size + ) for i in range(len(points)) + ] + + self.clear() + self.add_observations(observations=observations) + + return observations + +class DeepLabCutTracker(CumulativeBaseTracker): + config: DeepLabCutTrackerConfig + detector: DeepLabCutDetector + recorder: DeepLabCutRecorder + annotator: DeepLabCutImageAnnotator | None = None + + @classmethod + def create(cls, config: DeepLabCutTrackerConfig): + detector = DeepLabCutDetector.create(config.detector_config) + + return cls( + config=config, + detector=detector, + annotator=DeepLabCutImageAnnotator.create(config.annotator_config), + recorder=DeepLabCutRecorder(), + ) + + def process_video(self, input_video_filepath: Path, **kwargs) -> list[DeepLabCutObservation]: + observations = self.detector.detect_video(input_video_filepath, **kwargs) + + self.recorder.add_observations(observations=observations) + + return observations + + def annotate_image( + self, image: np.ndarray, latest_observation: DeepLabCutObservation + ) -> np.ndarray: + if self.annotator is None: + raise ValueError("No annotator configured") + return self.annotator.annotate_image( + image=image, latest_observation=latest_observation + ) + + def annotate_video(self, input_video_filepath: Path, output_video_filepath: Path, **kwargs) -> None: + cap = cv2.VideoCapture(str(input_video_filepath)) + + num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if len(self.recorder.observations) < num_frames: + cap.release() + raise ValueError(f"Not enough observations to annotate video (video has {num_frames} frames, but only {len(self.recorder.observations)} observations)") + + writer = cv2.VideoWriter( + str(output_video_filepath), + cv2.VideoWriter.fourcc(*"AVC1"), + cap.get(cv2.CAP_PROP_FPS), + (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))), + ) + + i = 0 + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + annotated_frame = self.annotate_image(frame, self.recorder.observations[i]) + + writer.write(annotated_frame) + + i += 1 + + cap.release() + writer.release() + +if __name__ == "__main__": + tracker = DeepLabCutTracker.create( + DeepLabCutTrackerConfig( + detector_config=DeepLabCutDetectorConfig(dlc_config="/Users/philipqueen/clicker_testing/clicker_testing/config.yaml"), + annotator_config=DeepLabCutAnnotatorConfig(), + ) + ) + input_video = "/Users/philipqueen/freemocap_data/recording_sessions/freemocap_test_data/synchronized_videos/sesh_2022-09-19_16_16_50_in_class_jsm_synced_Cam2.mp4" + + tracker.process_video( + input_video_filepath=Path(input_video), + ) + tracker.annotate_video( + input_video_filepath=Path(input_video), + output_video_filepath=Path("/Users/philipqueen/clicker_testing/clicker_testing/test_annotated.mp4"), + ) \ No newline at end of file diff --git a/skellytracker/trackers/dlc_tracker/__init__.py b/skellytracker/trackers/dlc_tracker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/skellytracker/trackers/dlc_tracker/dlc_annotator.py b/skellytracker/trackers/dlc_tracker/dlc_annotator.py new file mode 100644 index 0000000..e2f6fae --- /dev/null +++ b/skellytracker/trackers/dlc_tracker/dlc_annotator.py @@ -0,0 +1,53 @@ +import cv2 +import numpy as np +from numpydantic import NDArray, Shape + + +from skellytracker.trackers.base_tracker.base_tracker_abcs import BaseImageAnnotator, BaseImageAnnotatorConfig +from skellytracker.trackers.dlc_tracker.dlc_observation import DeepLabCutObservation + + +class DeepLabCutAnnotatorConfig(BaseImageAnnotatorConfig): + show_tracks: int | None = 15 + show_overlay: bool = True + marker_type: int = cv2.MARKER_DIAMOND + marker_size: int = 10 + marker_thickness: int = 2 + marker_color: tuple[int, int, int] = (0, 0, 255) + + text_color: tuple[int, int, int] = (215, 115, 40) + text_size: float = .5 + text_thickness: int = 2 + text_font: int = cv2.FONT_HERSHEY_SIMPLEX + + +class DeepLabCutImageAnnotator(BaseImageAnnotator): + config: DeepLabCutAnnotatorConfig + observations: list[DeepLabCutObservation] + + @classmethod + def create(cls, config: DeepLabCutAnnotatorConfig): + return cls(config=config, observations=[]) + + def annotate_image( + self, + image: NDArray[Shape["* width, * height, 1-4 channels"], np.uint8], + latest_observation: DeepLabCutObservation | None = None, + ) -> np.ndarray: + if latest_observation is None: + return image.copy() + # Copy the original image for annotation + annotated_image = image.copy() + + for marker in range(latest_observation.pose_points.shape[0]): + point = latest_observation.pose_points[marker, :2] + cv2.drawMarker( + img=annotated_image, + position=(int(point[0]), int(point[1])), + color=self.config.marker_color, + markerType=self.config.marker_type, + markerSize=self.config.marker_size, + thickness=self.config.marker_thickness, + ) + + return annotated_image diff --git a/skellytracker/trackers/dlc_tracker/dlc_detector.py b/skellytracker/trackers/dlc_tracker/dlc_detector.py new file mode 100644 index 0000000..5f4f79d --- /dev/null +++ b/skellytracker/trackers/dlc_tracker/dlc_detector.py @@ -0,0 +1,248 @@ +from __future__ import annotations +from pathlib import Path +import cv2 +from pydantic import ConfigDict +import torch.multiprocessing as mp +import albumentations as A +import numpy as np +import time + +from deeplabcut.compat import _update_device +from deeplabcut.pose_estimation_pytorch.apis.videos import ( + VideoIterator, + _generate_metadata, + video_inference, +) +import deeplabcut.pose_estimation_pytorch.apis.utils as utils +from deeplabcut.pose_estimation_pytorch.apis.videos import ( + _validate_destfolder, +) +import deeplabcut.pose_estimation_pytorch.runners.shelving as shelving +from deeplabcut.core.engine import Engine +from deeplabcut.pose_estimation_pytorch.runners import DynamicCropper +from deeplabcut.pose_estimation_pytorch.task import Task +from deeplabcut.utils import auxiliaryfunctions + +from skellytracker.trackers.base_tracker.base_tracker_abcs import BaseDetectorConfig, BaseDetector +from skellytracker.trackers.dlc_tracker.dlc_observation import DeepLabCutObservation + +class DeepLabCutDetectorConfig(BaseDetectorConfig): + model_config = ConfigDict(arbitrary_types_allowed=True) + dlc_config: str + videotype: str = "" + shuffle: int = 1 + trainingsetindex: int = 0 + gputouse: int | None = None + save_as_csv: bool = False + destfolder: str | None = None + cropping: list[int] | None = None + dynamic: tuple[bool, float, int] = (False, 0.5, 10) + modelprefix: str = "" + robust_nframes: bool = False + use_shelve: bool = False + auto_track: bool = True + n_tracks: int | None = None + animal_names: list[str] | None = None + identity_only: bool = False + snapshot_index: int | str | None = None + detector_snapshot_index: int | str | None = None + device: str | None = None + batch_size: int | None = None + detector_batch_size: int | None = None + transform: A.Compose | None = None + overwrite: bool = False + save_as_df: bool = False + + @classmethod + def from_config(cls, config_path: str | Path): + config = auxiliaryfunctions.read_config(config_path) + return cls(**config) # TODO: test this + + +class DeepLabCutDetector(BaseDetector): + config: DeepLabCutDetectorConfig + + @classmethod + def create(cls, config: DeepLabCutDetectorConfig): + return cls( + config=config, + ) + + def detect(self, frame_number: int, image: np.ndarray) -> DeepLabCutObservation: + raise NotImplementedError( + "This detector does not support processing individual images, please use detect_video instead." + ) + + # TODO: get point names from dlc config + + def detect_video( + self, + video_path: str | Path, + **torch_kwargs, + ) -> list[DeepLabCutObservation]: + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + _update_device(self.config.gputouse, torch_kwargs) + + video = Path(video_path) + + # Create the output folder + _validate_destfolder(self.config.destfolder) + + # Load the project configuration + cfg = auxiliaryfunctions.read_config(self.config.dlc_config) + project_path = Path(cfg["project_path"]) + train_fraction = cfg["TrainingFraction"][self.config.trainingsetindex] + model_folder = project_path / auxiliaryfunctions.get_model_folder( + train_fraction, + self.config.shuffle, + cfg, + modelprefix=self.config.modelprefix, + engine=Engine.PYTORCH, + ) + train_folder = model_folder / "train" + + # Read the inference configuration, load the model + model_cfg_path = train_folder / Engine.PYTORCH.pose_cfg_name + model_cfg = auxiliaryfunctions.read_plainconfig(model_cfg_path) + pose_task = Task(model_cfg["method"]) + + pose_cfg_path = model_folder / "test" / "pose_cfg.yaml" + pose_cfg = auxiliaryfunctions.read_plainconfig(pose_cfg_path) + + snapshot_index, detector_snapshot_index = utils.parse_snapshot_index_for_analysis( + cfg, + model_cfg, + self.config.snapshot_index, + self.config.detector_snapshot_index, + ) + + if self.config.cropping is None and cfg.get("cropping", False): + self.config.cropping = [cfg["x1"], cfg["x2"], cfg["y1"], cfg["y2"]] + + # Get general project parameters + multi_animal = cfg["multianimalproject"] + bodyparts = model_cfg["metadata"]["bodyparts"] + unique_bodyparts = model_cfg["metadata"]["unique_bodyparts"] + individuals = model_cfg["metadata"]["individuals"] + max_num_animals = len(individuals) + + if self.config.device is not None: + model_cfg["device"] = self.config.device + + if self.config.batch_size is None: + batch_size = cfg.get("batch_size", 1) + + if not multi_animal: + save_as_df = True + if self.config.use_shelve: + print( + "The ``use_shelve`` parameter cannot be used for single animal " + "projects. Setting ``use_shelve=False``." + ) + self.config.use_shelve = False + + dynamic = DynamicCropper.build(*self.config.dynamic) + if pose_task != Task.BOTTOM_UP and dynamic is not None: + print( + "Turning off dynamic cropping. It should only be used for bottom-up " + f"pose estimation models, but you are using a top-down model." + ) + dynamic = None + + snapshot = utils.get_model_snapshots(snapshot_index, train_folder, pose_task)[0] + print(f"Analyzing videos with {snapshot.path}") + pose_runner = utils.get_pose_inference_runner( + model_config=model_cfg, + snapshot_path=snapshot.path, + max_individuals=max_num_animals, + batch_size=batch_size, + transform=self.config.transform, + dynamic=dynamic, + ) + detector_runner = None + + detector_path, detector_snapshot = None, None + if pose_task == Task.TOP_DOWN: + if detector_snapshot_index is None: + raise ValueError( + "Cannot run videos analysis for top-down models without a detector " + "snapshot! Please specify your desired detector_snapshotindex in your " + "project's configuration file." + ) + + if self.config.detector_batch_size is None: + detector_batch_size = cfg.get("detector_batch_size", 1) + + detector_snapshot = utils.get_model_snapshots( + detector_snapshot_index, train_folder, Task.DETECT + )[0] + print(f" -> Using detector {detector_snapshot.path}") + detector_runner = utils.get_detector_inference_runner( + model_config=model_cfg, + snapshot_path=detector_snapshot.path, + max_individuals=max_num_animals, + batch_size=detector_batch_size, + ) + + dlc_scorer = utils.get_scorer_name( + cfg, + self.config.shuffle, + train_fraction, + snapshot_uid=utils.get_scorer_uid(snapshot, detector_snapshot), + modelprefix=self.config.modelprefix, + ) + if self.config.destfolder is None: + output_path = video.parent + else: + output_path = Path(self.config.destfolder) + + output_prefix = video.stem + dlc_scorer + output_pkl = output_path / f"{output_prefix}_full.pickle" + + video_iterator = VideoIterator(video, cropping=self.config.cropping) + + image_size = video_iterator.video.get(cv2.CAP_PROP_FRAME_WIDTH), video_iterator.video.get(cv2.CAP_PROP_FRAME_HEIGHT) + + shelf_writer = None + if self.config.use_shelve: + shelf_writer = shelving.ShelfWriter( + pose_cfg=pose_cfg, + filepath=output_pkl, + num_frames=video_iterator.get_n_frames(robust=self.config.robust_nframes), + ) + + runtime = [time.time()] + predictions = video_inference( + video=video_iterator, + pose_runner=pose_runner, + detector_runner=detector_runner, + shelf_writer=shelf_writer, + robust_nframes=self.config.robust_nframes, + ) + + runtime.append(time.time()) + metadata = _generate_metadata( + cfg=cfg, + pytorch_config=model_cfg, + dlc_scorer=dlc_scorer, + train_fraction=train_fraction, + batch_size=batch_size, + cropping=self.config.cropping, + runtime=(runtime[0], runtime[1]), + video=video_iterator, + robust_nframes=self.config.robust_nframes, + ) + + # TODO: how to save out metadata? + return [ + DeepLabCutObservation.from_detection_results(frame_number=i, pose_prediction=prediction, image_size=image_size) + for i, prediction in enumerate(predictions) + ] + + +if __name__=="__main__": + config = DeepLabCutDetectorConfig(dlc_config="") \ No newline at end of file diff --git a/skellytracker/trackers/dlc_tracker/dlc_observation.py b/skellytracker/trackers/dlc_tracker/dlc_observation.py new file mode 100644 index 0000000..11f068d --- /dev/null +++ b/skellytracker/trackers/dlc_tracker/dlc_observation.py @@ -0,0 +1,41 @@ +from typing import NamedTuple + +import numpy as np +from numpydantic import NDArray, Shape + +from skellytracker.trackers.base_tracker.base_tracker_abcs import BaseObservation, TrackerTypeString, TrackedPoint2d + +class DeepLabCutObservation(BaseObservation): + tracker_type:TrackerTypeString = 'dlc_tracker' + frame_number: int # the frame number of the image in which this observation was made + pose_points: np.ndarray # Num Markers x 2 + confidence_values: np.ndarray # num markers + image_size: tuple[int, int] + + @classmethod + def from_detection_results(cls, + frame_number: int, + pose_prediction: dict[str, NDArray[Shape["1,N,3"], float]], + image_size: tuple[int, int]): + # TODO: this will not work for multi animal dlc models + prediction_values = pose_prediction["bodyparts"].squeeze() + return cls( + frame_number=frame_number, + pose_points=prediction_values[:, :2], + confidence_values=prediction_values[:, 2], + image_size=image_size + ) + + def to_tracked_points(self) -> dict[str, TrackedPoint2d]: + tracked_points_dict = {} + for i in range(self.pose_points.shape[0]): + tracked_points_dict[f"Point-{i}"] = self.pose_points[i, :2] + tracked_points_dict[f"Point-{i}-Confidence"] = self.confidence_values[i] + + return tracked_points_dict + + def to_array(self) -> NDArray[Shape["N, 2"], float]: + return self.pose_points + + +DLCObservations = list[DeepLabCutObservation]