Video and GitHub repo to go along with this post.

In this post, I explain the vision transformer (ViT) architecture, which has found its way into computer vision as a powerful alternative to Convolutional Neural Networks (CNNs).

This implementation will focus on classifying the CIFAR-10 dataset, but is adaptable to many tasks, including semantic segmentation, instance segmentation, and image generation. As we will see, training small ViT models is difficult, and the notebook on fine-tuning (later in this post) explains how to get around these issues.

We begin by downloading the CIFAR-10 dataset, and transforming the data to torch.Tensors.

import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import math

transform = transforms.Compose([transforms.ToTensor()])

train_data = datasets.CIFAR10(root='./data/cifar-10', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data/cifar-10', train=False, download=True, transform=transform)

train_data, test_data
Files already downloaded and verified
Files already downloaded and verified





(Dataset CIFAR10
     Number of datapoints: 50000
     Root location: ./data/cifar-10
     Split: Train
     StandardTransform
 Transform: Compose(
                ToTensor()
            ),
 Dataset CIFAR10
     Number of datapoints: 10000
     Root location: ./data/cifar-10
     Split: Test
     StandardTransform
 Transform: Compose(
                ToTensor()
            ))

The images are represented as 3 channel (RGB) 32x32 pixel images. The dataset can be indexed, with the first index being the image index, and the second index indexing either the image data or the target. The pixel values are represented as torch.float32 values from 0 to 1.

train_data.data.shape, len(train_data.targets)
((50000, 32, 32, 3), 50000)
train_data[0][0].numpy().shape, train_data[0][1]
((3, 32, 32), 6)
train_data[0][0][0], train_data[0][0].dtype
(tensor([[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],
         [0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],
         [0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],
         ...,
         [0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],
         [0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],
         [0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]]),
 torch.float32)

If you are familiar with the transformer architecture, you likely know that transformers work with vectors to model different modalities. For a text-based modality, this means somehow tokenizing a string of text into characters or larger chunks, and training an embedding table to represent each token as a vector. We hope that tokenization results in semantic units, so that each vector may represent a concept with a specific meaning. As an example, the string “This is a test.” may tokenize as follows:

This is a test.

To adapt the transformer architecture for image tasks, we need to represent image data as a sequence of vectors, similar to how text is tokenized. In the original ViT paper, the authors address this by dividing an image into many patches and flattening them into vectors. With CIFAR-10, an image $x \in \mathbb{R}^{H\times W\times C}$ is turned into several flattened 2D patches of the form $x_p \in \mathbb{R}^{N\times (P^2\cdot C)}$, where $(H,W)$ are the image dimensions (32x32), $C$ is the number of channels (3 for RGB), and $P$ is the patch size. The number of flattened 2D patches is then $N = \frac{HW}{P^2}$. Finally, we project the flattened patches to latent vectors of size $D$, using the linear projection $\mathbf{E} \in \mathbb{R}^{(P^2\cdot C)\times D}$.

patch_size = 4
for i in range(0, 32, patch_size):
    for j in range(0, 32, patch_size):
        patch = train_data[0][0][:, i:i+patch_size, j:j+patch_size]

        # do something with patch...

print(patch.shape)
torch.Size([3, 4, 4])

A naive implementation of turning the images into patches may look like the code above. However, we can accelerate this process by using torch.Tensor.unfold.

image = torch.arange(0.,48).reshape(1, 3, 4, 4) # batch size, channels, width, height
image, image.unfold(2, 2, 2).unfold(3, 2, 2) # first unfold width, then height into 2x2 patches
(tensor([[[[ 0.,  1.,  2.,  3.],
           [ 4.,  5.,  6.,  7.],
           [ 8.,  9., 10., 11.],
           [12., 13., 14., 15.]],
 
          [[16., 17., 18., 19.],
           [20., 21., 22., 23.],
           [24., 25., 26., 27.],
           [28., 29., 30., 31.]],
 
          [[32., 33., 34., 35.],
           [36., 37., 38., 39.],
           [40., 41., 42., 43.],
           [44., 45., 46., 47.]]]]),
 tensor([[[[[[ 0.,  1.],
             [ 4.,  5.]],
 
            [[ 2.,  3.],
             [ 6.,  7.]]],
 
 
           [[[ 8.,  9.],
             [12., 13.]],
 
            [[10., 11.],
             [14., 15.]]]],
 
 
 
          [[[[16., 17.],
             [20., 21.]],
 
            [[18., 19.],
             [22., 23.]]],
 
 
           [[[24., 25.],
             [28., 29.]],
 
            [[26., 27.],
             [30., 31.]]]],
 
 
 
          [[[[32., 33.],
             [36., 37.]],
 
            [[34., 35.],
             [38., 39.]]],
 
 
           [[[40., 41.],
             [44., 45.]],
 
            [[42., 43.],
             [46., 47.]]]]]]))

We can then reshape this tensor into flat patches seperated by channel. We want to combine patches of the same location by channel, so we then permute the dimensions and reshape once again.

# N = H x W / P^2, or the number of patches
image.unfold(2,2,2).unfold(3,2,2).reshape(1, -1, 4, 4) # B x C x N x (P^2)
tensor([[[[ 0.,  1.,  4.,  5.],
          [ 2.,  3.,  6.,  7.],
          [ 8.,  9., 12., 13.],
          [10., 11., 14., 15.]],

         [[16., 17., 20., 21.],
          [18., 19., 22., 23.],
          [24., 25., 28., 29.],
          [26., 27., 30., 31.]],

         [[32., 33., 36., 37.],
          [34., 35., 38., 39.],
          [40., 41., 44., 45.],
          [42., 43., 46., 47.]]]])
image.unfold(2,2,2).unfold(3,2,2).reshape(1, -1, 4, 4).permute(0, 2, 1, 3) # B x N x C x P
image.unfold(2,2,2).unfold(3,2,2).reshape(1, -1, 4, 4).permute(0, 2, 1, 3).reshape(1, 4, -1) # B x N x (C*P^2)
tensor([[[ 0.,  1.,  4.,  5., 16., 17., 20., 21., 32., 33., 36., 37.],
         [ 2.,  3.,  6.,  7., 18., 19., 22., 23., 34., 35., 38., 39.],
         [ 8.,  9., 12., 13., 24., 25., 28., 29., 40., 41., 44., 45.],
         [10., 11., 14., 15., 26., 27., 30., 31., 42., 43., 46., 47.]]])

Another embedding that is critical for the transformer architecture to understand context is the positional embedding. In the text modality, the positional embedding is often implemented using cosine functions. To represent 2D position however, we will use a standard embedding table that is learned during training. There are other possibilities for representing position, including 2D-aware positional embeddings, but these are harder to implement and result in negligble performance differences.

Finally, we also want to add a specific vector for information about the class of the image to the transformer input. Through each transformer block, this vector is modified, before it is fed into a multi-layer perceptron at the last block to determine the class of the image. We will use a learnable embedding for this and prepend it to the other embedded patches.

class PatchEmbedding(nn.Module):
    def __init__(self, img_size = 32, patch_size = 4, in_chans = 3, embed_dim = 768):
        super().__init__()
        self.img_size   = img_size
        self.patch_size = patch_size    # P
        self.in_chans   = in_chans      # C
        self.embed_dim  = embed_dim     # D

        self.num_patches = (img_size // patch_size) ** 2        # N = H*W/P^2
        self.flatten_dim = patch_size * patch_size * in_chans   # P^2*C
        
        self.proj = nn.Linear(self.flatten_dim, embed_dim) # (P^2*C,D)

        self.position_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))
        self.class_embed    = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        B, C, H, W = x.shape

        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.reshape(1, -1, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 1, 3).reshape(B, self.num_patches, -1)

        x = self.proj(x)

        cls_emb = self.class_embed.expand(B, -1, -1)
        x = torch.cat((cls_emb, x), dim = 1)

        x = x + self.position_embed
        return x
patch_embed = PatchEmbedding()

embeddings = patch_embed(torch.stack([train_data[i][0] for i in range(10)]))

embeddings, embeddings.shape
(tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-3.9495e-01,  2.6499e-01, -9.8315e-02,  ...,  1.9426e-01,
            3.5573e-01, -1.2782e-01],
          [-4.9758e-01,  2.3686e-01, -1.7172e-01,  ...,  3.5498e-02,
            1.5119e-01,  6.9413e-03],
          ...,
          [-2.0157e-01,  1.8175e-01, -1.1618e-01,  ...,  6.3366e-02,
            1.4141e-01, -2.3075e-01],
          [-1.4961e-01,  2.5366e-01, -4.7240e-02,  ...,  3.0900e-02,
            1.3584e-01, -1.5386e-01],
          [-1.2538e-01,  1.3091e-01, -1.2969e-01,  ...,  1.0788e-01,
            5.3261e-02, -1.0373e-01]],
 
         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-3.2531e-01,  2.7973e-01, -2.7797e-01,  ...,  5.6054e-02,
            2.4562e-01, -5.4279e-02],
          [-3.7752e-01,  3.7624e-01, -1.9639e-01,  ..., -4.6192e-02,
            2.7113e-01, -6.9035e-02],
          ...,
          [-6.4879e-02, -9.6972e-04, -2.6319e-01,  ...,  2.6897e-01,
           -6.9230e-02, -7.6847e-02],
          [-2.5097e-01,  1.2145e-01, -3.3001e-01,  ...,  1.3163e-01,
            3.4776e-01, -5.5025e-02],
          [-2.2934e-01,  1.7568e-01, -2.5048e-01,  ...,  1.0340e-01,
            9.4081e-02, -3.7091e-02]],
 
         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-4.5876e-01,  2.7328e-01, -1.4612e-01,  ...,  1.1323e-01,
            2.5706e-01, -6.0638e-02],
          [-3.0678e-01,  3.2576e-01, -1.7295e-01,  ..., -2.9040e-02,
            2.4391e-01, -1.0266e-01],
          ...,
          [-5.1798e-01,  1.2231e-01, -2.8323e-01,  ...,  1.2910e-01,
            1.0092e-01, -5.8057e-02],
          [-4.3961e-01,  3.0831e-01, -6.3392e-02,  ...,  1.3885e-01,
            3.7307e-01, -3.5249e-01],
          [-1.5562e-01,  1.1467e-01, -2.3594e-01,  ...,  2.0727e-01,
            1.0254e-01, -2.6508e-02]],
 
         ...,
 
         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-4.2991e-01,  2.3180e-01, -1.6275e-01,  ...,  1.3453e-01,
            2.4938e-01, -5.4029e-02],
          [-2.7091e-01,  2.7341e-01, -2.3737e-01,  ..., -7.2334e-02,
            1.3330e-01, -1.0379e-01],
          ...,
          [-5.8417e-01,  2.4607e-01, -2.9059e-01,  ...,  1.0830e-02,
            1.0258e-01,  2.7538e-02],
          [-3.6148e-01,  1.1907e-01, -3.6477e-02,  ...,  5.2896e-02,
            4.4220e-02, -9.9179e-02],
          [-4.7744e-01,  1.0193e-01, -1.2614e-01,  ...,  2.0482e-01,
            5.7551e-02,  1.5926e-01]],
 
         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-9.7016e-01,  4.5935e-01, -2.7754e-01,  ...,  1.2938e-01,
            5.4929e-01, -1.2435e-01],
          [-5.6197e-01,  4.6641e-01, -1.0930e-02,  ...,  1.4106e-01,
            6.7737e-01, -1.1047e-01],
          ...,
          [-6.8178e-03,  1.5677e-01, -1.4674e-01,  ...,  9.6236e-02,
           -1.9692e-02, -1.9123e-01],
          [-9.7842e-02,  1.4272e-01, -3.2074e-01,  ..., -2.7982e-02,
            4.1490e-02, -1.2003e-01],
          [ 7.7636e-02,  7.0059e-02, -9.6226e-02,  ...,  1.1831e-01,
           -1.0042e-02, -1.0293e-01]],
 
         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-3.6146e-01,  1.9650e-01, -8.9722e-02,  ...,  2.0228e-01,
            2.3418e-01, -2.1788e-02],
          [-5.1785e-01,  7.8911e-02, -2.4814e-01,  ...,  6.2777e-02,
            1.8336e-01,  1.5384e-02],
          ...,
          [ 9.7187e-02,  6.5835e-02, -1.2366e-01,  ...,  8.6316e-02,
           -1.8488e-02, -8.9563e-02],
          [-1.8432e-02,  1.1963e-01, -1.7618e-01,  ..., -3.2793e-02,
           -2.6466e-02, -6.3163e-02],
          [-1.1296e-01,  1.0668e-01, -8.1302e-02,  ...,  1.3928e-01,
            7.5195e-02, -1.1134e-01]]], grad_fn=<AddBackward0>),
 torch.Size([10, 65, 768]))

As seen above, we are able to embed batches into our desired embedding dimension, with a correct number of vectors $N+1$.

We can now continue to implement the standard transformer architecture, with one notable change from GPT-like architectures. The attention mechanism of GPT involves multi-headed causal self-attention, which means that vectors are only allowed to query and interact with previous vectors. Although this makes sense in a language model that wants to extract causal contextual information, we want all vectors to communicate with all other vectors, and want to prevent applying an attention mask. Otherwise, the implementation remains unchanged.

class SelfAttention(nn.Module):
    def __init__(self, embed_dim = 768, num_heads = 4, bias = False, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim   = embed_dim
        self.num_heads   = num_heads
        self.head_dim    = embed_dim // num_heads

        self.query   = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.key     = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.value   = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.out     = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, _ = x.size()

        q = self.query(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.key(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.value(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # do NOT use causal attention as we are not dealing with sequential data (image patches are unordered)
        attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
        attn = attn.softmax(dim=-1)

        out = (attn @ v).permute(0, 2, 1, 3).reshape(B, N, self.embed_dim)

        out = self.out(out)

        return out
MSA = SelfAttention()
LN = nn.LayerNorm(embeddings.shape, bias=False)

MSA(LN(embeddings))
tensor([[[-0.5860,  0.3954,  0.2497,  ...,  0.1984,  0.0070,  0.1630],
         [-0.6067,  0.4070,  0.2400,  ...,  0.2375,  0.0075,  0.1689],
         [-0.5888,  0.3900,  0.2408,  ...,  0.2250,  0.0061,  0.1623],
         ...,
         [-0.5961,  0.3959,  0.2409,  ...,  0.2233,  0.0063,  0.1664],
         [-0.5848,  0.3911,  0.2453,  ...,  0.2099,  0.0056,  0.1627],
         [-0.5837,  0.3875,  0.2426,  ...,  0.2105,  0.0060,  0.1622]],

        [[-0.5661,  0.3710,  0.2337,  ...,  0.2130,  0.0105,  0.1403],
         [-0.5794,  0.3729,  0.2260,  ...,  0.2429,  0.0077,  0.1412],
         [-0.5740,  0.3701,  0.2325,  ...,  0.2309,  0.0033,  0.1431],
         ...,
         [-0.5679,  0.3716,  0.2228,  ...,  0.2239,  0.0129,  0.1400],
         [-0.5739,  0.3683,  0.2193,  ...,  0.2312,  0.0209,  0.1412],
         [-0.5778,  0.3683,  0.2226,  ...,  0.2343,  0.0075,  0.1407]],

        [[-0.5962,  0.3924,  0.2470,  ...,  0.2167,  0.0121,  0.1635],
         [-0.6102,  0.3943,  0.2456,  ...,  0.2575,  0.0085,  0.1720],
         [-0.6101,  0.3979,  0.2475,  ...,  0.2511,  0.0041,  0.1728],
         ...,
         [-0.6054,  0.3946,  0.2409,  ...,  0.2755,  0.0149,  0.1746],
         [-0.6069,  0.3830,  0.2432,  ...,  0.2586,  0.0115,  0.1723],
         [-0.5958,  0.3992,  0.2414,  ...,  0.2405,  0.0254,  0.1696]],

        ...,

        [[-0.5766,  0.3838,  0.2403,  ...,  0.2068,  0.0180,  0.1598],
         [-0.5917,  0.3846,  0.2338,  ...,  0.2433,  0.0178,  0.1638],
         [-0.5838,  0.3799,  0.2379,  ...,  0.2283,  0.0100,  0.1623],
         ...,
         [-0.5829,  0.3680,  0.2296,  ...,  0.2442,  0.0090,  0.1594],
         [-0.5840,  0.3761,  0.2376,  ...,  0.2264,  0.0136,  0.1623],
         [-0.5744,  0.3830,  0.2314,  ...,  0.2276,  0.0190,  0.1604]],

        [[-0.5098,  0.3486,  0.2086,  ...,  0.1954,  0.0156,  0.1372],
         [-0.5640,  0.3621,  0.1949,  ...,  0.2696,  0.0105,  0.1501],
         [-0.5355,  0.3669,  0.2019,  ...,  0.2246,  0.0319,  0.1437],
         ...,
         [-0.5099,  0.3471,  0.2052,  ...,  0.2006,  0.0181,  0.1391],
         [-0.5051,  0.3417,  0.2025,  ...,  0.1998,  0.0165,  0.1374],
         [-0.5016,  0.3411,  0.2031,  ...,  0.1934,  0.0188,  0.1351]],

        [[-0.5958,  0.3998,  0.2388,  ...,  0.2080,  0.0119,  0.1450],
         [-0.6008,  0.3943,  0.2309,  ...,  0.2356,  0.0160,  0.1482],
         [-0.6080,  0.3970,  0.2299,  ...,  0.2437,  0.0152,  0.1512],
         ...,
         [-0.5894,  0.3960,  0.2359,  ...,  0.2095,  0.0150,  0.1462],
         [-0.5887,  0.3914,  0.2346,  ...,  0.2109,  0.0092,  0.1452],
         [-0.5935,  0.3959,  0.2324,  ...,  0.2166,  0.0157,  0.1476]]],
       grad_fn=<UnsafeViewBackward0>)

Finally, we want to implement the multi-layer perceptron and combine all our modules into the transformer block.

class MLP(nn.Module):
    def __init__(self, embed_dim = 768, bias = False, dropout = 0.1):
        super().__init__()
        self.c_fc = nn.Linear(embed_dim, embed_dim * 4, bias=bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(embed_dim * 4, embed_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)

        return x
    
class Block(nn.Module):

    def __init__(self, embed_dim = 768, bias = False):
        super().__init__()
        self.ln_1 = nn.LayerNorm(embed_dim, bias=bias)
        self.attn = SelfAttention(embed_dim, bias=bias)
        self.ln_2 = nn.LayerNorm(embed_dim, bias=bias)
        self.mlp = MLP(embed_dim, bias=bias)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
    
    

We will also need to extract information from the class token mentioned earlier to classify the image according to the 10 CIFAR-10 categories. We do this by acting on the class token (a tensor with dimension equal to the embed_dim) with an MLP. The original ViT paper suggests using a hidden layer during pretraining, but finetuning without a hidden layer. For simplicity, we will use a linear layer in the example below.

The final element to consider in our model is the output. Unlike a transformer model like GPT, we would like to produce a probability distribution of the various image classes in CIFAR-10. To achieve this, we use the class vector prepended to the input that was mentioned earlier. After the vector has passed through each transformer block, we can finally take the class vector and pass it through a linear projection once more to get a probability distribution across all ten image classes.

class ViT(nn.Module):

    def __init__(self, embed_dim = 768, num_layers = 4, out_dim = 10, bias = False, dropout = 0.1):
        super().__init__()

        self.transformer = nn.ModuleDict(dict(
            pe = PatchEmbedding(),
            drop = nn.Dropout(dropout),
            h = nn.ModuleList([Block() for _ in range(num_layers)]),
            ln_f = nn.LayerNorm(embed_dim)
        ))
        self.head = nn.Linear(embed_dim, out_dim, bias=False)


        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def forward(self, x):
        emb = self.transformer.pe(x)
        x = self.transformer.drop(emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        class_token = x[:, 0]
        logits = self.head(class_token)
        return logits

vit = ViT()
vit(torch.stack([train_data[i][0] for i in range(10)]))
number of parameters: 28.42M





tensor([[-0.6155,  0.0833,  0.3612,  0.1262,  0.3042,  0.1746, -0.8091, -0.1743,
          0.7834, -0.0631],
        [-1.0184, -0.1414, -0.1417, -0.2963,  0.3737,  0.0205, -0.6493, -0.0109,
          0.8699, -0.1459],
        [-0.9966,  0.1542, -0.1182, -0.0219,  0.5765,  0.0459, -0.4171, -0.1138,
          0.9797, -0.0116],
        [-0.7114,  0.0771,  0.1259,  0.0709,  0.4262,  0.2157, -0.9031, -0.0468,
          0.7561, -0.0685],
        [-0.7684,  0.0140, -0.1220, -0.2881,  0.6543, -0.0601, -0.4210, -0.1563,
          0.8449,  0.0266],
        [-0.6628,  0.2121, -0.0296,  0.1316,  0.4504,  0.5865, -1.1234, -0.1563,
          0.6179,  0.0671],
        [-0.8797,  0.1741,  0.2067, -0.1092,  0.4820,  0.0855, -0.3633,  0.1939,
          0.8685, -0.4827],
        [-0.8408,  0.1524, -0.1209, -0.2138,  0.3352,  0.1860, -0.5956,  0.1675,
          0.9737,  0.0405],
        [-0.8931, -0.3047, -0.1372, -0.3863,  0.4625,  0.2419, -0.4026,  0.1199,
          0.7785,  0.1355],
        [-0.4177,  0.0746,  0.1027,  0.1930,  0.4062,  0.0012, -0.6935, -0.1216,
          1.0843, -0.0408]], grad_fn=<MmBackward0>)

Finetuning larger ViTs on small datasets

Small ViT models struggle to learn the inductive biases necessary for image recognition on small datasets, such as CIFAR-10, which makes training them difficult. Instead, ViTs are often pretrained on very large image datasets before being finetuned to downstream tasks, similar to GPT-style LLMs. As an example, the original ViT paper uses ImageNet-21k with 21k classes and 14M images as one of its datasets.

More technically, fine-tuning a ViT model typically involves turning the classification head into an MLP with no hidden layer ($D \times K$, where $D$ is the embedding dimension and $K$ is the number of downstream classes). Additionally, position embeddings need to be interpolated in 2D to adjust to different image resolutions.

Here, we will fine-tune the smallest model from the original ViT paper, ViT-B/16, trained on ImageNet-21k. We begin by downloading the CIFAR-10 dataset, scaling images up from 32x32 to 224x224. Typically, images are scaled up beyond the pre-training size at fine-tuning to improve performance, but here we will just keep it at the input dimensions for simplicity.

import torch
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_set = torchvision.datasets.CIFAR10(root='./data/cifar-10-finetune', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data/cifar-10-finetune', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

We then load the ViT-B/16 model from the torchvision module and load it onto the GPU. We need to set the head of the model to fit our fine-tuning task, which in this case has 10 classes. A simple linear layer is most commonly used, and we make sure to set the first dimension to the final dimension of the previous layer. Then, we use an Adam optimizer with some learning rate to finetune. The original ViT paper discusses the learning rate schedules for fine-tuning in further depth, but a constant rate will suffice for demonstration.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torchvision.models.vit_b_16(pretrained=True)
model = model.to(device)

model.heads = nn.Linear(model.heads[0].in_features, 10)
model.heads = model.heads.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

To train and evaluate the model, we use simple and generic PyTorch training code.

def train(model, criterion, optimizer, loader, epochs=10):
    model.train()
    for epoch in range(epochs):
        for images, labels in loader:
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

At fine-tuning, very little training is actually required, and we can get decent performance from just 1 epoch.

train(model, criterion, optimizer, train_loader, epochs=1)
accuracy = evaluate(model, test_loader)
print(f"Test Accuracy: {accuracy}%")

With this simple fine-tuning approach, we achieve over 95% accuracy on CIFAR-10, which only became SOTA in 2014 with fractional max-pooling. Note that this finetuning was done with all the layers locked (requires_grad=False), except for the classification head. More advanced fine-tuning, involving larger pre-trained models, learning rate schedules, and image resolution adjustments all bring this to what has been SOTA performance since 2020, 99.5% top-1 accuracy. Although this comes at the cost of having to train a huge model and needing extra training data, the DeiT vision transformer models introduced in Training data-efficient image transformers & distillation through attention are much smaller than ViT-H/16, can be distilled from Convnets, and achieve up to 99.1% accuracy on CIFAR-10.