Skip to content

Tubelet Embedding is not implemented. #11

@CreamMeatball

Description

@CreamMeatball

Hi, I'm student researcher in Korea.
GratefulIy, I received help by referring to your code.
But I think there is omission in your code.
In Original ViViT paper from Google, 'Tubelet Embedding' is using.
There are 2 embedding method in paper, 'Uniform Sampling' and 'Tubelet Embedding'.
But your code is only implemented by 1 embedding method, Uniform Sampling.

I make the example code for Tubelet Embedding below.

Create the Class 'Tubelet Embedding'

`

  class TubeletEmbedding(nn.Module):
  
  
      def __init__(self, image_size, patch_size, tubelet_size, in_channels, dim):
  
          super().__init__()
          # Conv3d expect the shape (B, C, T, H, W)
          self.tubelet_embedding = nn.Sequential(
              Rearrange('b t c h w -> b c t h w'),
              nn.Conv3d(
                  in_channels,
                  dim,
                  kernel_size=(tubelet_size, patch_size, patch_size),
                  stride=(tubelet_size, patch_size, patch_size)
              ),
              Rearrange('b d t h w -> b (t h w) d'),
          )
  
      def forward(self, x):
          return self.tubelet_embedding(x)

`

In class ViViT, create the intersection for 2 embedding method

`

          super().__init__()

          ...
  
          self.embedding_method = embedding_method
  
          ...
  
          # --- Tokenization(Embedding) Method Intersection ---
          if self.embedding_method == 'tubelet':
              assert num_frames % tubelet_size == 0, \
                  f"For tubelet embedding, num_frames ({num_frames}) must be divisible by tubelet_size ({tubelet_size})."
              self.to_patch_embedding = TubeletEmbedding(image_size, patch_size, tubelet_size, in_channels, dim)
              self.temporal_seq_length = num_frames // tubelet_size
          else: # 'uniform_frame_sampling'
              self.to_patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, dim)
              self.temporal_seq_length = num_frames
  
           ...
  
          self.pos_embedding = nn.Parameter(torch.randn(1, num_patches_per_frame + 1, dim))
          self.space_token = nn.Parameter(torch.randn(1, 1, dim))

`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions