Up until vision transformers were invented, the dominating model architecture in computer vision was convolutional neural network (CNN), which was invented at 1989 by famous researchers including Yann LeCun and Yoshua Bengio. At 2017, transformers were invented by Google and took the natural language processing domain by storm, but were not adapted successfully to computer vision up until 2020, when Google introduced vision transformers (ViT). Since then, vision transformers play a main role in the computer vision domain, by being the architecture behind top models such as DINOv2, DeiT and CLIP. In order to understand this important architecture, in this post we go back to the original vision transformers paper, titled: “An Image Is Worth 16X16 Words: Transformers For Image Recognition At Scale”.
If you prefer a video format then check out our video:
Can we use a Transformer as-is for images?
Before diving into how vision transformers work, let’s first ask why can’t we simply feed a transformer with an image pixels? In the core of the transformer model we have self-attention, where each token is attending every other token in the input sequence. So, for an input sequence of 12 tokens, we would get a 12 X 12 matrix which is called the attention matrix. This matrix has a quadratic dependency on the sequence length which makes it difficult to scale up the context length.
If we think about feeding an image into a transformer, then each pixel will need to attend all other pixels. Consider a low quality 256 X 256 image. In such an image we have 65k pixels, which will result in a 65k X 65k attention matrix. Looking at a 512 X 512 image which is still of low quality, we have 262k pixels and a 262k X 262k attention matrix. Just recently we’ve started to hear about models that can handle such long sequences, such as LongNet, and we’re just talking about low quality images here. So handling images as-is results in too long sequence lengths which makes the solution not practical and not scalable. Ok, so how are vision transformers able to deal with the issue?
How Do Vision Transformers Work?
In the image above we can see the Vision Transformer architecture, which includes the following stages:
- Breaking the image into patches – The main idea is that instead of feeding the image pixels in a sequence to the transformer, we break the image into patches, as we see in the example above on the bottom left. We then create a sequence of patches in order to feed that sequence to the transformer, similar to how we would feed a transformer with a sequence of tokens.
- Adapting the patches to a transformer input – In order to move the patches to the dimension which the transformer can work with, we pass them via a trainable linear projection layer that yields a vector for each patch with the dimension that match the transformer.
- Positional embeddings – Before feeding these vectors to the transformer, in order to retain the information about each patch location in the original image, positional embeddings are added to the linear projection output, which gives us the final patch embeddings for the input image.
After these steps, we have embeddings for the patches which we can pass to the transformer, and from this point the transformer behaves as a regular transformer. In addition to the patch embeddings, we also add a special learnable embedding in the beginning of the sequence, and we use the transformer output for that embedding in order to predict the class of the input image. So, this token learns to grasp global information about the image.
Inductive Bias
An interesting observation around vision transformer is its reduced inductive bias comparing to convolutional neural networks. With CNNs, we have high inductive bias. For a given input, each part in the first layer as a local view on the input image, where each part looks at a different part of the image. In the following layer, each part has a local view about the previous layer output, but its receptive field is larger over the input image. And as we move towards to the last layers, each part has a very wide receptive field, so it can take into account information from the entire image. So, we provide a lot of guidance to the CNN model regarding how to process the image. This is unlike vision transformers, where thanks to self-attention, each patch token can attend to any other token in every layer, so the first layers can also take into account information from the entire image and not only the last layers. So, with ViT we provide much less guidance to the model. Even the positional embeddings are trained so the model needs to learn them from scratch.
References & Links
- Paper page – https://arxiv.org/abs/2010.11929
- Video – https://youtu.be/NetSJM590Lo
- Join our newsletter to receive concise 1 minute summaries of the papers we review – https://aipapersacademy.com/newsletter/
All credit for the research goes to the researchers who wrote the paper we covered in this post.