From 3ba12a70250cfbea57484e97eab91137620a9a8a Mon Sep 17 00:00:00 2001 From: Kin Date: Sat, 10 Jan 2026 18:40:25 +0100 Subject: [PATCH 1/2] chore(visualization): refactor the open3d visualization, merge fn together. * reset fire class uage directly. * add save screenshot easily with multi-view. * sync view point through diff windows. * visual lidar center tf if set slc to True. --- README.md | 14 +- src/dataset.py | 2 +- src/trainer.py | 2 +- src/utils/mics.py | 72 ++++++ src/utils/o3d_view.py | 502 ++++++++++++++++++++++++-------------- tools/README.md | 28 ++- tools/visualization.py | 537 +++++++++++++++++++++++++++-------------- 7 files changed, 779 insertions(+), 378 deletions(-) diff --git a/README.md b/README.md index 3f3d59f..3603a46 100644 --- a/README.md +++ b/README.md @@ -301,7 +301,7 @@ pip install "evalai" evalai set-token # Step 3: Copy the command pop above and submit to leaderboard -evalai challenge 2010 phase 4018 submit --file av2_submit.zip --large --private +# evalai challenge 2010 phase 4018 submit --file av2_submit.zip --large --private evalai challenge 2210 phase 4396 submit --file av2_submit_v2.zip --large --private ``` @@ -318,21 +318,24 @@ python eval.py model=nsfp dataset_path=/home/kin/data/av2/h5py/demo/val # The output of above command will be like: Model: DeFlow, Checkpoint from: /home/kin/model_zoo/v2/seflow_best.ckpt We already write the flow_est into the dataset, please run following commend to visualize the flow. Copy and paste it to your terminal: -python tools/visualization.py --res_name 'seflow_best' --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis +python tools/visualization.py vis --res_name 'seflow_best' --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis Enjoy! ^v^ ------ # Then run the command in the terminal: -python tools/visualization.py --res_name 'seflow_best' --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis +python tools/visualization.py vis --res_name 'seflow_best' --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis ``` https://github.com/user-attachments/assets/f031d1a2-2d2f-4947-a01f-834ed1c146e6 For exporting easy comparsion with ground truth and other methods, we also provided multi-visulization open3d window: ```bash -python tools/visualization.py --mode mul --res_name "['flow', 'seflow_best']" --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis +python tools/visualization.py vis --res_name "['flow', 'seflow_best']" --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis ``` -Or another way to interact with [rerun](https://github.com/rerun-io/rerun) but please only vis scene by scene, not all at once. +**Tips**: To quickly create qualitative results for all methods, you can use multiple results comparison mode, select a good viewpoint and then save screenshots for all frames by pressing `P` key. You will found all methods' results are saved in the output folder (default is `logs/imgs`). Enjoy it! + + +_Rerun_: Another way to interact with [rerun](https://github.com/rerun-io/rerun) but please only vis scene by scene, not all at once. ```bash python tools/visualization_rerun.py --data_dir /home/kin/data/av2/h5py/demo/train --res_name "['flow', 'deflow']" @@ -340,7 +343,6 @@ python tools/visualization_rerun.py --data_dir /home/kin/data/av2/h5py/demo/trai https://github.com/user-attachments/assets/07e8d430-a867-42b7-900a-11755949de21 - ## Cite Us [*OpenSceneFlow*](https://github.com/KTH-RPL/OpenSceneFlow) is originally designed by [Qingwen Zhang](https://kin-zhang.github.io/) from DeFlow and SeFlow. diff --git a/src/dataset.py b/src/dataset.py index dd04995..9d431b9 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -347,7 +347,7 @@ def __getitem__(self, index_): data_dict[f'gmh{i+1}'] = past_gm data_dict[f'poseh{i+1}'] = past_pose - for data_key in self.vis_name + ['ego_motion', 'lidar_dt', + for data_key in self.vis_name + ['ego_motion', 'lidar_dt', 'lidar_center', # ground truth information: 'flow', 'flow_is_valid', 'flow_category_indices', 'flow_instance_id', 'dufo']: if data_key in f[key]: diff --git a/src/trainer.py b/src/trainer.py index 84b22c5..de064f1 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -243,7 +243,7 @@ def on_validation_epoch_end(self): with open(str(self.save_res_path)+'.pkl', 'wb') as f: pickle.dump((self.metrics.epe_3way, self.metrics.bucketed, self.metrics.epe_ssf), f) print(f"We already write the {self.res_name} into the dataset, please run following commend to visualize the flow. Copy and paste it to your terminal:") - print(f"python tools/visualization.py --res_name '{self.res_name}' --data_dir {self.dataset_path}") + print(f"python tools/visualization.py vis --res_name '{self.res_name}' --data_dir {self.dataset_path}") print(f"Enjoy! ^v^ ------ \n") self.metrics = OfficialMetrics() diff --git a/src/utils/mics.py b/src/utils/mics.py index 2ee5c65..8ba1a8e 100644 --- a/src/utils/mics.py +++ b/src/utils/mics.py @@ -172,6 +172,78 @@ def make_colorwheel(transitions: tuple=DEFAULT_TRANSITIONS) -> np.ndarray: return colorwheel +def error_to_color( + error_magnitude: np.ndarray, + max_error: float, + color_map: str = "jet" +) -> np.ndarray: + """ + Convert flow error to RGB color visualization. + Args: + color_map: Color map to use for visualization ("jet" recommended for error visualization) + + Returns: + RGB color representation of the error of shape (..., 3) + """ + if max_error > 0: + normalized_error = np.clip(error_magnitude / max_error, 0, 1) + else: + normalized_error = np.zeros_like(error_magnitude) + + # Create colormap + if color_map == "jet": + # Simple jet colormap implementation + colors = np.zeros((*normalized_error.shape, 3), dtype=np.uint8) + + # Blue to cyan to green to yellow to red + # Blue (low error) + idx = normalized_error < 0.25 + colors[idx, 2] = 255 + colors[idx, 0] = 0 + colors[idx, 1] = np.uint8(255 * normalized_error[idx] * 4) + + # Cyan to green + idx = (normalized_error >= 0.25) & (normalized_error < 0.5) + colors[idx, 1] = 255 + colors[idx, 0] = 0 + colors[idx, 2] = np.uint8(255 * (1 - (normalized_error[idx] - 0.25) * 4)) + + # Green to yellow + idx = (normalized_error >= 0.5) & (normalized_error < 0.75) + colors[idx, 1] = 255 + colors[idx, 2] = 0 + colors[idx, 0] = np.uint8(255 * (normalized_error[idx] - 0.5) * 4) + + # Yellow to red (high error) + idx = normalized_error >= 0.75 + colors[idx, 0] = 255 + colors[idx, 2] = 0 + colors[idx, 1] = np.uint8(255 * (1 - (normalized_error[idx] - 0.75) * 4)) + + elif color_map == "hot": + # Hot colormap: black -> red -> yellow -> white + colors = np.zeros((*normalized_error.shape, 3), dtype=np.uint8) + + # Black to red + idx = normalized_error < 0.33 + colors[idx, 0] = np.uint8(255 * normalized_error[idx] * 3) + + # Red to yellow + idx = (normalized_error >= 0.33) & (normalized_error < 0.67) + colors[idx, 0] = 255 + colors[idx, 1] = np.uint8(255 * (normalized_error[idx] - 0.33) * 3) + + # Yellow to white + idx = normalized_error >= 0.67 + colors[idx, 0] = 255 + colors[idx, 1] = 255 + colors[idx, 2] = np.uint8(255 * (normalized_error[idx] - 0.67) * 3) + + else: + raise ValueError(f"Unsupported color map: {color_map}. Use 'jet' or 'hot'.") + + return colors + def flow_to_rgb( flow: np.ndarray, flow_max_radius: Optional[float]=None, diff --git a/src/utils/o3d_view.py b/src/utils/o3d_view.py index e5de457..ea2316f 100644 --- a/src/utils/o3d_view.py +++ b/src/utils/o3d_view.py @@ -1,219 +1,349 @@ -''' -# @date: 2023-1-26 16:38 -# @author: Qingwen Zhang (https://kin-zhang.github.io/), Ajinkya Khoche (https://ajinkyakhoche.github.io/) -# Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology -# @detail: -# 1. Play the data you want in open3d, and save the view control to json file. -# 2. Use the json file to view the data again. -# 3. Save the screen shot and view file for later check and animation. -# -# code gits: https://gist.github.com/Kin-Zhang/77e8aa77a998f1a4f7495357843f24ef -# -# CHANGELOG: -# 2024-08-23 21:41(Qingwen): remove totally on view setting from scratch but use open3d>=0.18.0 version for set_view from json text func. -# 2024-04-15 12:06(Qingwen): show a example json text. add hex_to_rgb, color_map_hex, color_map (for color points if needed) -# 2024-01-27 0:41(Qingwen): update MyVisualizer class, reference from kiss-icp: https://github.com/PRBonn/kiss-icp/blob/main/python/kiss_icp/tools/visualizer.py -# 2024-09-10 (Ajinkya): Add MyMultiVisualizer class to view multiple windows at once, allow forward and backward playback, create bev square for giving a sense of metric scale. -''' +""" +Open3D Visualizer for Scene Flow +================================ +@date: 2023-1-26 16:38 +@author: Qingwen Zhang (https://kin-zhang.github.io/), Ajinkya Khoche (https://ajinkyakhoche.github.io/) +Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology + +# This file is part of OpenSceneFlow (https://github.com/KTH-RPL/OpenSceneFlow). +# If you find this repo helpful, please cite the respective publication as +# listed on the above website. + +Features: + - Single or multi-window visualization + - Viewpoint sync across windows (press S) + - Forward/backward playback + - Screenshot and viewpoint save + +CHANGELOG: +2026-01-10 (Qingwen): Unified single/multi visualizer, added viewpoint sync with S key +2024-09-10 (Ajinkya): Add multi-window support, forward/backward playback +2024-08-23 (Qingwen): Use open3d>=0.18.0 set_view_status API +""" import open3d as o3d -import os, time -from typing import List, Callable +import os +import time +from typing import List, Callable, Union from functools import partial import numpy as np -def hex_to_rgb(hex_color): + +def hex_to_rgb(hex_color: str) -> tuple: + """Convert hex color string to RGB tuple (0-1 range).""" hex_color = hex_color.lstrip("#") return tuple(int(hex_color[i:i + 2], 16) / 255.0 for i in (0, 2, 4)) -color_map_hex = ['#a6cee3', '#de2d26', '#1f78b4','#b2df8a','#33a02c','#fb9a99','#e31a1c','#fdbf6f','#ff7f00',\ - '#cab2d6','#6a3d9a','#ffff99','#b15928', '#8dd3c7','#ffffb3','#bebada','#fb8072','#80b1d3',\ - '#fdb462','#b3de69','#fccde5','#d9d9d9','#bc80bd','#ccebc5','#ffed6f'] - -color_map = [hex_to_rgb(color) for color in color_map_hex] +COLOR_MAP_HEX = [ + '#a6cee3', '#de2d26', '#1f78b4', '#b2df8a', '#33a02c', '#fb9a99', '#e31a1c', + '#fdbf6f', '#ff7f00', '#cab2d6', '#6a3d9a', '#ffff99', '#b15928', '#8dd3c7', + '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', + '#d9d9d9', '#bc80bd', '#ccebc5', '#ffed6f' +] +color_map = [hex_to_rgb(c) for c in COLOR_MAP_HEX] + + +class O3DVisualizer: + """ + Unified Open3D visualizer supporting single or multiple windows. + + Args: + view_file: Path to JSON file with saved viewpoint + res_name: Single name or list of names for windows + save_folder: Folder to save screenshots + point_size: Point size for rendering + bg_color: Background color as RGB tuple (0-1 range) + screen_width: Screen width for multi-window layout + screen_height: Screen height for multi-window layout + + Usage: + # Single window + viz = O3DVisualizer(res_name="flow") + + # Multiple windows + viz = O3DVisualizer(res_name=["flow", "flow_est"]) + """ + + def __init__( + self, + view_file: str = None, + res_name: Union[str, List[str]] = "flow", + save_folder: str = "logs/imgs", + point_size: float = 3.0, + bg_color: tuple = (80/255, 90/255, 110/255), + screen_width: int = 1375, + screen_height: int = 2500, + ): + # Normalize res_name to list + self.res_names = [res_name] if isinstance(res_name, str) else list(res_name) + self.num_windows = len(self.res_names) -class MyVisualizer: - def __init__(self, view_file=None, window_title="Default", save_folder="logs/imgs"): - self.params = None - self.vis = o3d.visualization.VisualizerWithKeyCallback() - self.vis.create_window(window_name=window_title) self.view_file = view_file - + self.save_folder = save_folder + self.point_size = point_size + self.bg_color = np.asarray(bg_color) + + os.makedirs(self.save_folder, exist_ok=True) + + # State self.block_vis = True self.play_crun = False self.reset_bounding_box = True - self.save_img_folder = save_folder - os.makedirs(self.save_img_folder, exist_ok=True) + self.playback_direction = 1 # 1: forward, -1: backward + self.curr_index = -1 + self.tmp_value = None + self._should_save = False + self._should_sync = False + self._sync_source_idx = 0 + + # Create windows + self.vis: List[o3d.visualization.VisualizerWithKeyCallback] = [] + self._create_windows(screen_width, screen_height) + self._setup_render_options() + self._register_callbacks() + self._print_help() + + def _create_windows(self, screen_width: int, screen_height: int): + """Create visualizer windows.""" + if self.num_windows == 1: + v = o3d.visualization.VisualizerWithKeyCallback() + title = self._window_title(self.res_names[0]) + v.create_window(window_name=title) + self.vis.append(v) + else: + window_width = screen_width // 2 + window_height = screen_height // 4 + epsilon = 150 + positions = [ + (0, 0), + (screen_width - window_width + epsilon, 0), + (0, screen_height - window_height + epsilon), + (screen_width - window_width + epsilon, screen_height - window_height + epsilon), + ] + for i, name in enumerate(self.res_names): + v = o3d.visualization.VisualizerWithKeyCallback() + title = self._window_title(name) + pos = positions[i % len(positions)] + v.create_window(window_name=title, width=window_width, height=window_height, + left=pos[0], top=pos[1]) + self.vis.append(v) + + def _window_title(self, name: str) -> str: + label = "ground truth flow" if name == "flow" else name + return f"View {label} | SPACE: play/pause" + + def _setup_render_options(self): + """Configure render options for all windows.""" + for v in self.vis: + opt = v.get_render_option() + opt.background_color = self.bg_color + opt.point_size = self.point_size + + def _register_callbacks(self): + """Register keyboard callbacks for all windows.""" + callbacks = [ + (["Ā", "Q", "\x1b"], self._quit), + ([" "], self._start_stop), + (["D"], self._next_frame), + (["A"], self._prev_frame), + (["P"], self._save_screen), + (["E"], self._save_error_bar), + (["S"], self._sync_viewpoint), + ] + for keys, callback in callbacks: + for key in keys: + for idx, v in enumerate(self.vis): + v.register_key_callback(ord(str(key)), partial(callback, src_idx=idx)) + + def _print_help(self): + sync_hint = "[S] sync viewpoint across windows\n" if self.num_windows > 1 else "" print( - f"\n{window_title.capitalize()} initialized. Press:\n" - "\t[SPACE] to pause/start\n" - "\t[ESC/Q] to exit\n" - "\t [P] to save screen and viewpoint\n" - "\t [D] to step next\n" + f"\nVisualizer initialized ({self.num_windows} window(s)). Keys:\n" + f" [SPACE] play/pause [D] next frame [A] prev frame\n" + f" [P] save screenshot [E] save error bar\n" + f" {sync_hint}" + f" [ESC/Q] quit\n" ) - self._register_key_callback(["Ā", "Q", "\x1b"], self._quit) - self._register_key_callback(["P"], self._save_screen) - self._register_key_callback([" "], self._start_stop) - self._register_key_callback(["D"], self._next_frame) - - def show(self, assets: List): - self.vis.clear_geometries() - - for asset in assets: - self.vis.add_geometry(asset) - if self.view_file is not None: - self.vis.set_view_status(open(self.view_file).read()) - - self.vis.update_renderer() - self.vis.poll_events() - self.vis.run() - self.vis.destroy_window() - - def update(self, assets: List, clear: bool = True): - if clear: - self.vis.clear_geometries() - - for asset in assets: - self.vis.add_geometry(asset, reset_bounding_box=False) - self.vis.update_geometry(asset) + # ------------------------------------------------------------------------- + # Public API + # ------------------------------------------------------------------------- + + def update(self, assets: Union[List, List[List]], index: int = -1, value: float = None): + """ + Update visualizer with new assets. + + Args: + assets: For single window - list of geometries + For multi window - list of lists of geometries + index: Current frame index (for screenshot naming) + value: Optional value (e.g., max error for colorbar) + """ + self.curr_index = index + self.tmp_value = value + + # Normalize to list of lists + if self.num_windows == 1: + assets_list = [assets] if not self._is_nested_list(assets) else assets + else: + assets_list = assets + + # Clear and add geometries + for v in self.vis: + v.clear_geometries() + + for i, window_assets in enumerate(assets_list): + if i >= len(self.vis): + break + for asset in window_assets: + self.vis[i].add_geometry(asset, reset_bounding_box=False) + self.vis[i].update_geometry(asset) + + # Reset view on first frame if self.reset_bounding_box: - self.vis.reset_view_point(True) - if self.view_file is not None: - self.vis.set_view_status(open(self.view_file).read()) + for v in self.vis: + v.reset_view_point(True) + if self.view_file is not None: + v.set_view_status(open(self.view_file).read()) self.reset_bounding_box = False - - self.vis.update_renderer() + + # Render and wait + for v in self.vis: + v.update_renderer() + while self.block_vis: - self.vis.poll_events() + for v in self.vis: + v.poll_events() + if self._should_sync: + self._do_sync_viewpoint() + if self._should_save: + self._do_save_screen() if self.play_crun: break + self.block_vis = not self.block_vis - def _register_key_callback(self, keys: List, callback: Callable): - for key in keys: - self.vis.register_key_callback(ord(str(key)), partial(callback)) - - def _next_frame(self, vis): - self.block_vis = not self.block_vis - - def _start_stop(self, vis): - self.play_crun = not self.play_crun - - def _quit(self, vis): + def show(self, assets: List): + """Show assets and run visualization loop (blocking).""" + for v in self.vis: + v.clear_geometries() + for asset in assets: + for v in self.vis: + v.add_geometry(asset) + if self.view_file is not None: + v.set_view_status(open(self.view_file).read()) + for v in self.vis: + v.update_renderer() + v.poll_events() + self.vis[0].run() + for v in self.vis: + v.destroy_window() + + def _is_nested_list(self, obj) -> bool: + """Check if obj is a list of lists.""" + return isinstance(obj, list) and len(obj) > 0 and isinstance(obj[0], list) + + # ------------------------------------------------------------------------- + # Callbacks + # ------------------------------------------------------------------------- + + def _quit(self, vis, src_idx=0): print("Destroying Visualizer. Thanks for using ^v^.") - vis.destroy_window() + for v in self.vis: + v.destroy_window() os._exit(0) - def _save_screen(self, vis): - timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - png_file = f"{self.save_img_folder}/ScreenShot_{timestamp}.png" - view_json_file = f"{self.save_img_folder}/ScreenView_{timestamp}.json" - with open(view_json_file, 'w') as f: - f.write(vis.get_view_status()) - vis.capture_screen_image(png_file) - print(f"ScreenShot saved to: {png_file}, Please check it.") - - -def create_bev_square(size=409.6, color=[68/255,114/255,196/255]): - # Create the vertices of the square - half_size = size / 2.0 - vertices = np.array([ - [-half_size, -half_size, 0], - [half_size, -half_size, 0], - [half_size, half_size, 0], - [-half_size, half_size, 0] - ]) - - # Define the square as a LineSet for visualization - lines = [[0, 1], [1, 2], [2, 3], [3, 0]] - colors = [color for _ in lines] - - line_set = o3d.geometry.LineSet( - points=o3d.utility.Vector3dVector(vertices), - lines=o3d.utility.Vector2iVector(lines) - ) - line_set.colors = o3d.utility.Vector3dVector(colors) - - return line_set - -class MyMultiVisualizer(MyVisualizer): - def __init__(self, view_file=None, flow_mode=['flow'], screen_width=2500, screen_height = 1375): - self.params = None - self.view_file = view_file - self.block_vis = True - self.play_crun = False - self.reset_bounding_box = True - self.playback_direction = 1 # 1:forward, -1:backward - - self.vis = [] - # self.o3d_vctrl = [] - - # Define width and height for each window - window_width = screen_width // 2 - window_height = screen_height // 2 - # Define positions for the four windows - epsilon = 150 - positions = [ - (0, 0), # Top-left - (screen_width - window_width + epsilon, 0), # Top-right - (0, screen_height - window_height + epsilon), # Bottom-left - (screen_width - window_width + epsilon, screen_height - window_height + epsilon) # Bottom-right - ] - - for i, mode in enumerate(flow_mode): - window_title = f"view {'ground truth flow' if mode == 'flow' else f'{mode} flow'}, `SPACE` start/stop" - v = o3d.visualization.VisualizerWithKeyCallback() - v.create_window(window_name=window_title, width=window_width, height=window_height, left=positions[i%len(positions)][0], top=positions[i%len(positions)][1]) - # self.o3d_vctrl.append(ViewControl(v.get_view_control(), view_file=view_file)) - self.vis.append(v) - - self._register_key_callback(["Ā", "Q", "\x1b"], self._quit) - self._register_key_callback([" "], self._start_stop) - self._register_key_callback(["D"], self._next_frame) - self._register_key_callback(["A"], self._prev_frame) - print( - f"\n{window_title.capitalize()} initialized. Press:\n" - "\t[SPACE] to pause/start\n" - "\t[ESC/Q] to exit\n" - "\t [P] to save screen and viewpoint\n" - "\t [D] to step next\n" - "\t [A] to step previous\n" - ) - - def update(self, assets_list: List, clear: bool = True): - if clear: - [v.clear_geometries() for v in self.vis] - - for i, assets in enumerate(assets_list): - [self.vis[i].add_geometry(asset, reset_bounding_box=False) for asset in assets] - self.vis[i].update_geometry(assets[-1]) - - if self.reset_bounding_box: - [v.reset_view_point(True) for v in self.vis] - if self.view_file is not None: - # [o.read_viewTfile(self.view_file) for o in self.o3d_vctrl] - [v.set_view_status(open(self.view_file).read()) for v in self.vis] - self.reset_bounding_box = False - - [v.update_renderer() for v in self.vis] - while self.block_vis: - [v.poll_events() for v in self.vis] - if self.play_crun: - break - self.block_vis = not self.block_vis + def _start_stop(self, vis, src_idx=0): + self.play_crun = not self.play_crun - def _register_key_callback(self, keys: List, callback: Callable): - for key in keys: - [v.register_key_callback(ord(str(key)), partial(callback)) for v in self.vis] - def _next_frame(self, vis): + def _next_frame(self, vis, src_idx=0): self.block_vis = not self.block_vis self.playback_direction = 1 - def _prev_frame(self, vis): + + def _prev_frame(self, vis, src_idx=0): self.block_vis = not self.block_vis self.playback_direction = -1 + def _save_screen(self, vis, src_idx=0): + self._should_save = True + # NOTE: sync viewpoint before saving + self._should_sync = True + return False + + def _sync_viewpoint(self, vis, src_idx=0): + """Sync all windows to the viewpoint of the source window.""" + self._should_sync = True + self._sync_source_idx = src_idx + return False + + def _do_sync_viewpoint(self): + """Actually perform viewpoint sync (called from main loop).""" + if self.num_windows <= 1: + self._should_sync = False + return + + source_view = self.vis[self._sync_source_idx].get_view_status() + for i, v in enumerate(self.vis): + if i != self._sync_source_idx: + v.set_view_status(source_view) + v.update_renderer() + print(f"Synced viewpoint from window {self._sync_source_idx} to all windows.") + self._should_sync = False + + def _do_save_screen(self): + """Save screenshots from all windows.""" + timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + + for i, v in enumerate(self.vis): + v.poll_events() + v.update_renderer() + + name = self.res_names[i] if i < len(self.res_names) else f"window{i}" + prefix = f"{self.curr_index}_{name}" if self.curr_index != -1 else name + png_file = f"{self.save_folder}/{prefix}_{timestamp}.png" + v.capture_screen_image(png_file) + + if i == 0: + view_file = f"{self.save_folder}/{prefix}_{timestamp}.json" + with open(view_file, 'w') as f: + f.write(v.get_view_status()) + + print(f"Screenshots saved to {self.save_folder}/") + self._should_save = False + + def _save_error_bar(self, vis, src_idx=0): + """Save error colorbar as image.""" + if self.tmp_value is None: + print("No error value set, skipping error bar save.") + return + + import matplotlib.pyplot as plt + import matplotlib as mpl + + timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + prefix = f"{self.curr_index}_error" if self.curr_index != -1 else "error" + png_file = f"{self.save_folder}/{prefix}_{timestamp}.png" + + fig, ax = plt.subplots(figsize=(10, 1)) + max_val = self.tmp_value * 100 + norm = mpl.colors.Normalize(vmin=0, vmax=max_val) + cb = mpl.colorbar.ColorbarBase(ax, cmap=plt.cm.hot, norm=norm, orientation='horizontal') + + ticks = np.linspace(0, max_val, 5) + cb.set_ticks(ticks) + cb.set_ticklabels([f"{t:.1f}" for t in ticks]) + cb.set_label('Error Magnitude (cm)') + + plt.savefig(png_file, bbox_inches='tight') + plt.close() + print(f"Error bar saved to: {png_file}") + + +# Backward compatibility aliases; FIXME: remove in near future +MyVisualizer = O3DVisualizer +MyMultiVisualizer = O3DVisualizer + if __name__ == "__main__": json_content = """{ @@ -236,12 +366,12 @@ def _prev_frame(self, vis): "version_minor" : 0 } """ - # write to json file view_json_file = "view.json" with open(view_json_file, 'w') as f: f.write(json_content) + sample_ply_data = o3d.data.PLYPointCloud() pcd = o3d.io.read_point_cloud(sample_ply_data.path) - - viz = MyVisualizer(view_json_file, window_title="Qingwen's View") + + viz = O3DVisualizer(view_json_file, res_name="Demo") viz.show([pcd]) \ No newline at end of file diff --git a/tools/README.md b/tools/README.md index a479138..9ccdb44 100644 --- a/tools/README.md +++ b/tools/README.md @@ -11,18 +11,36 @@ Here we introduce some tools to help you: run `tools/visualization.py` to view the scene flow dataset with ground truth flow. Note the color wheel in under world coordinate. ```bash -# view gt flow -python3 tools/visualization.py --data_dir /home/kin/data/av2/preprocess/sensor/mini --res_name flow +# Visualize flow with color coding +python tools/visualization.py vis --data_dir /path/to/data --res_name flow -# view est flow -python3 tools/visualization.py --data_dir /home/kin/data/av2/preprocess/sensor/mini --res_name deflow_best -python3 tools/visualization.py --data_dir /home/kin/data/av2/preprocess/sensor/mini --res_name seflow_best +# Compare multiple results side-by-side +python tools/visualization.py vis --data_dir /path/to/data --res_name "[flow, deflow, deltaflow, ssf]" + +# Show flow as vector lines +python tools/visualization.py vector --data_dir /path/to/data + +# Check flow with pc0, pc1, and flowed pc0 +python tools/visualization.py check --data_dir /path/to/data + +# Show error heatmap +python tools/visualization.py error --data_dir /path/to/data --res_name "[flow, deflow, deltaflow, ssf]" ``` Demo Effect (press `SPACE` to stop and start in the visualization window): https://github.com/user-attachments/assets/f031d1a2-2d2f-4947-a01f-834ed1c146e6 +**Tips**: To quickly create qualitative results for all methods, you can use multiple results comparison mode, select a good viewpoint and then save screenshots for all frames by pressing `P` key. You will found all methods' results are saved in the output folder (default is `logs/imgs`). + +## Quick Read .h5 Files + +You can quickly read all keys and shapes in a .h5 file by: + +```bash +python tools/read_h5.py --file_path /path/to/file.h5 +``` + ## Conversion run `tools/zero2ours.py` to convert the ZeroFlow pretrained model to our codebase. diff --git a/tools/visualization.py b/tools/visualization.py index 84c30e2..98b1ad6 100644 --- a/tools/visualization.py +++ b/tools/visualization.py @@ -1,204 +1,383 @@ """ -# Created: 2023-11-29 21:22 -# Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology -# Author: Qingwen Zhang (https://kin-zhang.github.io/), Ajinkya Khoche (https://ajinkyakhoche.github.io/) -# -# This file is part of OpenSceneFlow (https://github.com/KTH-RPL/OpenSceneFlow). -# If you find this repo helpful, please cite the respective publication as -# listed on the above website. -# -# Description: view scene flow dataset after preprocess. +Scene Flow Visualization Tool +============================= +Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology +Author: Qingwen Zhang (https://kin-zhang.github.io/), Ajinkya Khoche (https://ajinkyakhoche.github.io/) -# CHANGELOG: -# 2024-09-10 (Ajinkya): Add vis_multiple(), to visualize multiple flow modes at once. - -# Usage: (flow is ground truth flow, `other_name` is the estimated flow from the model) -* python tools/visualization.py --data_dir /home/kin/data/av2/h5py/demo/train --res_name 'flow' --mode vis -* python tools/visualization.py --data_dir /home/kin/data/av2/h5py/demo/train --res_name "['flow', 'deflow' , 'ssf']" --mode mul +Part of OpenSceneFlow (https://github.com/KTH-RPL/OpenSceneFlow). +Usage (Fire class-based): + # Visualize flow with color coding + python tools/visualization.py vis --data_dir /path/to/data --res_name flow + + # Compare multiple results side-by-side + python tools/visualization.py vis --data_dir /path/to/data --res_name "[flow, flow_est]" + + # Show flow as vector lines + python tools/visualization.py vector --data_dir /path/to/data + + # Check flow with pc0, pc1, and flowed pc0 + python tools/visualization.py check --data_dir /path/to/data + + # Show error heatmap + python tools/visualization.py error --data_dir /path/to/data --res_name "[raw, flow_est]" + +Keys: + [SPACE] play/pause [D] next frame [A] prev frame + [P] save screenshot [E] save error bar + [S] sync viewpoint across windows (multi-window mode) + [ESC/Q] quit """ import numpy as np -import fire, time +import fire +import time from tqdm import tqdm - import open3d as o3d -import os, sys -BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '..' )) +import os +import sys + +BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(BASE_DIR) -from src.utils.mics import HDF5Data, flow_to_rgb -from src.utils.o3d_view import MyVisualizer, MyMultiVisualizer, color_map, create_bev_square +from src.utils.mics import flow_to_rgb, error_to_color +from src.utils import npcal_pose0to1 +from src.utils.o3d_view import O3DVisualizer, color_map +from src.dataset import HDF5Dataset VIEW_FILE = f"{BASE_DIR}/assets/view/demo.json" +NO_COLOR = [1, 1, 1] -def check_flow( - data_dir: str ="/home/kin/data/av2/preprocess/sensor/mini", - res_name: str = "flow", # "flow", "flow_est" - start_id: int = 0, - point_size: float = 3.0, -): - dataset = HDF5Data(data_dir, vis_name=res_name, flow_view=True) - o3d_vis = MyVisualizer(view_file=VIEW_FILE, window_title=f"view {'ground truth flow' if res_name == 'flow' else f'{res_name} flow'}, `SPACE` start/stop") - - opt = o3d_vis.vis.get_render_option() - opt.background_color = np.asarray([80/255, 90/255, 110/255]) - opt.point_size = point_size - - for data_id in (pbar := tqdm(range(start_id, len(dataset)))): - data = dataset[data_id] - now_scene_id = data['scene_id'] - pbar.set_description(f"id: {data_id}, scene_id: {now_scene_id}, timestamp: {data['timestamp']}") - gm0 = data['gm0'] - pc0 = data['pc0'][~gm0] - - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) - pcd.paint_uniform_color([1.0, 0.0, 0.0]) # red: pc0 - - pc1 = data['pc1'] - pcd1 = o3d.geometry.PointCloud() - pcd1.points = o3d.utility.Vector3dVector(pc1[:, :3][~data['gm1']]) - pcd1.paint_uniform_color([0.0, 1.0, 0.0]) # green: pc1 - - pcd2 = o3d.geometry.PointCloud() - # pcd2.points = o3d.utility.Vector3dVector(pc0[:, :3] + pose_flow) # if you want to check pose_flow - pcd2.points = o3d.utility.Vector3dVector(pc0[:, :3] + data[res_name][~gm0]) - pcd2.paint_uniform_color([0.0, 0.0, 1.0]) # blue: pc0 + flow - o3d_vis.update([pcd, pcd1, pcd2, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) -def vis( - data_dir: str ="/home/kin/data/av2/h5py/demo/val", - res_name: str = "flow", # any res_name we write before in HDF5Data - start_id: int = 0, - point_size: float = 2.0, - mode: str = "vis", -): - if mode != "vis": - return - dataset = HDF5Data(data_dir, vis_name=res_name, flow_view=True) - o3d_vis = MyVisualizer(view_file=VIEW_FILE, window_title=f"view {'ground truth flow' if res_name == 'flow' else f'{res_name} flow'}, `SPACE` start/stop") +def _ensure_list(val): + """Ensure value is a list.""" + if val is None: + return [] + return val if isinstance(val, list) else [val] - opt = o3d_vis.vis.get_render_option() - # opt.background_color = np.asarray([216, 216, 216]) / 255.0 - opt.background_color = np.asarray([80/255, 90/255, 110/255]) - # opt.background_color = np.asarray([1, 1, 1]) - opt.point_size = point_size - for data_id in (pbar := tqdm(range(start_id, len(dataset)))): - data = dataset[data_id] - now_scene_id = data['scene_id'] - pbar.set_description(f"id: {data_id}, scene_id: {now_scene_id}, timestamp: {data['timestamp']}") +class SceneFlowVisualizer: + """ + Open3D-based Scene Flow Visualizer. + + Supports multiple visualization modes as class methods, + compatible with python-fire for CLI usage. + """ + + def __init__( + self, + data_dir: str = "/home/kin/data/av2/preprocess/sensor/mini", + res_name: str = "flow", + start_id: int = 0, + num_frames: int = 2, + rgm: bool = True, # remove ground mask + slc: bool = False, # show lidar centers + point_size: float = 3.0, + max_distance: float = 50.0, + bg_color: tuple = (80/255, 90/255, 110/255), + ): + """ + Initialize the visualizer. + + Args: + data_dir: Path to HDF5 dataset directory + res_name: Result name(s) to visualize (string or list) + start_id: Starting frame index + point_size: Point size for rendering + rgm: Remove ground mask if True + slc: Show LiDAR sensor centers if True + num_frames: Number of frames for history mode + max_distance: Maximum distance filter for points + bg_color: Background color as RGB tuple (0-1 range) + """ + self.data_dir = data_dir + self.res_names = _ensure_list(res_name) + self.start_id = start_id + self.point_size = point_size + self.rgm = rgm + self.num_frames = num_frames + self.max_distance = max_distance + self.bg_color = bg_color + self.show_lidar_centers = slc + + def _load_dataset(self, vis_name=None, n_frames=2): + """Load HDF5 dataset.""" + vis_name = vis_name or self.res_names + return HDF5Dataset(self.data_dir, vis_name=vis_name, n_frames=n_frames) + + def _create_visualizer(self, res_name=None): + """Create O3DVisualizer instance.""" + res_name = res_name or self.res_names + return O3DVisualizer( + view_file=VIEW_FILE, + res_name=res_name, + point_size=self.point_size, + bg_color=self.bg_color, + ) + + def _filter_ground_and_distance(self, pc, gm): + """Apply ground mask and distance filter.""" + if not self.rgm: + gm = np.zeros_like(gm) + distance = np.linalg.norm(pc[:, :3], axis=1) + return gm | (distance > self.max_distance) + + def _compute_pose_flow(self, pc0, pose0, pose1): + """Compute ego-motion flow.""" + ego_pose = npcal_pose0to1(pose0, pose1) + return pc0[:, :3] @ ego_pose[:3, :3].T + ego_pose[:3, 3] - pc0[:, :3] + + # ------------------------------------------------------------------------- + # Visualization Modes (Fire subcommands) + # ------------------------------------------------------------------------- + + def vis(self): + """ + Visualize scene flow with color-coded dynamic motion. + + Supports single or multiple result names for side-by-side comparison. + Colors represent flow direction (after ego-motion compensation). + """ + dataset = self._load_dataset() + o3d_vis = self._create_visualizer() + + data_id = self.start_id + pbar = tqdm(range(self.start_id, len(dataset))) + + while 0 <= data_id < len(dataset): + data = dataset[data_id] + pbar.set_description(f"id: {data_id}, scene: {data['scene_id']}, ts: {data['timestamp']}") + + pc0 = data['pc0'] + gm0 = self._filter_ground_and_distance(pc0, data['gm0']) + pose_flow = self._compute_pose_flow(pc0, data['pose0'], data['pose1']) + + if self.rgm: + pc0 = pc0[~gm0] + pose_flow = pose_flow[~gm0] + + pcd_list = [] + for single_res in self.res_names: + pcd = o3d.geometry.PointCloud() + + # Instance/cluster visualization + if single_res in ['dufo', 'cluster', 'dufocluster', 'flow_instance_id', + 'ground_mask', 'pc0_dynamic'] and single_res in data: + labels = data[single_res][~gm0] if self.rgm else data[single_res] + pcd = self._color_by_labels(pc0, labels) + + # Flow visualization + elif single_res in data: + pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) + flow = (data[single_res][~gm0] if self.rgm else data[single_res]) - pose_flow + flow_color = flow_to_rgb(flow) / 255.0 + is_dynamic = np.linalg.norm(flow, axis=1) > 0.08 + flow_color[~is_dynamic] = NO_COLOR + if not self.rgm: + flow_color[gm0] = NO_COLOR + pcd.colors = o3d.utility.Vector3dVector(flow_color) + + # Raw point cloud + elif single_res == 'raw': + pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) + + pcd_list.append([pcd, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) + + # show lidar centers + if self.show_lidar_centers and 'lidar_center' in data: + lidar_center = data['lidar_center'] + for lidar_num in range(lidar_center.shape[0]): + pcd_list[-1].append( + o3d.geometry.TriangleMesh.create_coordinate_frame(size=1).transform( + lidar_center[lidar_num] + ) + ) + + o3d_vis.update(pcd_list, index=data_id) + data_id += o3d_vis.playback_direction + pbar.update(o3d_vis.playback_direction) + + def check(self): + """ + Check flow by showing pc0 (red), pc1 (green), and pc0+flow (blue). + + Useful for verifying flow correctness. + """ + res_name = self.res_names[0] if self.res_names else "flow" + dataset = self._load_dataset(vis_name=res_name) + o3d_vis = self._create_visualizer(res_name=res_name) + + data_id = self.start_id + pbar = tqdm(range(self.start_id, len(dataset))) + + while 0 <= data_id < len(dataset): + data = dataset[data_id] + pbar.set_description(f"id: {data_id}, scene: {data['scene_id']}, ts: {data['timestamp']}") + + if res_name not in dataset[data_id]: + print(f"'{res_name}' not in dataset, skipping id {data_id}") + data_id += 1 + continue + + data = dataset[data_id] + pbar.set_description(f"id: {data_id}, scene: {data['scene_id']}, ts: {data['timestamp']}") + + pc0, pc1 = data['pc0'], data['pc1'] + if self.rgm: + pc0 = pc0[~data['gm0']] + pc1 = pc1[~data['gm1']] + + # Red: pc0 + pcd0 = o3d.geometry.PointCloud() + pcd0.points = o3d.utility.Vector3dVector(pc0[:, :3]) + pcd0.paint_uniform_color([1.0, 0.0, 0.0]) + + # Green: pc1 + pcd1 = o3d.geometry.PointCloud() + pcd1.points = o3d.utility.Vector3dVector(pc1[:, :3]) + pcd1.paint_uniform_color([0.0, 1.0, 0.0]) + + # Blue: pc0 + flow + res_flow = data[res_name][~data['gm0']] if self.rgm else data[res_name] + pcd2 = o3d.geometry.PointCloud() + pcd2.points = o3d.utility.Vector3dVector(pc0[:, :3] + res_flow) + pcd2.paint_uniform_color([0.0, 0.0, 1.0]) + + o3d_vis.update([pcd0, pcd1, pcd2, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) + data_id += o3d_vis.playback_direction + pbar.update(o3d_vis.playback_direction) - pc0 = data['pc0'] - gm0 = data['gm0'] - pose0 = data['pose0'] - pose1 = data['pose1'] - ego_pose = np.linalg.inv(pose1) @ pose0 + def vector(self): + """ + Visualize flow as red vector lines from source to target. + + Shows pc0 (green), pc1 (blue), and flow vectors (red lines). + """ + res_name = self.res_names[0] if self.res_names else "flow" + dataset = self._load_dataset(vis_name=res_name) + o3d_vis = O3DVisualizer( + view_file=VIEW_FILE, + res_name=res_name, + point_size=self.point_size, + bg_color=(1, 1, 1), # White background for vector mode + ) + + data_id = self.start_id + pbar = tqdm(range(self.start_id, len(dataset))) + + while 0 <= data_id < len(dataset): + data = dataset[data_id] + pbar.set_description(f"id: {data_id}, scene: {data['scene_id']}, ts: {data['timestamp']}") + + if res_name not in dataset[data_id]: + print(f"'{res_name}' not in dataset, skipping id {data_id}") + data_id += 1 + continue - pose_flow = pc0[:, :3] @ ego_pose[:3, :3].T + ego_pose[:3, 3] - pc0[:, :3] + pc0 = data['pc0'] + gm0 = self._filter_ground_and_distance(pc0, data['gm0']) + + ego_pose = np.linalg.inv(data['pose1']) @ data['pose0'] + pose_flow = pc0[:, :3] @ ego_pose[:3, :3].T + ego_pose[:3, 3] - pc0[:, :3] + flow = data[res_name] - pose_flow + + # Green: pc0 transformed + vis_pc = pc0[:, :3][~gm0] + pose_flow[~gm0] + pcd0 = o3d.geometry.PointCloud() + pcd0.points = o3d.utility.Vector3dVector(vis_pc) + pcd0.paint_uniform_color([0, 1, 0]) + + # Blue: pc1 + pcd1 = o3d.geometry.PointCloud() + pcd1.points = o3d.utility.Vector3dVector(data['pc1'][:, :3][~data['gm1']]) + pcd1.paint_uniform_color([0.0, 0.0, 1]) + + # Red: flow vectors + line_set = self._create_flow_lines(vis_pc, flow[~gm0], color=[1, 0, 0]) + + o3d_vis.update([pcd0, pcd1, line_set, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)], + index=data_id) + data_id += o3d_vis.playback_direction + pbar.update(o3d_vis.playback_direction) + + def error(self, max_error: float = 0.35): + """ + Visualize flow error as heatmap (hot colormap). + + Args: + max_error: Maximum error for color scaling (meters) + """ + dataset = self._load_dataset() + o3d_vis = self._create_visualizer() + o3d_vis.bg_color = np.asarray([216, 216, 216]) / 255.0 # Off-white + data_id = self.start_id + pbar = tqdm(range(0, len(dataset))) + + while 0 <= data_id < len(dataset): + data = dataset[data_id] + pbar.set_description(f"id: {data_id}, scene: {data['scene_id']}, ts: {data['timestamp']}") + + pc0 = data['pc0'] + gm0 = self._filter_ground_and_distance(pc0, data['gm0']) + + gt_flow = data["flow"][~gm0] if self.rgm else data["flow"] + if self.rgm: + pc0 = pc0[~gm0] + + pcd_list = [] + for single_res in self.res_names: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) + + res_flow = None + if single_res in data: + res_flow = data[single_res][~gm0] if self.rgm else data[single_res] + elif single_res == 'raw': + res_flow = self._compute_pose_flow(pc0, data['pose0'], data['pose1']) + + if res_flow is not None: + error_mag = np.linalg.norm(gt_flow - res_flow, axis=-1) + error_mag[error_mag < 0.05] = 0 + error_color = error_to_color(error_mag, max_error=max_error, color_map="hot") / 255.0 + if not self.rgm: + error_color[gm0] = NO_COLOR + pcd.colors = o3d.utility.Vector3dVector(error_color) + pcd_list.append([pcd, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) + + o3d_vis.update(pcd_list, index=data_id, value=max_error) + data_id += o3d_vis.playback_direction + pbar.update(o3d_vis.playback_direction) + + # ------------------------------------------------------------------------- + # Helper Methods + # ------------------------------------------------------------------------- + + def _color_by_labels(self, pc, labels): + """Create point cloud colored by instance labels.""" pcd = o3d.geometry.PointCloud() - if res_name == 'raw': # no result, only show **raw point cloud** - pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) - pcd.paint_uniform_color([1.0, 1.0, 1.0]) - elif res_name in ['dufo', 'label']: - labels = data[res_name] + for label_i in np.unique(labels): pcd_i = o3d.geometry.PointCloud() - for label_i in np.unique(labels): - pcd_i.points = o3d.utility.Vector3dVector(pc0[labels == label_i][:, :3]) - if label_i <= 0: - pcd_i.paint_uniform_color([1.0, 1.0, 1.0]) - else: - pcd_i.paint_uniform_color(color_map[label_i % len(color_map)]) - pcd += pcd_i - elif res_name in data: - pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) - flow = data[res_name] - pose_flow # ego motion compensation here. - flow_color = flow_to_rgb(flow) / 255.0 - is_dynamic = np.linalg.norm(flow, axis=1) > 0.1 - flow_color[~is_dynamic] = [1, 1, 1] - flow_color[gm0] = [1, 1, 1] - pcd.colors = o3d.utility.Vector3dVector(flow_color) - o3d_vis.update([pcd, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) - - -def vis_multiple( - data_dir: str ="/home/kin/data/av2/h5py/demo/val", - res_name: list = ["flow"], - start_id: int = 0, - point_size: float = 3.0, - tone: str = 'dark', - mode: str = "mul", -): - if mode != "mul": - return - assert isinstance(res_name, list), "vis_multiple() needs a list as flow_mode" - dataset = HDF5Data(data_dir, vis_name=res_name, flow_view=True) - o3d_vis = MyMultiVisualizer(view_file=VIEW_FILE, flow_mode=res_name) - - for v in o3d_vis.vis: - opt = v.get_render_option() - if tone == 'bright': - background_color = np.asarray([216, 216, 216]) / 255.0 # offwhite - # background_color = np.asarray([1, 1, 1]) - pcd_color = [0.25, 0.25, 0.25] - elif tone == 'dark': - background_color = np.asarray([80/255, 90/255, 110/255]) # dark - pcd_color = [1., 1., 1.] - - opt.background_color = background_color - opt.point_size = point_size - - data_id = start_id - pbar = tqdm(range(0, len(dataset))) - - while data_id >= 0 and data_id < len(dataset): - data = dataset[data_id] - now_scene_id = data['scene_id'] - pbar.set_description(f"id: {data_id}, scene_id: {now_scene_id}, timestamp: {data['timestamp']}") - - pc0 = data['pc0'] - gm0 = data['gm0'] - pose0 = data['pose0'] - pose1 = data['pose1'] - ego_pose = np.linalg.inv(pose1) @ pose0 - - pose_flow = pc0[:, :3] @ ego_pose[:3, :3].T + ego_pose[:3, 3] - pc0[:, :3] - - pcd_list = [] - for mode in res_name: - pcd = o3d.geometry.PointCloud() - if mode in ['dufo', 'label']: - labels = data[mode] - pcd_i = o3d.geometry.PointCloud() - for label_i in np.unique(labels): - pcd_i.points = o3d.utility.Vector3dVector(pc0[labels == label_i][:, :3]) - if label_i <= 0: - pcd_i.paint_uniform_color([1.0, 1.0, 1.0]) - else: - pcd_i.paint_uniform_color(color_map[label_i % len(color_map)]) - pcd += pcd_i - elif mode in data: - pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) - flow = data[mode] - pose_flow # ego motion compensation here. - flow_color = flow_to_rgb(flow) / 255.0 - is_dynamic = np.linalg.norm(flow, axis=1) > 0.1 - flow_color[~is_dynamic] = pcd_color - flow_color[gm0] = pcd_color - pcd.colors = o3d.utility.Vector3dVector(flow_color) - pcd_list.append([pcd, create_bev_square(), - create_bev_square(size=204.8, color=[195/255,86/255,89/255]), - o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) - o3d_vis.update(pcd_list) - - data_id += o3d_vis.playback_direction - pbar.update(o3d_vis.playback_direction) - + pcd_i.points = o3d.utility.Vector3dVector(pc[labels == label_i][:, :3]) + if label_i <= 0: + pcd_i.paint_uniform_color(NO_COLOR) + else: + pcd_i.paint_uniform_color(color_map[label_i % len(color_map)]) + pcd += pcd_i + return pcd + + def _create_flow_lines(self, source_pts, flow, color=[1, 0, 0]): + """Create line set for flow visualization.""" + line_set = o3d.geometry.LineSet() + target_pts = source_pts + flow + line_set_points = np.concatenate([source_pts, target_pts], axis=0) + lines = np.array([[i, i + len(source_pts)] for i in range(len(source_pts))]) + line_set.points = o3d.utility.Vector3dVector(line_set_points) + line_set.lines = o3d.utility.Vector2iVector(lines) + line_set.paint_uniform_color(color) + return line_set if __name__ == '__main__': start_time = time.time() - # fire.Fire(check_flow) - fire.Fire(vis) - fire.Fire(vis_multiple) + fire.Fire(SceneFlowVisualizer) print(f"Time used: {time.time() - start_time:.2f} s") \ No newline at end of file From cff7ce80937b13502e49d5f17d4661112389a701 Mon Sep 17 00:00:00 2001 From: Kin Date: Mon, 12 Jan 2026 12:31:39 +0100 Subject: [PATCH 2/2] fix(flow): add index_flow for 2hz gt view etc. --- src/dataset.py | 5 +++-- tools/visualization.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/dataset.py b/src/dataset.py index 9d431b9..13cbc87 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -187,7 +187,7 @@ class HDF5Dataset(Dataset): def __init__(self, directory, \ transform=None, n_frames=2, ssl_label=None, \ eval = False, leaderboard_version=1, \ - vis_name=''): + vis_name='', index_flow=False): ''' Args: directory: the directory of the dataset, the folder should contain some .h5 file and index_total.pkl. @@ -199,6 +199,7 @@ def __init__(self, directory, \ * eval: if True, use the eval index (only used it for leaderboard evaluation) * leaderboard_version: 1st or 2nd, default is 1. If '2', we will use the index_eval_v2.pkl from assets/docs. * vis_name: the data of the visualization, default is ''. + * index_flow: if True, use the flow index for training or visualization. ''' super(HDF5Dataset, self).__init__() self.directory = directory @@ -247,7 +248,7 @@ def __init__(self, directory, \ # for some dataset that annotated HZ is different.... like truckscene and nuscene etc. self.train_index = None - if not eval and ssl_label is None and transform is not None: # transform indicates whether we are in training mode. + if (not eval and ssl_label is None and transform is not None) or index_flow: # transform indicates whether we are in training mode. # check if train seq all have gt. one_scene_id = list(self.scene_id_bounds.keys())[0] check_flow_exist = True diff --git a/tools/visualization.py b/tools/visualization.py index 98b1ad6..859259b 100644 --- a/tools/visualization.py +++ b/tools/visualization.py @@ -103,7 +103,7 @@ def __init__( def _load_dataset(self, vis_name=None, n_frames=2): """Load HDF5 dataset.""" vis_name = vis_name or self.res_names - return HDF5Dataset(self.data_dir, vis_name=vis_name, n_frames=n_frames) + return HDF5Dataset(self.data_dir, vis_name=vis_name, n_frames=n_frames, index_flow='flow' in vis_name) def _create_visualizer(self, res_name=None): """Create O3DVisualizer instance.""" @@ -289,7 +289,8 @@ def vector(self): # Blue: pc1 pcd1 = o3d.geometry.PointCloud() - pcd1.points = o3d.utility.Vector3dVector(data['pc1'][:, :3][~data['gm1']]) + gm1 = self._filter_ground_and_distance(data['pc1'], data['gm1']) + pcd1.points = o3d.utility.Vector3dVector(data['pc1'][:, :3][~gm1]) pcd1.paint_uniform_color([0.0, 0.0, 1]) # Red: flow vectors