diff --git a/README.md b/README.md index c5048017..cc94463f 100644 --- a/README.md +++ b/README.md @@ -88,11 +88,15 @@ We offer a number of other ways to interact with CoTracker: [Google Colab](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb). - Or explore the notebook located at [`notebooks/demo.ipynb`](./notebooks/demo.ipynb). 2. You can [install](#installation-instructions) CoTracker _locally_ and then: - - Run an *offline* demo with 10 ⨉ 10 points sampled on a grid on the first frame of a video (results will be saved to `./saved_videos/demo.mp4`)): + - Run an *offline* demo with 10 ⨉ 10 points sampled on a grid on the first frame of a video or images (results will be saved to `./saved_videos/demo.mp4`)): ```bash python demo.py --grid_size 10 ``` + Or interactive_query points using mouse click + ``` + python demo.py --video_path ./your/images/or/video --checkpoint ./checkpoints/cotracker2.pth --grid_query_frame 5 --interactive_query + ``` - Run an *online* demo: ```bash diff --git a/cotracker/utils/visualizer.py b/cotracker/utils/visualizer.py index 88287c37..19a76cdd 100644 --- a/cotracker/utils/visualizer.py +++ b/cotracker/utils/visualizer.py @@ -14,18 +14,86 @@ import matplotlib.pyplot as plt from PIL import Image, ImageDraw +import matplotlib.pyplot as plt +from matplotlib.widgets import Button + +# Event handler for mouse clicks +def on_click(event, queries): + if event.button == 1 and event.inaxes == ax: # Left mouse button clicked + x, y = int(np.round(event.xdata)), int(np.round(event.ydata)) + frame_idx = 0 # Assuming the first frame for simplicity + + # Add the clicked point to the list of queries + queries.append([frame_idx, x, y]) + + # Update the plot to show the new point + # Map the index of the query to a value between 0 and 1 + color_index = len(queries) % 20 # Ensure it cycles through the colormap + color = colormap(color_index / 20) # Normalize the index + + # The colormap returns a tuple with an alpha channel, but we only need the RGB values + color = color[:3] # Remove the alpha channel + + ax.plot(x, y, 'o', color=color, markersize=2) + plt.draw() + +# Function to get queries from mouse clicks +def get_queries_from_clicks(frame): + global ax, colormap + + # Initialize queries as an empty list + queries = [] + + # Convert the tensor to a numpy array and ensure it's in the correct range [0, 1] + frame_np = frame.permute(1, 2, 0).cpu().numpy() + frame_np = (frame_np - frame_np.min()) / (frame_np.max() - frame_np.min()) + + # Display the frame and set up the event handler + fig, ax = plt.subplots() + ax.imshow(frame_np) + colormap = plt.cm.get_cmap('tab20') + cid = fig.canvas.mpl_connect('button_press_event', lambda event: on_click(event, queries)) + + # Wait for user input + plt.show() + + # Disconnect the event handler + fig.canvas.mpl_disconnect(cid) + + # Convert the list of queries to a tensor + queries_tensor = torch.tensor(queries) + + # Move queries to the appropriate device + if torch.cuda.is_available(): + queries_tensor = queries_tensor.cuda() + + return queries_tensor def read_video_from_path(path): try: - reader = imageio.get_reader(path) + # Check if the path is a video file + if os.path.isfile(path): + reader = imageio.get_reader(path) + frames = [] + for i, im in enumerate(reader): + frames.append(np.array(im)) + return np.stack(frames) + # Check if the path is a directory + elif os.path.isdir(path): + images = [] + # Get all files in the directory and sort them + filenames = sorted(os.listdir(path)) + for filename in filenames: + if filename.endswith(('.png', '.jpg', '.jpeg')): + img = imageio.imread(os.path.join(path, filename)) + images.append(img) + return np.stack(images) + else: + print("Error: Invalid path") + return None except Exception as e: - print("Error opening video file: ", e) + print("Error opening video file or images folder: ", e) return None - frames = [] - for i, im in enumerate(reader): - frames.append(np.array(im)) - return np.stack(frames) - def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True): # Create a draw object diff --git a/demo.py b/demo.py index 2f265879..c23a41c4 100644 --- a/demo.py +++ b/demo.py @@ -10,7 +10,7 @@ import numpy as np from PIL import Image -from cotracker.utils.visualizer import Visualizer, read_video_from_path +from cotracker.utils.visualizer import Visualizer, read_video_from_path, get_queries_from_clicks from cotracker.predictor import CoTrackerPredictor # Unfortunately MPS acceleration does not support all the features we require, @@ -58,6 +58,14 @@ help="Compute tracks in both directions, not only forward", ) + # Flag to enable interactive queries + parser.add_argument( + "--interactive_query", + action="store_true", + default=False, # Set default value to False + help="Enable interactive query mode for user input." + ) + args = parser.parse_args() # load the input video frame by frame @@ -73,8 +81,15 @@ model = model.to(DEFAULT_DEVICE) video = video.to(DEFAULT_DEVICE) # video = video[:, :20] + # Determine the queries based on interactive mode + if args.interactive_query: + queries = get_queries_from_clicks(video[0][args.grid_query_frame]).float()[None] + else: + queries = None + pred_tracks, pred_visibility = model( video, + queries=queries, grid_size=args.grid_size, grid_query_frame=args.grid_query_frame, backward_tracking=args.backward_tracking,