Vision Transformers for Computer Vision
Mike Wang, John Inacay, and Wiley Wang (All authors contributed equally)
When Transformers were initially invented, transformers were typically applied to the field of Natural Language Processing. Importantly, transformers are one of the main building blocks of neural network architectures such as Bert and GPT-3. Due to the success of Transformers in the field of Natural Language Processing (NLP), many researchers have begun applying the Transformer architecture to other fields such as Computer Vision. The Vision Transformer created by Google Research and Brain Team is one such architecture that uses a Transformer based architecture to tackle the problem of Image Classification. For our code changes, we used the original Vit-Pytorch Cats and Dogs code here.
Quick Recap on Transformers
We have previously written about Transformers here. To briefly recap, a transformer is a network architecture that takes in a sequence of tokens, mixes them, and outputs a new token sequence where each individual has “context” information from the rest of the sequence. The trick with the Vision Transformer is the adaptation of images to the token sequence.
What is the Vision Transformer?
As you may know, transformers are sequence to sequence models. However, a single RGB frame isn’t typically considered a sequence. In our case, how do you apply a sequence-to-sequence model to a 224 x 224 pixel image? The Vision Transformer solves this problem by breaking down an input image into a 16x16 grid of patches. Each patch is a 14x14 pixel subsection of the image which we then flatten to a linear embedding. We then also flatten the grid to create a “sequence” of 256 patches, where the 2d patch position of the patch maps to a 1d position. This 256 patch sequence can then be put into the Transformer. To get here from NLP, each patch in the image problem is analogous to a word in the language problem. Using a sequence of patches as input, the Vision Transformer then outputs a predicted class for the whole image.
What is the Class Token?
One of the interesting things about the Vision Transformer is that the architecture uses Class Tokens. These Class Tokens are randomly initialized tokens that are prepended to the beginning of your input sequence. What is the reason for this Class Token and what does it do? Note that the Class Token is randomly initialized so it doesn’t contain any useful information on its own. However, the Class Token is able to accumulate information from the other tokens in the sequence the deeper and more layers the Transformer is. When the Vision Transformer finally performs the final classification of the sequence, it uses an MLP head which only looks at data from the last layer’s Class Token and no other information. This operation suggests that the Class Token is a placeholder data structure that’s used to store information that is extracted from other tokens in the sequence. By allocating an empty token for this procedure, it seems like the Vision Transformer makes it less likely to bias the final output towards or against any single one of the other individual tokens.
Positional Encoding
When researchers use transformers to build language models, they typically try to encode each word’s positional data within the input sequence. They do this by adding a positional encoding to each word indicating the position of each word. Likewise, the Vision Transformer also adds a positional encoding to each patch. In this way, the Vision Transformer can determine that the top left patch is the 1st token and the bottom right patch is the last token.
The Program Setup
We downloaded the Dogs vs Cats dataset from here. The training and validation datasets are split 80/20. In our example code, we modified the original Vision Transformer implementation that uses Linformer to use the PyTorch version of Transformer.
For our Visual Transformer implementation, at the last layers of MLP head for classification, we tried out different layer configurations, and chose to have two fully connected layers.
Conclusion
Why is the Vision Transformer important? The Vision Transformer allows us to apply a Neural Network Architecture that is traditionally used for building Language Models to the field of Computer Vision. It also allows us to formulate the image recognition problem as a sequence to sequence problem. Instead of interpreting an image as a matrix of pixels, Vision Transformers show us that we can interpret an image as a sequence of patches. By formulating the image recognition problem in this way, we’re now able to apply the new Transformer architecture to the relatively older problem of Image Recognition.