diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..6fea3a1b0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.idea/ +.vscode/ +*.py[cod] + +configs/custom +models diff --git a/slowfast/config/defaults.py b/slowfast/config/defaults.py index 718801a92..6d4efc191 100644 --- a/slowfast/config/defaults.py +++ b/slowfast/config/defaults.py @@ -705,6 +705,9 @@ # The number of overlapping frames cannot be larger than # half of the sequence length `cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE` _C.DEMO.BUFFER_SIZE = 0 +# Display the output prediction onto the input video +# If disabled, output predictions will be logged instead of displaying the video +_C.DEMO.OUTPUT_DISPLAY = True # If specified, the visualized outputs will be written this a video file of # this path. Otherwise, the visualized outputs will be displayed in a window. _C.DEMO.OUTPUT_FILE = "" diff --git a/slowfast/utils/logging.py b/slowfast/utils/logging.py index b26a53c02..8a27a2b29 100644 --- a/slowfast/utils/logging.py +++ b/slowfast/utils/logging.py @@ -64,11 +64,16 @@ def setup_logging(output_dir=None): logger.addHandler(ch) if output_dir is not None and du.is_master_proc(du.get_world_size()): - filename = os.path.join(output_dir, "stdout.log") - fh = logging.StreamHandler(_cached_log_stream(filename)) - fh.setLevel(logging.DEBUG) - fh.setFormatter(plain_formatter) - logger.addHandler(fh) + setup_file_logger(logger, output_dir, "stdout.log", plain_formatter) + + +def setup_file_logger(logger, output_dir, file_name, formatter=None): + filename = os.path.join(output_dir, file_name) + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter or logging.Formatter("%(message)s")) + logger.addHandler(fh) + logger.setLevel(logging.DEBUG) def get_logger(name): diff --git a/slowfast/visualization/async_predictor.py b/slowfast/visualization/async_predictor.py index bb11ef15b..66dd4ed4f 100644 --- a/slowfast/visualization/async_predictor.py +++ b/slowfast/visualization/async_predictor.py @@ -125,17 +125,19 @@ def default_buffer_size(self): class AsyncVis: class _VisWorker(mp.Process): - def __init__(self, video_vis, task_queue, result_queue): + def __init__(self, video_vis, task_queue, result_queue, prediction_processor): """ Visualization Worker for AsyncVis. Args: video_vis (VideoVisualizer object): object with tools for visualization. task_queue (mp.Queue): a shared queue for incoming task for visualization. result_queue (mp.Queue): a shared queue for visualized results. + prediction_processor (func): function that processes frames using (task, video_vis) inputs """ self.video_vis = video_vis self.task_queue = task_queue self.result_queue = result_queue + self.process_predictions = prediction_processor super().__init__() def run(self): @@ -147,17 +149,20 @@ def run(self): if isinstance(task, _StopToken): break - frames = draw_predictions(task, self.video_vis) + frames = self.process_predictions(task, self.video_vis) task.frames = np.array(frames) self.result_queue.put(task) - def __init__(self, video_vis, n_workers=None): + def __init__(self, video_vis, n_workers=None, prediction_processor=None): """ Args: cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py n_workers (Optional[int]): number of CPUs for running video visualizer. If not given, use all CPUs. + prediction_processor (func): + function that processes frames using (task, video_vis) inputs + passed down to video visualizer. """ num_workers = mp.cpu_count() if n_workers is None else n_workers @@ -168,10 +173,11 @@ def __init__(self, video_vis, n_workers=None): self.procs = [] self.result_data = {} self.put_id = -1 + predictor = prediction_processor or draw_predictions for _ in range(max(num_workers, 1)): self.procs.append( AsyncVis._VisWorker( - video_vis, self.task_queue, self.result_queue + video_vis, self.task_queue, self.result_queue, predictor ) ) @@ -309,10 +315,66 @@ def draw_predictions(task, video_vis): boxes, keyframe_idx=keyframe_idx, draw_range=draw_range, + task_id=task.id, ) else: frames = video_vis.draw_clip_range( - frames, preds, keyframe_idx=keyframe_idx, draw_range=draw_range + frames, + preds, + keyframe_idx=keyframe_idx, + draw_range=draw_range, + task_id=task.id, + ) + del task + + return buffer + frames + + +def log_predictions(task, video_vis): + """ + Log prediction for the given task. + Args: + task (TaskInfo object): task object that contain + the necessary information for logging. (e.g. frames, preds) + All attributes must lie on CPU devices. + video_vis (VideoVisualizer object): the video visualizer object. + """ + boxes = task.bboxes + frames = task.frames + preds = task.action_preds + if boxes is not None: + img_width = task.img_width + img_height = task.img_height + if boxes.device != torch.device("cpu"): + boxes = boxes.cpu() + boxes = cv2_transform.revert_scaled_boxes( + task.crop_size, boxes, img_height, img_width + ) + + keyframe_idx = len(frames) // 2 - task.num_buffer_frames + draw_range = [ + keyframe_idx - task.clip_vis_size, + keyframe_idx + task.clip_vis_size, + ] + buffer = frames[: task.num_buffer_frames] + frames = frames[task.num_buffer_frames :] + if boxes is not None: + if len(boxes) != 0: + frames = video_vis.draw_clip_range( + frames, + preds, + boxes, + keyframe_idx=keyframe_idx, + draw_range=draw_range, + task_id=task.id, + ) + else: + frames = video_vis.draw_clip_range( + frames, + preds, + keyframe_idx=keyframe_idx, + draw_range=draw_range, + task_id=task.id, ) del task diff --git a/slowfast/visualization/predictor.py b/slowfast/visualization/predictor.py index 3007aa58f..0e989e2df 100644 --- a/slowfast/visualization/predictor.py +++ b/slowfast/visualization/predictor.py @@ -33,6 +33,8 @@ def __init__(self, cfg, gpu_id=None): self.gpu_id = ( torch.cuda.current_device() if gpu_id is None else gpu_id ) + else: + self.gpu_id = None # Build the video model and print model statistics. self.model = build_model(cfg, gpu_id=gpu_id) diff --git a/slowfast/visualization/video_visualizer.py b/slowfast/visualization/video_visualizer.py index faa127294..6fd29f051 100644 --- a/slowfast/visualization/video_visualizer.py +++ b/slowfast/visualization/video_visualizer.py @@ -13,6 +13,7 @@ logger = logging.get_logger(__name__) log.getLogger("matplotlib").setLevel(log.ERROR) +pred_log = logging.get_logger("slowfast-predictions") def _create_text_labels(classes, scores, class_names, ground_truth=False): @@ -521,6 +522,7 @@ def draw_clip_range( keyframe_idx=None, draw_range=None, repeat_frame=1, + task_id=None, ): """ Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip @@ -537,6 +539,7 @@ def draw_clip_range( draw_range (Optional[list[ints]): only draw frames in range [start_idx, end_idx] inclusively in the clip. If None, draw on the entire clip. repeat_frame (int): repeat each frame in draw_range for `repeat_frame` time for slow-motion effect. + task_id (int): reference index of the task where frames and predictions originated from. """ if draw_range is None: draw_range = [0, len(frames) - 1] @@ -675,3 +678,92 @@ def _get_thres_array(self, common_class_names=None): ) thres_array[common_class_ids] = self.thres self.thres = torch.from_numpy(thres_array) + + +class VideoLogger(VideoVisualizer): + """ + Log predictions to file instead of drawing onto output video frames. + + Core is identical to `VideoVisualizer`. Override draw method to log. + """ + def __init__(self, *_, **__): + super(VideoLogger, self).__init__(*_, **__) + self.clip_index = -1 + + def draw_clip_range( + self, + frames, + preds, + bboxes=None, + text_alpha=0.5, + ground_truth=False, + keyframe_idx=None, + draw_range=None, + repeat_frame=1, + task_id=None, + ): + self.clip_index = task_id or self.clip_index + 1 + num_frames = len(frames) + start_frame = self.clip_index * num_frames + frame_range = [start_frame, start_frame + num_frames - 1] + + if isinstance(preds, torch.Tensor): + if preds.ndim == 1: + preds = preds.unsqueeze(0) + n_instances = preds.shape[0] + elif isinstance(preds, list): + n_instances = len(preds) + else: + logger.error("Unsupported type of prediction input.") + return + + if ground_truth: + method = "ground-truth" + top_scores, top_classes = [None] * n_instances, preds + elif self.mode == "top-k": + method = "top-k={}".format(self.top_k) + top_scores, top_classes = torch.topk(preds, k=self.top_k) + top_scores, top_classes = top_scores.tolist(), top_classes.tolist() + elif self.mode == "thres": + method = "thres>={}".format(self.thres) + top_scores, top_classes = [], [] + for pred in preds: + mask = pred >= self.thres + top_scores.append(pred[mask].tolist()) + top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist() + top_classes.append(top_class) + else: + logger.error("Unknown mode: %s", self.mode) + return + + text_labels = [] + for i in range(n_instances): + text_labels.append( + _create_text_labels( + top_classes[i], + top_scores[i], + self.class_names, + ground_truth=ground_truth, + ) + ) + + frames_info = "{:04d} [{:08d}, {:08d}]:".format(self.clip_index, frame_range[0], frame_range[1]) + if bboxes is not None: + assert len(preds) == len(bboxes), \ + "Encounter {} predictions and {} bounding boxes".format(len(preds), len(bboxes)) + pred_log.info(frames_info) + for i, box in enumerate(bboxes): + top_labels = [self.class_names[i] for i in top_classes[i]] + txt_scores = [float("{:.4f}".format(float(score))) for score in top_scores[i]] + label = " labeled '{}'".format(text_labels[i]) if ground_truth else "" + text_box = "bbox: {},".format(list(float("{:04.2f}".format(float(c))) for c in list(box))) + pred_log.info(" %s%s is predicted to class %s, %s: %s, %s", + text_box, label, text_labels[i][0], method, top_labels, txt_scores) + else: + label = " labeled '{}'".format(text_labels[0]) if ground_truth else "" + top_labels = [self.class_names[i] for i in top_classes[0]] + txt_scores = [float("{:.4f}".format(float(score))) for score in top_scores[0]] + pred_log.info("%s%s is predicted to class %s, %s: %s, %s", + frames_info, label, text_labels[0], method, top_labels, txt_scores) + + return [] # drop frames to speed up process (no writing) diff --git a/tools/demo_net.py b/tools/demo_net.py index a7e98ebde..405dc2517 100644 --- a/tools/demo_net.py +++ b/tools/demo_net.py @@ -5,15 +5,17 @@ import time import torch import tqdm +import os from slowfast.utils import logging from slowfast.visualization.async_predictor import AsyncDemo, AsyncVis from slowfast.visualization.ava_demo_precomputed_boxes import ( AVAVisualizerWithPrecomputedBox, ) +from slowfast.visualization.async_predictor import draw_predictions, log_predictions from slowfast.visualization.demo_loader import ThreadVideoManager, VideoManager from slowfast.visualization.predictor import ActionPredictor -from slowfast.visualization.video_visualizer import VideoVisualizer +from slowfast.visualization.video_visualizer import VideoVisualizer, VideoLogger logger = logging.get_logger(__name__) @@ -42,7 +44,16 @@ def run_demo(cfg, frame_provider): else None ) - video_vis = VideoVisualizer( + if not cfg.DEMO.OUTPUT_DISPLAY: + video_vis_cls = VideoLogger + pred_processor = log_predictions + pred_log = logging.get_logger("slowfast-predictions") + logging.setup_file_logger(pred_log, cfg.OUTPUT_DIR, "predictions.log") + else: + video_vis_cls = VideoVisualizer + pred_processor = draw_predictions + + video_vis = video_vis_cls( num_classes=cfg.MODEL.NUM_CLASSES, class_names_path=cfg.DEMO.LABEL_FILE_PATH, top_k=cfg.TENSORBOARD.MODEL_VIS.TOPK_PREDS, @@ -52,8 +63,11 @@ def run_demo(cfg, frame_provider): colormap=cfg.TENSORBOARD.MODEL_VIS.COLORMAP, mode=cfg.DEMO.VIS_MODE, ) - - async_vis = AsyncVis(video_vis, n_workers=cfg.DEMO.NUM_VIS_INSTANCES) + async_vis = AsyncVis( + video_vis, + n_workers=cfg.DEMO.NUM_VIS_INSTANCES, + prediction_processor=pred_processor, + ) if cfg.NUM_GPUS <= 1: model = ActionPredictor(cfg=cfg, async_vis=async_vis)