From 37e3ecfaad57a7187ef038f1b6080663cfe79026 Mon Sep 17 00:00:00 2001
From: Jonathan Scholz
Date: Thu, 8 May 2025 08:46:28 -0400
Subject: [PATCH] Added example for querying trajan at user-provided points
---
colabs/trajan_demo.ipynb | 187 ++++++++++++++++++-----------
tapnet/trajan/track_autoencoder.py | 6 +-
2 files changed, 121 insertions(+), 72 deletions(-)
diff --git a/colabs/trajan_demo.ipynb b/colabs/trajan_demo.ipynb
index 2b55e84..466a923 100644
--- a/colabs/trajan_demo.ipynb
+++ b/colabs/trajan_demo.ipynb
@@ -27,44 +27,43 @@
"id": "zuYtT7GpuVgQ"
},
"source": [
- "\u003cp align=\"center\"\u003e\n",
- " \u003ch1 align=\"center\"\u003eTRAJAN: Direct Motion Models for Assessing Generated Videos\u003c/h1\u003e\n",
- " \u003cp align=\"center\"\u003e\n",
- " \u003ca href=\"https://k-r-allen.github.io/\"\u003eKelsey Allen*\u003c/a\u003e\n",
+ "\n",
+ "
TRAJAN: Direct Motion Models for Assessing Generated Videos
\n",
+ " \n",
+ " Kelsey Allen*\n",
" ·\n",
- " \u003ca href=\"http://www.carldoersch.com/\"\u003eCarl Doersch\u003c/a\u003e\n",
+ " Carl Doersch\n",
" ·\n",
- " \u003ca href=\"https://stanniszhou.github.io/\"\u003eGuangyao Zhou\u003c/a\u003e\n",
+ " Guangyao Zhou\n",
" ·\n",
- " \u003ca href=\"https://mohammedsuhail.net/\"\u003eMohammed Suhail\u003c/a\u003e\n",
+ " Mohammed Suhail\n",
" ·\n",
- " \u003ca href=\"https://dannydriess.github.io/\"\u003eDanny Driess\u003c/a\u003e\n",
+ " Danny Driess\n",
" ·\n",
- " \u003ca href=\"https://www.irocco.info/\"\u003eIgnacio Rocco\u003c/a\u003e\n",
+ " Ignacio Rocco\n",
" ·\n",
- " \u003ca href=\"https://yuliarubanova.github.io/\"\u003eYulia Rubanova\u003c/a\u003e\n",
+ " Yulia Rubanova\n",
" ·\n",
- " \u003ca href=\"https://tkipf.github.io/\"\u003eThomas Kipf\u003c/a\u003e\n",
+ " Thomas Kipf\n",
" ·\n",
- " \u003ca href=\"https://msajjadi.com/\"\u003eMehdi S. M. Sajjadi\u003c/a\u003e\n",
+ " Mehdi S. M. Sajjadi\n",
" ·\n",
- " \u003ca href=\"https://scholar.google.com/citations?user=MxxZkEcAAAAJ\u0026hl=en\"\u003eKevin Murphy\u003c/a\u003e\n",
+ " Kevin Murphy\n",
" ·\n",
- " \u003ca href=\"https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ\"\u003eJoao Carreira\u003c/a\u003e\n",
+ " Joao Carreira\n",
" ·\n",
- " \u003ca href=\"https://www.sjoerdvansteenkiste.com/\"\u003eSjoerd van Steenkiste*\u003c/a\u003e\n",
- " \u003c/p\u003e\n",
- " \u003ch3 align=\"center\"\u003e\u003ca href=\"\"\u003ePaper\u003c/a\u003e | \u003ca href=\"https://trajan-paper.github.io\"\u003eProject Page\u003c/a\u003e | \u003ca href=\"https://github.com/deepmind/tapnet\"\u003eGitHub\u003c/a\u003e \u003c/h3\u003e\n",
- " \u003cdiv align=\"center\"\u003e\u003c/div\u003e\n",
- "\u003c/p\u003e\n",
- "\n",
- "\u003cp align=\"center\"\u003e\n",
- " \u003ca href=\"\"\u003e\n",
- " \u003cimg src=\"https://storage.googleapis.com/dm-tapnet/swaying_gif.gif\" alt=\"Logo\" width=\"50%\"\u003e\n",
- " \u003c/a\u003e\n",
- "\u003c/p\u003e\n",
- "\n",
- ""
+ " Sjoerd van Steenkiste*\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "\n",
+ "\n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "\n"
]
},
{
@@ -407,7 +406,7 @@
" \"\"\"\n",
"\n",
" # note that we do not use query points in the encoding, so it is expected\n",
- " # that num_support_tracks \u003e\u003e num_target_tracks\n",
+ " # that num_support_tracks >> num_target_tracks\n",
"\n",
" num_support_tracks: int\n",
" num_target_tracks: int\n",
@@ -436,7 +435,7 @@
" # pad to 'episode_length' frames\n",
" if self.before_boundary:\n",
" # if input video is longer than episode_length, crop to episode_length\n",
- " if self.episode_length - visibles.shape[0] \u003c 0:\n",
+ " if self.episode_length - visibles.shape[0] < 0:\n",
" visibles = visibles[: self.episode_length]\n",
" tracks_xy = tracks_xy[: self.episode_length]\n",
"\n",
@@ -457,7 +456,7 @@
" np.random.shuffle(idx)\n",
"\n",
" assert (\n",
- " num_input_tracks \u003e= self.num_support_tracks + self.num_target_tracks\n",
+ " num_input_tracks >= self.num_support_tracks + self.num_target_tracks\n",
" ), (\n",
" (\n",
" f\"num_input_tracks {num_input_tracks} must be greater than\"\n",
@@ -484,34 +483,58 @@
" np.transpose(support_tracks_visible, [1, 0]), axis=-1\n",
" )\n",
"\n",
- " # [time, point_id, x/y] -\u003e [point_id, time, x/y]\n",
+ " # [time, point_id, x/y] -> [point_id, time, x/y]\n",
" target_tracks = np.transpose(target_tracks, [1, 0, 2])\n",
" target_tracks_visible = np.transpose(target_tracks_visible, [1, 0])\n",
"\n",
" # Sample query points as random visible points\n",
" num_target_tracks = target_tracks_visible.shape[0]\n",
- " num_frames = target_tracks_visible.shape[1]\n",
- " random_frame = np.zeros(num_target_tracks, dtype=np.int64)\n",
+ " target_queries = self.sample_query_from_targets(\n",
+ " num_target_tracks, target_tracks, target_tracks_visible)\n",
"\n",
- " for i in range(num_target_tracks):\n",
- " visible_indices = np.where(target_tracks_visible[i] \u003e 0)[0]\n",
- " if len(visible_indices) \u003e 0:\n",
- " # Choose a random frame index from the visible ones\n",
- " random_frame[i] = np.random.choice(visible_indices)\n",
- " else:\n",
- " # If no frame is visible for a track, default to frame 0\n",
- " # (or handle as appropriate for your use case)\n",
- " random_frame[i] = 0\n",
+ " # Add channel dimension to target_tracks_visible\n",
+ " target_tracks_visible = np.expand_dims(target_tracks_visible, axis=-1)\n",
"\n",
- " # Create one-hot encoding based on the randomly selected frame for each track\n",
- " idx = np.eye(num_frames, dtype=np.float32)[\n",
- " random_frame\n",
- " ] # [num_target_tracks, num_frames]\n",
+ " # Updates `features` to contain these *new* features and add batch dim.\n",
+ " features_new = {\n",
+ " \"support_tracks\": support_tracks[None, :],\n",
+ " \"support_tracks_visible\": support_tracks_visible[None, :],\n",
+ " \"query_points\": target_queries[None, :],\n",
+ " \"target_points\": target_tracks[None, :],\n",
+ " \"boundary_frame\": np.array([boundary_frame]),\n",
+ " \"target_tracks_visible\": target_tracks_visible[None, :],\n",
+ " }\n",
+ " features.update(features_new)\n",
+ " return features\n",
+ " \n",
+ " def sample_query_from_targets(\n",
+ " self,\n",
+ " num_query_tracks: int,\n",
+ " target_tracks: np.ndarray,\n",
+ " target_tracks_visible: np.ndarray,\n",
+ " ) -> np.ndarray:\n",
+ " \"\"\"Samples query points from target tracks.\"\"\"\n",
+ " random_frame = np.zeros(num_query_tracks, dtype=np.int64)\n",
+ " num_frames = target_tracks_visible.shape[1]\n",
+ " for i in range(num_query_tracks):\n",
+ " visible_indices = np.where(target_tracks_visible[i] > 0)[0]\n",
+ " if len(visible_indices) > 0:\n",
+ " # Choose a random frame index from the visible ones\n",
+ " random_frame[i] = np.random.choice(visible_indices)\n",
+ " else:\n",
+ " # If no frame is visible for a track, default to frame 0\n",
+ " # (or handle as appropriate for your use case)\n",
+ " random_frame[i] = 0\n",
+ " \n",
+ " # Create one-hot encoding based on the randomly selected frame for each track\n",
+ " idx = np.eye(num_frames, dtype=np.float32)[\n",
+ " random_frame\n",
+ " ] # [num_query_tracks, num_frames]\n",
"\n",
" # Use the one-hot index to select the coordinates at the chosen frame\n",
" target_queries_xy = np.sum(\n",
" target_tracks * idx[..., np.newaxis], axis=1\n",
- " ) # [num_target_tracks, 2]\n",
+ " ) # [num_query_tracks, 2]\n",
"\n",
" # Stack frame index and coordinates: [t, x, y]\n",
" target_queries = np.stack(\n",
@@ -521,22 +544,8 @@
" target_queries_xy[..., 1],\n",
" ],\n",
" axis=-1,\n",
- " ) # [num_target_tracks, 3]\n",
- "\n",
- " # Add channel dimension to target_tracks_visible\n",
- " target_tracks_visible = np.expand_dims(target_tracks_visible, axis=-1)\n",
- "\n",
- " # Updates `features` to contain these *new* features and add batch dim.\n",
- " features_new = {\n",
- " \"support_tracks\": support_tracks[None, :],\n",
- " \"support_tracks_visible\": support_tracks_visible[None, :],\n",
- " \"query_points\": target_queries[None, :],\n",
- " \"target_points\": target_tracks[None, :],\n",
- " \"boundary_frame\": np.array([boundary_frame]),\n",
- " \"target_tracks_visible\": target_tracks_visible[None, :],\n",
- " }\n",
- " features.update(features_new)\n",
- " return features"
+ " ) # [num_query_tracks, 3]\n",
+ " return target_queries"
]
},
{
@@ -551,7 +560,10 @@
"# @title Run Model\n",
"\n",
"# Create model and define forward pass.\n",
- "model = track_autoencoder.TrackAutoEncoder(decoder_scan_chunk_size=32)\n",
+ "model = track_autoencoder.TrackAutoEncoder(\n",
+ " decoder_scan_chunk_size=32, # If passing large queries\n",
+ "# decoder_scan_chunk_size=None, # If passing arbitrary small queries\n",
+ ")\n",
"\n",
"@jax.jit\n",
"def forward(params, inputs):\n",
@@ -573,9 +585,9 @@
" transforms.convert_grid_coordinates(\n",
" tracks + 0.5, (width, height), (1, 1)\n",
" ),\n",
- " \"q t c -\u003e t q c\",\n",
+ " \"q t c -> t q c\",\n",
" ),\n",
- " \"visible\": einops.rearrange(visibles, \"q t -\u003e t q\"),\n",
+ " \"visible\": einops.rearrange(visibles, \"q t -> t q\"),\n",
"}\n",
"\n",
"batch = preprocessor.random_map(batch)\n",
@@ -585,6 +597,42 @@
"outputs = forward(params, batch)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# @title Run Model on Custom Query Points\n",
+ "\n",
+ "# Create model and define forward pass.\n",
+ "model = track_autoencoder.TrackAutoEncoder(\n",
+ "# decoder_scan_chunk_size=32, # If passing large queries\n",
+ " decoder_scan_chunk_size=None, # If passing arbitrary small queries\n",
+ ")\n",
+ "\n",
+ "@jax.jit\n",
+ "def forward(params, inputs):\n",
+ " outputs = model.apply({'params': params}, inputs)\n",
+ " return outputs\n",
+ "\n",
+ "# [Optional] Define custom query points [t, x, y], where t is the frame\n",
+ "# index, and x and y are normalized coordinates e.g. [12., 0.01, 0.9]\n",
+ "# Comment this out to use the query points sampled from the target tracks.\n",
+ "query_pts = np.array(\n",
+ " [\n",
+ " [10., 0.5, 0.5], # center of the image\n",
+ " [10., 0.05, 0.05], # top left corner\n",
+ " [10., 0.95, 0.95], # bottom right corner\n",
+ " [10., 0.05, 0.95], # bottom left corner\n",
+ " ]\n",
+ ") # [num_query_points, 3]\n",
+ "batch['query_points'] = query_pts[None, :]\n",
+ "\n",
+ "# Run forward pass\n",
+ "outputs = forward(params, batch)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -612,12 +660,12 @@
" outputs.visible_logits, outputs.certain_logits\n",
")\n",
"\n",
- "# NOTE: uncomment the lines below to also visualize the support \u0026 target tracks.\n",
+ "# NOTE: uncomment the lines below to also visualize the support & target tracks.\n",
"video_length = video.shape[0]\n",
"\n",
"# video_viz = viz_utils.paint_point_track(\n",
"# video,\n",
- "# support_tracks_vis[:, : video.shape[1]],\n",
+ "# support_tracks_vis[:, : video.shape[0]],\n",
"# batch['support_tracks_visible'][0, :, :video_length],\n",
"# )\n",
"# media.show_video(video_viz, fps=10)\n",
@@ -632,7 +680,8 @@
"video_viz = viz_utils.paint_point_track(\n",
" video,\n",
" reconstructed_tracks[:, :video_length],\n",
- " batch['target_tracks_visible'][0, :, :video_length],\n",
+ " reconstructed_visibles[0, :, :video_length],\n",
+ " # np.ones_like(reconstructed_visibles[0, :, :video_length]),\n",
")\n",
"media.show_video(video_viz, fps=10)"
]
diff --git a/tapnet/trajan/track_autoencoder.py b/tapnet/trajan/track_autoencoder.py
index 1386380..d4f1079 100644
--- a/tapnet/trajan/track_autoencoder.py
+++ b/tapnet/trajan/track_autoencoder.py
@@ -72,9 +72,9 @@ class TrackAutoEncoderInputs(TypedDict):
"""Track autoencoder inputs.
Attributes:
- query_points: The (t, x, y) locations of a set of query points on initial
- frame. The decoder predicts the location and visibility of these query
- points for T frames into the future.
+ query_points: The (t, x, y) locations of a set of query points on the
+ chosen frames. The decoder predicts the location and visibility of these
+ points for the remaining frames.
boundary_frame: Int specifying the first frame of any padding in the support
tracks. Track values starting on this frame will be masked out of the
attention for transformers.