-
Notifications
You must be signed in to change notification settings - Fork 70
Open
Description
Hi, I'm student researcher in Korea.
GratefulIy, I received help by referring to your code.
But I think there is omission in your code.
In Original ViViT paper from Google, 'Tubelet Embedding' is using.
There are 2 embedding method in paper, 'Uniform Sampling' and 'Tubelet Embedding'.
But your code is only implemented by 1 embedding method, Uniform Sampling.
I make the example code for Tubelet Embedding below.
Create the Class 'Tubelet Embedding'
`
class TubeletEmbedding(nn.Module):
def __init__(self, image_size, patch_size, tubelet_size, in_channels, dim):
super().__init__()
# Conv3d expect the shape (B, C, T, H, W)
self.tubelet_embedding = nn.Sequential(
Rearrange('b t c h w -> b c t h w'),
nn.Conv3d(
in_channels,
dim,
kernel_size=(tubelet_size, patch_size, patch_size),
stride=(tubelet_size, patch_size, patch_size)
),
Rearrange('b d t h w -> b (t h w) d'),
)
def forward(self, x):
return self.tubelet_embedding(x)
`
In class ViViT, create the intersection for 2 embedding method
`
super().__init__()
...
self.embedding_method = embedding_method
...
# --- Tokenization(Embedding) Method Intersection ---
if self.embedding_method == 'tubelet':
assert num_frames % tubelet_size == 0, \
f"For tubelet embedding, num_frames ({num_frames}) must be divisible by tubelet_size ({tubelet_size})."
self.to_patch_embedding = TubeletEmbedding(image_size, patch_size, tubelet_size, in_channels, dim)
self.temporal_seq_length = num_frames // tubelet_size
else: # 'uniform_frame_sampling'
self.to_patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, dim)
self.temporal_seq_length = num_frames
...
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches_per_frame + 1, dim))
self.space_token = nn.Parameter(torch.randn(1, 1, dim))
`
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels