diff --git a/modules/detect_target/detect_target_factory.py b/modules/detect_target/detect_target_factory.py index eb0f9817..ca37a14b 100644 --- a/modules/detect_target/detect_target_factory.py +++ b/modules/detect_target/detect_target_factory.py @@ -6,6 +6,7 @@ from . import base_detect_target from . import detect_target_ultralytics +from ..common.modules.logger import logger class DetectTargetOption(enum.Enum): @@ -21,6 +22,7 @@ def create_detect_target( device: "str | int", model_path: str, override_full: bool, + local_logger: logger.Logger, show_annotations: bool, save_name: str, ) -> tuple[bool, base_detect_target.BaseDetectTarget | None]: @@ -33,6 +35,7 @@ def create_detect_target( device, model_path, override_full, + local_logger, show_annotations, save_name, ) diff --git a/modules/detect_target/detect_target_ultralytics.py b/modules/detect_target/detect_target_ultralytics.py index 975edd75..4f9bddc1 100644 --- a/modules/detect_target/detect_target_ultralytics.py +++ b/modules/detect_target/detect_target_ultralytics.py @@ -10,6 +10,7 @@ from . import base_detect_target from .. import image_and_time from .. import detections_and_time +from ..common.modules.logger import logger class DetectTargetUltralytics(base_detect_target.BaseDetectTarget): @@ -22,6 +23,7 @@ def __init__( device: "str | int", model_path: str, override_full: bool, + local_logger: logger.Logger, show_annotations: bool = False, save_name: str = "", ) -> None: @@ -36,6 +38,7 @@ def __init__( self.__model = ultralytics.YOLO(model_path) self.__counter = 0 self.__enable_half_precision = not self.__device == "cpu" + self.__local_logger = local_logger self.__show_annotations = show_annotations if override_full: self.__enable_half_precision = False @@ -54,6 +57,8 @@ def run( Return: Success and the detections. """ image = data.image + start_time = time.time() + predictions = self.__model.predict( source=image, half=self.__enable_half_precision, @@ -91,15 +96,15 @@ def run( detections.append(detection) + end_time = time.time() + self.__local_logger.info( + f"{time.time()}: Count: {self.__counter}. Target detection took {end_time - start_time} seconds. Objects detected: {detections}." + ) + # Logging if self.__filename_prefix != "": filename = self.__filename_prefix + str(self.__counter) - # Object detections - with open(filename + ".txt", "w", encoding="utf-8") as file: - # Use internal string representation - file.write(repr(detections)) - # Annotated image cv2.imwrite(filename + ".png", image_annotated) # type: ignore diff --git a/modules/detect_target/detect_target_worker.py b/modules/detect_target/detect_target_worker.py index 7d03b9fa..8ce5cbc4 100644 --- a/modules/detect_target/detect_target_worker.py +++ b/modules/detect_target/detect_target_worker.py @@ -47,6 +47,7 @@ def detect_target_worker( device, model_path, override_full, + local_logger, show_annotations, save_name, ) diff --git a/modules/detection_in_world.py b/modules/detection_in_world.py index a5c22943..cc96e855 100644 --- a/modules/detection_in_world.py +++ b/modules/detection_in_world.py @@ -58,3 +58,9 @@ def __str__(self) -> str: To string. """ return f"{self.__class__}, vertices: {self.vertices.tolist()}, centre: {self.centre}, label: {self.label}, confidence: {self.confidence}" + + def __repr__(self) -> str: + """ + For collections (e.g. list). + """ + return str(self) diff --git a/modules/detections_and_time.py b/modules/detections_and_time.py index fcedcd9e..75df2e02 100644 --- a/modules/detections_and_time.py +++ b/modules/detections_and_time.py @@ -57,6 +57,12 @@ def __str__(self) -> str: """ return f"cls: {self.label}, conf: {self.confidence}, bounds: {self.x_1} {self.y_1} {self.x_2} {self.y_2}" + def __repr__(self) -> str: + """ + For collections (e.g. list). + """ + return str(self) + def get_centre(self) -> "tuple[float, float]": """ Gets the xy centre of the bounding box. diff --git a/tests/unit/test_detect_target_ultralytics.py b/tests/unit/test_detect_target_ultralytics.py index dab5a771..b7159086 100644 --- a/tests/unit/test_detect_target_ultralytics.py +++ b/tests/unit/test_detect_target_ultralytics.py @@ -13,6 +13,7 @@ from modules.detect_target import detect_target_ultralytics from modules import image_and_time from modules import detections_and_time +from modules.common.modules.logger import logger TEST_PATH = pathlib.Path("tests", "model_example") @@ -108,8 +109,13 @@ def detector() -> detect_target_ultralytics.DetectTargetUltralytics: # type: ig """ Construct DetectTargetUltralytics. """ + result, test_logger = logger.Logger.create("test_logger", False) + + assert result + assert test_logger is not None + detection = detect_target_ultralytics.DetectTargetUltralytics( - DEVICE, str(MODEL_PATH), OVERRIDE_FULL + DEVICE, str(MODEL_PATH), OVERRIDE_FULL, test_logger ) yield detection # type: ignore