This repository contains a from-scratch implementation of the Vision Transformer (ViT) model in PyTorch. The model is built using the core concepts from the original paper, "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", and is trained on the classic MNIST dataset.
The entire implementation, from data loading to model definition and training, is contained in the main Jupyter Notebook.
This code is a simplified implementation based on the architecture described in the original Google Research paper:
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, et al. arXiv:2010.11929
Follow these steps to set up your environment and run the code.
This project uses Python 3.13. It is highly recommended to use a virtual environment.
# Create a virtual environment
python3 -m venv vit-env
# Activate the virtual environment
# On macOS/Linux:
source vit-env/bin/activate
# On Windows:
.\vit-env\Scripts\activateInstall the necessary Python libraries using pip:
pip install torch torchvision numpy jupyterAll the code is in the notebook.ipynb notebook.
- Launch Jupyter Notebook from your activated environment:
jupyter notebook
- Open the
notebook.ipynbnotebook. - Run all cells to download the MNIST dataset, build the model, and execute the training process.
The following are the key components of the ViT architecture.
This is the most critical part of the ViT. It's responsible for converting a 2D image into a 1D sequence of token embeddings, just as a sequence of word embeddings in NLP.
We cleverly use a 2D Convolutional layer (nn.Conv2d) to achieve this. By setting the kernel_size and stride to be equal to the patch_size, the nn.Conv2d layer performs one operation for each image patch, effectively splitting the image and applying a linear projection at the same time.
This is the key data flow within the PatchEmbedding module:
-
Input:
(B, C, H, W)B= Batch Size (e.g., 128)C= Channels (1 for MNIST)H= Image Height (32)W= Image Width (32)
-
self.linear_project(x)->(B, d_model, P_col, P_row)- The
nn.Conv2dlayer projects the patches. d_model= Embedding dimension (9)P_col= Patches in height (32 / 16 = 2)P_row= Patches in width (32 / 16 = 2)- Shape becomes:
(128, 9, 2, 2)
- The
-
x.flatten(2)->(B, d_model, P)- Flattens the patch grid into a single dimension.
P= Total patches (2 * 2 = 4)- Shape becomes:
(128, 9, 4)
-
x.transpose(1, 2)->(B, P, d_model)- Transposes the dimensions to get the desired
(Batch, Sequence_Length, Embedding_Dim)format. - Shape becomes:
(128, 4, 9) - This sequence of 4 patch embeddings is now ready to be fed into the Transformer.
- Transposes the dimensions to get the desired
This module adds two crucial components:
- [CLS] Token: A learnable parameter
nn.Parameterthat is prepended to the sequence of patch embeddings. The Transformer's output for this single token will be used for the final classification. - Positional Embeddings: Standard sine/cosine positional embeddings are added to the patch embeddings. This is essential because the Transformer's self-attention mechanism is permutation-invariant, so we must provide explicit information about the relative positions of the patches.
This is a standard Transformer encoder block, consisting of:
- Layer Normalization
- Multi-Head Self-Attention (MHA)
- Residual (skip) Connection
- Layer Normalization
- MLP (Feed-Forward) Network
- Residual (skip) Connection
This final module assembles all the pieces:
- It passes the input images through
PatchEmbedding. - It prepends the
[CLS]token and addsPositionalEncoding. - It passes the resulting sequence through a stack of
n_layers(e.g., 3) ofTransformerEncoderblocks. - It takes only the output corresponding to the [CLS] token (
x[:, 0]) and passes it to a finalClassifier(annn.Linearlayer) to get the class logits.
The model was trained on the following high-performance GPU:
- GPU: NVIDIA H100 NVL
- CUDA Version: 12.4
The model was trained for 10 epochs with a learning rate of 0.005.
| Epoch | Loss |
|---|---|
| 1 | 1.768 |
| 2 | 1.634 |
| 3 | 1.622 |
| 4 | 1.613 |
| 5 | 1.605 |
| 6 | 1.559 |
| 7 | 1.542 |
| 8 | 1.536 |
| 9 | 1.535 |
| 10 | 1.530 |
The model achieved 93% accuracy on the MNIST test set.