Skip to content

PyTorch implementation of a vision transformer from scratch. Implements patch embeddings, positional embeddings and multi-head self-attention.

Notifications You must be signed in to change notification settings

aadiprasad/vision-transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vision Transformer

Implementation based on this paper. See transformer.ipynb for implementation.

Results

We realised that the better models were obtained be means of smaller parameters. That, combined with data augmentation, yielded better results. The best model configuration turned out to be the following.

cfg = {
    # Architecture
    'depth': 6,
    'dropout': 0.1,          
    'mlp_ratio': 4,
    'num_patches': 8,        
    'embed_dim': 192,        

    # Optimization
    'lr': 1e-3,              
    'weight_decay': 0.05,    
}

combined with basic data augmentation:

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([  # 50% chance to apply each
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02),
    ], p=0.5),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5)),
    ], p=0.3),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.1)),  # Very mild erasing
])

performed the best. These are basic augmentations applied randomly to the dataset.

More agressive augmentation did not work any better.

The best model is in best_model.pth.

Embeddings

alt text

We see that 2d and sinusoidal embeddings provide marginal improvements over 1d learned embeddings. None, is, however, significantly worse than the others.

Attention Maps

Visualisation attached: alt text

About

PyTorch implementation of a vision transformer from scratch. Implements patch embeddings, positional embeddings and multi-head self-attention.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published