diff --git a/colabs/trajan_demo.ipynb b/colabs/trajan_demo.ipynb index 2b55e84..466a923 100644 --- a/colabs/trajan_demo.ipynb +++ b/colabs/trajan_demo.ipynb @@ -27,44 +27,43 @@ "id": "zuYtT7GpuVgQ" }, "source": [ - "\u003cp align=\"center\"\u003e\n", - " \u003ch1 align=\"center\"\u003eTRAJAN: Direct Motion Models for Assessing Generated Videos\u003c/h1\u003e\n", - " \u003cp align=\"center\"\u003e\n", - " \u003ca href=\"https://k-r-allen.github.io/\"\u003eKelsey Allen*\u003c/a\u003e\n", + "

\n", + "

TRAJAN: Direct Motion Models for Assessing Generated Videos

\n", + "

\n", + " Kelsey Allen*\n", " ·\n", - " \u003ca href=\"http://www.carldoersch.com/\"\u003eCarl Doersch\u003c/a\u003e\n", + " Carl Doersch\n", " ·\n", - " \u003ca href=\"https://stanniszhou.github.io/\"\u003eGuangyao Zhou\u003c/a\u003e\n", + " Guangyao Zhou\n", " ·\n", - " \u003ca href=\"https://mohammedsuhail.net/\"\u003eMohammed Suhail\u003c/a\u003e\n", + " Mohammed Suhail\n", " ·\n", - " \u003ca href=\"https://dannydriess.github.io/\"\u003eDanny Driess\u003c/a\u003e\n", + " Danny Driess\n", " ·\n", - " \u003ca href=\"https://www.irocco.info/\"\u003eIgnacio Rocco\u003c/a\u003e\n", + " Ignacio Rocco\n", " ·\n", - " \u003ca href=\"https://yuliarubanova.github.io/\"\u003eYulia Rubanova\u003c/a\u003e\n", + " Yulia Rubanova\n", " ·\n", - " \u003ca href=\"https://tkipf.github.io/\"\u003eThomas Kipf\u003c/a\u003e\n", + " Thomas Kipf\n", " ·\n", - " \u003ca href=\"https://msajjadi.com/\"\u003eMehdi S. M. Sajjadi\u003c/a\u003e\n", + " Mehdi S. M. Sajjadi\n", " ·\n", - " \u003ca href=\"https://scholar.google.com/citations?user=MxxZkEcAAAAJ\u0026hl=en\"\u003eKevin Murphy\u003c/a\u003e\n", + " Kevin Murphy\n", " ·\n", - " \u003ca href=\"https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ\"\u003eJoao Carreira\u003c/a\u003e\n", + " Joao Carreira\n", " ·\n", - " \u003ca href=\"https://www.sjoerdvansteenkiste.com/\"\u003eSjoerd van Steenkiste*\u003c/a\u003e\n", - " \u003c/p\u003e\n", - " \u003ch3 align=\"center\"\u003e\u003ca href=\"\"\u003ePaper\u003c/a\u003e | \u003ca href=\"https://trajan-paper.github.io\"\u003eProject Page\u003c/a\u003e | \u003ca href=\"https://github.com/deepmind/tapnet\"\u003eGitHub\u003c/a\u003e \u003c/h3\u003e\n", - " \u003cdiv align=\"center\"\u003e\u003c/div\u003e\n", - "\u003c/p\u003e\n", - "\n", - "\u003cp align=\"center\"\u003e\n", - " \u003ca href=\"\"\u003e\n", - " \u003cimg src=\"https://storage.googleapis.com/dm-tapnet/swaying_gif.gif\" alt=\"Logo\" width=\"50%\"\u003e\n", - " \u003c/a\u003e\n", - "\u003c/p\u003e\n", - "\n", - "" + " Sjoerd van Steenkiste*\n", + "

\n", + "

Paper | Project Page | GitHub

\n", + "
\n", + "

\n", + "\n", + "

\n", + " \n", + " \"Logo\"\n", + " \n", + "

\n", + "\n" ] }, { @@ -407,7 +406,7 @@ " \"\"\"\n", "\n", " # note that we do not use query points in the encoding, so it is expected\n", - " # that num_support_tracks \u003e\u003e num_target_tracks\n", + " # that num_support_tracks >> num_target_tracks\n", "\n", " num_support_tracks: int\n", " num_target_tracks: int\n", @@ -436,7 +435,7 @@ " # pad to 'episode_length' frames\n", " if self.before_boundary:\n", " # if input video is longer than episode_length, crop to episode_length\n", - " if self.episode_length - visibles.shape[0] \u003c 0:\n", + " if self.episode_length - visibles.shape[0] < 0:\n", " visibles = visibles[: self.episode_length]\n", " tracks_xy = tracks_xy[: self.episode_length]\n", "\n", @@ -457,7 +456,7 @@ " np.random.shuffle(idx)\n", "\n", " assert (\n", - " num_input_tracks \u003e= self.num_support_tracks + self.num_target_tracks\n", + " num_input_tracks >= self.num_support_tracks + self.num_target_tracks\n", " ), (\n", " (\n", " f\"num_input_tracks {num_input_tracks} must be greater than\"\n", @@ -484,34 +483,58 @@ " np.transpose(support_tracks_visible, [1, 0]), axis=-1\n", " )\n", "\n", - " # [time, point_id, x/y] -\u003e [point_id, time, x/y]\n", + " # [time, point_id, x/y] -> [point_id, time, x/y]\n", " target_tracks = np.transpose(target_tracks, [1, 0, 2])\n", " target_tracks_visible = np.transpose(target_tracks_visible, [1, 0])\n", "\n", " # Sample query points as random visible points\n", " num_target_tracks = target_tracks_visible.shape[0]\n", - " num_frames = target_tracks_visible.shape[1]\n", - " random_frame = np.zeros(num_target_tracks, dtype=np.int64)\n", + " target_queries = self.sample_query_from_targets(\n", + " num_target_tracks, target_tracks, target_tracks_visible)\n", "\n", - " for i in range(num_target_tracks):\n", - " visible_indices = np.where(target_tracks_visible[i] \u003e 0)[0]\n", - " if len(visible_indices) \u003e 0:\n", - " # Choose a random frame index from the visible ones\n", - " random_frame[i] = np.random.choice(visible_indices)\n", - " else:\n", - " # If no frame is visible for a track, default to frame 0\n", - " # (or handle as appropriate for your use case)\n", - " random_frame[i] = 0\n", + " # Add channel dimension to target_tracks_visible\n", + " target_tracks_visible = np.expand_dims(target_tracks_visible, axis=-1)\n", "\n", - " # Create one-hot encoding based on the randomly selected frame for each track\n", - " idx = np.eye(num_frames, dtype=np.float32)[\n", - " random_frame\n", - " ] # [num_target_tracks, num_frames]\n", + " # Updates `features` to contain these *new* features and add batch dim.\n", + " features_new = {\n", + " \"support_tracks\": support_tracks[None, :],\n", + " \"support_tracks_visible\": support_tracks_visible[None, :],\n", + " \"query_points\": target_queries[None, :],\n", + " \"target_points\": target_tracks[None, :],\n", + " \"boundary_frame\": np.array([boundary_frame]),\n", + " \"target_tracks_visible\": target_tracks_visible[None, :],\n", + " }\n", + " features.update(features_new)\n", + " return features\n", + " \n", + " def sample_query_from_targets(\n", + " self,\n", + " num_query_tracks: int,\n", + " target_tracks: np.ndarray,\n", + " target_tracks_visible: np.ndarray,\n", + " ) -> np.ndarray:\n", + " \"\"\"Samples query points from target tracks.\"\"\"\n", + " random_frame = np.zeros(num_query_tracks, dtype=np.int64)\n", + " num_frames = target_tracks_visible.shape[1]\n", + " for i in range(num_query_tracks):\n", + " visible_indices = np.where(target_tracks_visible[i] > 0)[0]\n", + " if len(visible_indices) > 0:\n", + " # Choose a random frame index from the visible ones\n", + " random_frame[i] = np.random.choice(visible_indices)\n", + " else:\n", + " # If no frame is visible for a track, default to frame 0\n", + " # (or handle as appropriate for your use case)\n", + " random_frame[i] = 0\n", + " \n", + " # Create one-hot encoding based on the randomly selected frame for each track\n", + " idx = np.eye(num_frames, dtype=np.float32)[\n", + " random_frame\n", + " ] # [num_query_tracks, num_frames]\n", "\n", " # Use the one-hot index to select the coordinates at the chosen frame\n", " target_queries_xy = np.sum(\n", " target_tracks * idx[..., np.newaxis], axis=1\n", - " ) # [num_target_tracks, 2]\n", + " ) # [num_query_tracks, 2]\n", "\n", " # Stack frame index and coordinates: [t, x, y]\n", " target_queries = np.stack(\n", @@ -521,22 +544,8 @@ " target_queries_xy[..., 1],\n", " ],\n", " axis=-1,\n", - " ) # [num_target_tracks, 3]\n", - "\n", - " # Add channel dimension to target_tracks_visible\n", - " target_tracks_visible = np.expand_dims(target_tracks_visible, axis=-1)\n", - "\n", - " # Updates `features` to contain these *new* features and add batch dim.\n", - " features_new = {\n", - " \"support_tracks\": support_tracks[None, :],\n", - " \"support_tracks_visible\": support_tracks_visible[None, :],\n", - " \"query_points\": target_queries[None, :],\n", - " \"target_points\": target_tracks[None, :],\n", - " \"boundary_frame\": np.array([boundary_frame]),\n", - " \"target_tracks_visible\": target_tracks_visible[None, :],\n", - " }\n", - " features.update(features_new)\n", - " return features" + " ) # [num_query_tracks, 3]\n", + " return target_queries" ] }, { @@ -551,7 +560,10 @@ "# @title Run Model\n", "\n", "# Create model and define forward pass.\n", - "model = track_autoencoder.TrackAutoEncoder(decoder_scan_chunk_size=32)\n", + "model = track_autoencoder.TrackAutoEncoder(\n", + " decoder_scan_chunk_size=32, # If passing large queries\n", + "# decoder_scan_chunk_size=None, # If passing arbitrary small queries\n", + ")\n", "\n", "@jax.jit\n", "def forward(params, inputs):\n", @@ -573,9 +585,9 @@ " transforms.convert_grid_coordinates(\n", " tracks + 0.5, (width, height), (1, 1)\n", " ),\n", - " \"q t c -\u003e t q c\",\n", + " \"q t c -> t q c\",\n", " ),\n", - " \"visible\": einops.rearrange(visibles, \"q t -\u003e t q\"),\n", + " \"visible\": einops.rearrange(visibles, \"q t -> t q\"),\n", "}\n", "\n", "batch = preprocessor.random_map(batch)\n", @@ -585,6 +597,42 @@ "outputs = forward(params, batch)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Run Model on Custom Query Points\n", + "\n", + "# Create model and define forward pass.\n", + "model = track_autoencoder.TrackAutoEncoder(\n", + "# decoder_scan_chunk_size=32, # If passing large queries\n", + " decoder_scan_chunk_size=None, # If passing arbitrary small queries\n", + ")\n", + "\n", + "@jax.jit\n", + "def forward(params, inputs):\n", + " outputs = model.apply({'params': params}, inputs)\n", + " return outputs\n", + "\n", + "# [Optional] Define custom query points [t, x, y], where t is the frame\n", + "# index, and x and y are normalized coordinates e.g. [12., 0.01, 0.9]\n", + "# Comment this out to use the query points sampled from the target tracks.\n", + "query_pts = np.array(\n", + " [\n", + " [10., 0.5, 0.5], # center of the image\n", + " [10., 0.05, 0.05], # top left corner\n", + " [10., 0.95, 0.95], # bottom right corner\n", + " [10., 0.05, 0.95], # bottom left corner\n", + " ]\n", + ") # [num_query_points, 3]\n", + "batch['query_points'] = query_pts[None, :]\n", + "\n", + "# Run forward pass\n", + "outputs = forward(params, batch)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -612,12 +660,12 @@ " outputs.visible_logits, outputs.certain_logits\n", ")\n", "\n", - "# NOTE: uncomment the lines below to also visualize the support \u0026 target tracks.\n", + "# NOTE: uncomment the lines below to also visualize the support & target tracks.\n", "video_length = video.shape[0]\n", "\n", "# video_viz = viz_utils.paint_point_track(\n", "# video,\n", - "# support_tracks_vis[:, : video.shape[1]],\n", + "# support_tracks_vis[:, : video.shape[0]],\n", "# batch['support_tracks_visible'][0, :, :video_length],\n", "# )\n", "# media.show_video(video_viz, fps=10)\n", @@ -632,7 +680,8 @@ "video_viz = viz_utils.paint_point_track(\n", " video,\n", " reconstructed_tracks[:, :video_length],\n", - " batch['target_tracks_visible'][0, :, :video_length],\n", + " reconstructed_visibles[0, :, :video_length],\n", + " # np.ones_like(reconstructed_visibles[0, :, :video_length]),\n", ")\n", "media.show_video(video_viz, fps=10)" ] diff --git a/tapnet/trajan/track_autoencoder.py b/tapnet/trajan/track_autoencoder.py index 1386380..d4f1079 100644 --- a/tapnet/trajan/track_autoencoder.py +++ b/tapnet/trajan/track_autoencoder.py @@ -72,9 +72,9 @@ class TrackAutoEncoderInputs(TypedDict): """Track autoencoder inputs. Attributes: - query_points: The (t, x, y) locations of a set of query points on initial - frame. The decoder predicts the location and visibility of these query - points for T frames into the future. + query_points: The (t, x, y) locations of a set of query points on the + chosen frames. The decoder predicts the location and visibility of these + points for the remaining frames. boundary_frame: Int specifying the first frame of any padding in the support tracks. Track values starting on this frame will be masked out of the attention for transformers.