Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 118 additions & 69 deletions colabs/trajan_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,44 +27,43 @@
"id": "zuYtT7GpuVgQ"
},
"source": [
"\u003cp align=\"center\"\u003e\n",
" \u003ch1 align=\"center\"\u003eTRAJAN: Direct Motion Models for Assessing Generated Videos\u003c/h1\u003e\n",
" \u003cp align=\"center\"\u003e\n",
" \u003ca href=\"https://k-r-allen.github.io/\"\u003eKelsey Allen*\u003c/a\u003e\n",
"<p align=\"center\">\n",
" <h1 align=\"center\">TRAJAN: Direct Motion Models for Assessing Generated Videos</h1>\n",
" <p align=\"center\">\n",
" <a href=\"https://k-r-allen.github.io/\">Kelsey Allen*</a>\n",
" ·\n",
" \u003ca href=\"http://www.carldoersch.com/\"\u003eCarl Doersch\u003c/a\u003e\n",
" <a href=\"http://www.carldoersch.com/\">Carl Doersch</a>\n",
" ·\n",
" \u003ca href=\"https://stanniszhou.github.io/\"\u003eGuangyao Zhou\u003c/a\u003e\n",
" <a href=\"https://stanniszhou.github.io/\">Guangyao Zhou</a>\n",
" ·\n",
" \u003ca href=\"https://mohammedsuhail.net/\"\u003eMohammed Suhail\u003c/a\u003e\n",
" <a href=\"https://mohammedsuhail.net/\">Mohammed Suhail</a>\n",
" ·\n",
" \u003ca href=\"https://dannydriess.github.io/\"\u003eDanny Driess\u003c/a\u003e\n",
" <a href=\"https://dannydriess.github.io/\">Danny Driess</a>\n",
" ·\n",
" \u003ca href=\"https://www.irocco.info/\"\u003eIgnacio Rocco\u003c/a\u003e\n",
" <a href=\"https://www.irocco.info/\">Ignacio Rocco</a>\n",
" ·\n",
" \u003ca href=\"https://yuliarubanova.github.io/\"\u003eYulia Rubanova\u003c/a\u003e\n",
" <a href=\"https://yuliarubanova.github.io/\">Yulia Rubanova</a>\n",
" ·\n",
" \u003ca href=\"https://tkipf.github.io/\"\u003eThomas Kipf\u003c/a\u003e\n",
" <a href=\"https://tkipf.github.io/\">Thomas Kipf</a>\n",
" ·\n",
" \u003ca href=\"https://msajjadi.com/\"\u003eMehdi S. M. Sajjadi\u003c/a\u003e\n",
" <a href=\"https://msajjadi.com/\">Mehdi S. M. Sajjadi</a>\n",
" ·\n",
" \u003ca href=\"https://scholar.google.com/citations?user=MxxZkEcAAAAJ\u0026hl=en\"\u003eKevin Murphy\u003c/a\u003e\n",
" <a href=\"https://scholar.google.com/citations?user=MxxZkEcAAAAJ&hl=en\">Kevin Murphy</a>\n",
" ·\n",
" \u003ca href=\"https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ\"\u003eJoao Carreira\u003c/a\u003e\n",
" <a href=\"https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ\">Joao Carreira</a>\n",
" ·\n",
" \u003ca href=\"https://www.sjoerdvansteenkiste.com/\"\u003eSjoerd van Steenkiste*\u003c/a\u003e\n",
" \u003c/p\u003e\n",
" \u003ch3 align=\"center\"\u003e\u003ca href=\"\"\u003ePaper\u003c/a\u003e | \u003ca href=\"https://trajan-paper.github.io\"\u003eProject Page\u003c/a\u003e | \u003ca href=\"https://github.com/deepmind/tapnet\"\u003eGitHub\u003c/a\u003e \u003c/h3\u003e\n",
" \u003cdiv align=\"center\"\u003e\u003c/div\u003e\n",
"\u003c/p\u003e\n",
"\n",
"\u003cp align=\"center\"\u003e\n",
" \u003ca href=\"\"\u003e\n",
" \u003cimg src=\"https://storage.googleapis.com/dm-tapnet/swaying_gif.gif\" alt=\"Logo\" width=\"50%\"\u003e\n",
" \u003c/a\u003e\n",
"\u003c/p\u003e\n",
"\n",
""
" <a href=\"https://www.sjoerdvansteenkiste.com/\">Sjoerd van Steenkiste*</a>\n",
" </p>\n",
" <h3 align=\"center\"><a href=\"\">Paper</a> | <a href=\"https://trajan-paper.github.io\">Project Page</a> | <a href=\"https://github.com/deepmind/tapnet\">GitHub</a> </h3>\n",
" <div align=\"center\"></div>\n",
"</p>\n",
"\n",
"<p align=\"center\">\n",
" <a href=\"\">\n",
" <img src=\"https://storage.googleapis.com/dm-tapnet/swaying_gif.gif\" alt=\"Logo\" width=\"50%\">\n",
" </a>\n",
"</p>\n",
"\n"
]
},
{
Expand Down Expand Up @@ -407,7 +406,7 @@
" \"\"\"\n",
"\n",
" # note that we do not use query points in the encoding, so it is expected\n",
" # that num_support_tracks \u003e\u003e num_target_tracks\n",
" # that num_support_tracks >> num_target_tracks\n",
"\n",
" num_support_tracks: int\n",
" num_target_tracks: int\n",
Expand Down Expand Up @@ -436,7 +435,7 @@
" # pad to 'episode_length' frames\n",
" if self.before_boundary:\n",
" # if input video is longer than episode_length, crop to episode_length\n",
" if self.episode_length - visibles.shape[0] \u003c 0:\n",
" if self.episode_length - visibles.shape[0] < 0:\n",
" visibles = visibles[: self.episode_length]\n",
" tracks_xy = tracks_xy[: self.episode_length]\n",
"\n",
Expand All @@ -457,7 +456,7 @@
" np.random.shuffle(idx)\n",
"\n",
" assert (\n",
" num_input_tracks \u003e= self.num_support_tracks + self.num_target_tracks\n",
" num_input_tracks >= self.num_support_tracks + self.num_target_tracks\n",
" ), (\n",
" (\n",
" f\"num_input_tracks {num_input_tracks} must be greater than\"\n",
Expand All @@ -484,34 +483,58 @@
" np.transpose(support_tracks_visible, [1, 0]), axis=-1\n",
" )\n",
"\n",
" # [time, point_id, x/y] -\u003e [point_id, time, x/y]\n",
" # [time, point_id, x/y] -> [point_id, time, x/y]\n",
" target_tracks = np.transpose(target_tracks, [1, 0, 2])\n",
" target_tracks_visible = np.transpose(target_tracks_visible, [1, 0])\n",
"\n",
" # Sample query points as random visible points\n",
" num_target_tracks = target_tracks_visible.shape[0]\n",
" num_frames = target_tracks_visible.shape[1]\n",
" random_frame = np.zeros(num_target_tracks, dtype=np.int64)\n",
" target_queries = self.sample_query_from_targets(\n",
" num_target_tracks, target_tracks, target_tracks_visible)\n",
"\n",
" for i in range(num_target_tracks):\n",
" visible_indices = np.where(target_tracks_visible[i] \u003e 0)[0]\n",
" if len(visible_indices) \u003e 0:\n",
" # Choose a random frame index from the visible ones\n",
" random_frame[i] = np.random.choice(visible_indices)\n",
" else:\n",
" # If no frame is visible for a track, default to frame 0\n",
" # (or handle as appropriate for your use case)\n",
" random_frame[i] = 0\n",
" # Add channel dimension to target_tracks_visible\n",
" target_tracks_visible = np.expand_dims(target_tracks_visible, axis=-1)\n",
"\n",
" # Create one-hot encoding based on the randomly selected frame for each track\n",
" idx = np.eye(num_frames, dtype=np.float32)[\n",
" random_frame\n",
" ] # [num_target_tracks, num_frames]\n",
" # Updates `features` to contain these *new* features and add batch dim.\n",
" features_new = {\n",
" \"support_tracks\": support_tracks[None, :],\n",
" \"support_tracks_visible\": support_tracks_visible[None, :],\n",
" \"query_points\": target_queries[None, :],\n",
" \"target_points\": target_tracks[None, :],\n",
" \"boundary_frame\": np.array([boundary_frame]),\n",
" \"target_tracks_visible\": target_tracks_visible[None, :],\n",
" }\n",
" features.update(features_new)\n",
" return features\n",
" \n",
" def sample_query_from_targets(\n",
" self,\n",
" num_query_tracks: int,\n",
" target_tracks: np.ndarray,\n",
" target_tracks_visible: np.ndarray,\n",
" ) -> np.ndarray:\n",
" \"\"\"Samples query points from target tracks.\"\"\"\n",
" random_frame = np.zeros(num_query_tracks, dtype=np.int64)\n",
" num_frames = target_tracks_visible.shape[1]\n",
" for i in range(num_query_tracks):\n",
" visible_indices = np.where(target_tracks_visible[i] > 0)[0]\n",
" if len(visible_indices) > 0:\n",
" # Choose a random frame index from the visible ones\n",
" random_frame[i] = np.random.choice(visible_indices)\n",
" else:\n",
" # If no frame is visible for a track, default to frame 0\n",
" # (or handle as appropriate for your use case)\n",
" random_frame[i] = 0\n",
" \n",
" # Create one-hot encoding based on the randomly selected frame for each track\n",
" idx = np.eye(num_frames, dtype=np.float32)[\n",
" random_frame\n",
" ] # [num_query_tracks, num_frames]\n",
"\n",
" # Use the one-hot index to select the coordinates at the chosen frame\n",
" target_queries_xy = np.sum(\n",
" target_tracks * idx[..., np.newaxis], axis=1\n",
" ) # [num_target_tracks, 2]\n",
" ) # [num_query_tracks, 2]\n",
"\n",
" # Stack frame index and coordinates: [t, x, y]\n",
" target_queries = np.stack(\n",
Expand All @@ -521,22 +544,8 @@
" target_queries_xy[..., 1],\n",
" ],\n",
" axis=-1,\n",
" ) # [num_target_tracks, 3]\n",
"\n",
" # Add channel dimension to target_tracks_visible\n",
" target_tracks_visible = np.expand_dims(target_tracks_visible, axis=-1)\n",
"\n",
" # Updates `features` to contain these *new* features and add batch dim.\n",
" features_new = {\n",
" \"support_tracks\": support_tracks[None, :],\n",
" \"support_tracks_visible\": support_tracks_visible[None, :],\n",
" \"query_points\": target_queries[None, :],\n",
" \"target_points\": target_tracks[None, :],\n",
" \"boundary_frame\": np.array([boundary_frame]),\n",
" \"target_tracks_visible\": target_tracks_visible[None, :],\n",
" }\n",
" features.update(features_new)\n",
" return features"
" ) # [num_query_tracks, 3]\n",
" return target_queries"
]
},
{
Expand All @@ -551,7 +560,10 @@
"# @title Run Model\n",
"\n",
"# Create model and define forward pass.\n",
"model = track_autoencoder.TrackAutoEncoder(decoder_scan_chunk_size=32)\n",
"model = track_autoencoder.TrackAutoEncoder(\n",
" decoder_scan_chunk_size=32, # If passing large queries\n",
"# decoder_scan_chunk_size=None, # If passing arbitrary small queries\n",
")\n",
"\n",
"@jax.jit\n",
"def forward(params, inputs):\n",
Expand All @@ -573,9 +585,9 @@
" transforms.convert_grid_coordinates(\n",
" tracks + 0.5, (width, height), (1, 1)\n",
" ),\n",
" \"q t c -\u003e t q c\",\n",
" \"q t c -> t q c\",\n",
" ),\n",
" \"visible\": einops.rearrange(visibles, \"q t -\u003e t q\"),\n",
" \"visible\": einops.rearrange(visibles, \"q t -> t q\"),\n",
"}\n",
"\n",
"batch = preprocessor.random_map(batch)\n",
Expand All @@ -585,6 +597,42 @@
"outputs = forward(params, batch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# @title Run Model on Custom Query Points\n",
"\n",
"# Create model and define forward pass.\n",
"model = track_autoencoder.TrackAutoEncoder(\n",
"# decoder_scan_chunk_size=32, # If passing large queries\n",
" decoder_scan_chunk_size=None, # If passing arbitrary small queries\n",
")\n",
"\n",
"@jax.jit\n",
"def forward(params, inputs):\n",
" outputs = model.apply({'params': params}, inputs)\n",
" return outputs\n",
"\n",
"# [Optional] Define custom query points [t, x, y], where t is the frame\n",
"# index, and x and y are normalized coordinates e.g. [12., 0.01, 0.9]\n",
"# Comment this out to use the query points sampled from the target tracks.\n",
"query_pts = np.array(\n",
" [\n",
" [10., 0.5, 0.5], # center of the image\n",
" [10., 0.05, 0.05], # top left corner\n",
" [10., 0.95, 0.95], # bottom right corner\n",
" [10., 0.05, 0.95], # bottom left corner\n",
" ]\n",
") # [num_query_points, 3]\n",
"batch['query_points'] = query_pts[None, :]\n",
"\n",
"# Run forward pass\n",
"outputs = forward(params, batch)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -612,12 +660,12 @@
" outputs.visible_logits, outputs.certain_logits\n",
")\n",
"\n",
"# NOTE: uncomment the lines below to also visualize the support \u0026 target tracks.\n",
"# NOTE: uncomment the lines below to also visualize the support & target tracks.\n",
"video_length = video.shape[0]\n",
"\n",
"# video_viz = viz_utils.paint_point_track(\n",
"# video,\n",
"# support_tracks_vis[:, : video.shape[1]],\n",
"# support_tracks_vis[:, : video.shape[0]],\n",
"# batch['support_tracks_visible'][0, :, :video_length],\n",
"# )\n",
"# media.show_video(video_viz, fps=10)\n",
Expand All @@ -632,7 +680,8 @@
"video_viz = viz_utils.paint_point_track(\n",
" video,\n",
" reconstructed_tracks[:, :video_length],\n",
" batch['target_tracks_visible'][0, :, :video_length],\n",
" reconstructed_visibles[0, :, :video_length],\n",
" # np.ones_like(reconstructed_visibles[0, :, :video_length]),\n",
")\n",
"media.show_video(video_viz, fps=10)"
]
Expand Down
6 changes: 3 additions & 3 deletions tapnet/trajan/track_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ class TrackAutoEncoderInputs(TypedDict):
"""Track autoencoder inputs.

Attributes:
query_points: The (t, x, y) locations of a set of query points on initial
frame. The decoder predicts the location and visibility of these query
points for T frames into the future.
query_points: The (t, x, y) locations of a set of query points on the
chosen frames. The decoder predicts the location and visibility of these
points for the remaining frames.
boundary_frame: Int specifying the first frame of any padding in the support
tracks. Track values starting on this frame will be masked out of the
attention for transformers.
Expand Down