Skip to content

Aashutoshh01/VisionSmith

Repository files navigation

Vision Transformer (ViT) from Scratch on MNIST

This repository contains a from-scratch implementation of the Vision Transformer (ViT) model in PyTorch. The model is built using the core concepts from the original paper, "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", and is trained on the classic MNIST dataset.

The entire implementation, from data loading to model definition and training, is contained in the main Jupyter Notebook.

📜 Reference Paper

This code is a simplified implementation based on the architecture described in the original Google Research paper:

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, et al. arXiv:2010.11929

🚀 Getting Started

Follow these steps to set up your environment and run the code.

1. Environment Setup

This project uses Python 3.13. It is highly recommended to use a virtual environment.

# Create a virtual environment
python3 -m venv vit-env

# Activate the virtual environment
# On macOS/Linux:
source vit-env/bin/activate
# On Windows:
.\vit-env\Scripts\activate

2. Install Dependencies

Install the necessary Python libraries using pip:

pip install torch torchvision numpy jupyter

🏃‍♂️ Running the Code

All the code is in the notebook.ipynb notebook.

  1. Launch Jupyter Notebook from your activated environment:
    jupyter notebook
  2. Open the notebook.ipynb notebook.
  3. Run all cells to download the MNIST dataset, build the model, and execute the training process.

💡 Code Explained

The following are the key components of the ViT architecture.

1. PatchEmbedding

This is the most critical part of the ViT. It's responsible for converting a 2D image into a 1D sequence of token embeddings, just as a sequence of word embeddings in NLP.

We cleverly use a 2D Convolutional layer (nn.Conv2d) to achieve this. By setting the kernel_size and stride to be equal to the patch_size, the nn.Conv2d layer performs one operation for each image patch, effectively splitting the image and applying a linear projection at the same time.

Shape Transformation:

This is the key data flow within the PatchEmbedding module:

  1. Input: (B, C, H, W)

    • B = Batch Size (e.g., 128)
    • C = Channels (1 for MNIST)
    • H = Image Height (32)
    • W = Image Width (32)
  2. self.linear_project(x) -> (B, d_model, P_col, P_row)

    • The nn.Conv2d layer projects the patches.
    • d_model = Embedding dimension (9)
    • P_col = Patches in height (32 / 16 = 2)
    • P_row = Patches in width (32 / 16 = 2)
    • Shape becomes: (128, 9, 2, 2)
  3. x.flatten(2) -> (B, d_model, P)

    • Flattens the patch grid into a single dimension.
    • P = Total patches (2 * 2 = 4)
    • Shape becomes: (128, 9, 4)
  4. x.transpose(1, 2) -> (B, P, d_model)

    • Transposes the dimensions to get the desired (Batch, Sequence_Length, Embedding_Dim) format.
    • Shape becomes: (128, 4, 9)
    • This sequence of 4 patch embeddings is now ready to be fed into the Transformer.

2. PositionalEncoding

This module adds two crucial components:

  • [CLS] Token: A learnable parameter nn.Parameter that is prepended to the sequence of patch embeddings. The Transformer's output for this single token will be used for the final classification.
  • Positional Embeddings: Standard sine/cosine positional embeddings are added to the patch embeddings. This is essential because the Transformer's self-attention mechanism is permutation-invariant, so we must provide explicit information about the relative positions of the patches.

3. TransformerEncoder

This is a standard Transformer encoder block, consisting of:

  1. Layer Normalization
  2. Multi-Head Self-Attention (MHA)
  3. Residual (skip) Connection
  4. Layer Normalization
  5. MLP (Feed-Forward) Network
  6. Residual (skip) Connection

4. VisionTransformer

This final module assembles all the pieces:

  1. It passes the input images through PatchEmbedding.
  2. It prepends the [CLS] token and adds PositionalEncoding.
  3. It passes the resulting sequence through a stack of n_layers (e.g., 3) of TransformerEncoder blocks.
  4. It takes only the output corresponding to the [CLS] token (x[:, 0]) and passes it to a final Classifier (an nn.Linear layer) to get the class logits.

💻 Hardware

The model was trained on the following high-performance GPU:

  • GPU: NVIDIA H100 NVL
  • CUDA Version: 12.4

📊 Results

The model was trained for 10 epochs with a learning rate of 0.005.

Training Loss

Epoch Loss
1 1.768
2 1.634
3 1.622
4 1.613
5 1.605
6 1.559
7 1.542
8 1.536
9 1.535
10 1.530

Final Accuracy

The model achieved 93% accuracy on the MNIST test set.

About

A clean, from-scratch Vision Transformer (ViT) implementation in PyTorch, trained on MNIST and based on the “An Image is Worth 16×16 Words” paper. Includes full patch embedding, positional encoding, transformer encoder blocks, and end-to-end training in a single notebook.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors