This repository contains an implementation of VICReg (Variance-Invariance-Covariance Regularization), a self-supervised learning method for representation learning on CIFAR-10.
VICReg is a contrastive self-supervised learning method that learns representations by:
- Invariance: Making representations of augmented views of the same image similar
- Variance: Encouraging variance across features to prevent collapse
- Covariance: Decorrelating features to ensure information is distributed across dimensions
This implementation includes:
- VICReg model training and evaluation
- Linear probe evaluation for downstream task performance
- Embedding visualization using PCA and t-SNE
- Ablation studies (e.g., removing variance loss)
- Clustering evaluation
- Image retrieval evaluation
.
├── main.py # Main script with experiment pipelines
├── models.py # Neural network architectures (Encoder, Projector, LinearProber)
├── train_functions.py # Training and evaluation functions
├── vicreg_loss.py # VICReg loss computation
├── cifar_loader.py # Data loading utilities
├── augmentations.py # Data augmentation transforms
├── plot_functions.py # Visualization functions
├── clustering.py # Clustering evaluation
├── config.py # Configuration and hyperparameters
├── report.pdf # Project report
└── README.md # This file
Install the required dependencies:
pip install torch torchvision numpy scikit-learn matplotlib tqdmTrain a VICReg model from scratch:
from main import run_vicreg_training_pipeline
vicreg_model, encoder, projector, train_loss, test_loss = run_vicreg_training_pipeline()Visualize learned representations using PCA and t-SNE:
from main import run_vicreg_embedding_visualization
from config import ENCODER_Q1_PATH
run_vicreg_embedding_visualization(checkpoint_path=str(ENCODER_Q1_PATH))Evaluate representations using a linear probe:
from main import run_linear_prober_evaluation
from models import Encoder
from config import ENCODER_DIM, DEVICE, ENCODER_Q1_PATH
import torch
encoder = Encoder(D=ENCODER_DIM, device=DEVICE).to(DEVICE)
encoder.load_state_dict(torch.load(ENCODER_Q1_PATH, map_location=DEVICE))
encoder.eval()
lp_model, train_accuracies, test_acc = run_linear_prober_evaluation(encoder)To run all experiments (Q1-Q7), execute:
python main.pyThis will:
- Train a VICReg model (Q1)
- Visualize embeddings (Q2)
- Evaluate with linear probe (Q3)
- Run ablation study without variance loss (Q4)
- Train with nearest neighbor augmentation (Q5)
- Evaluate retrieval performance (Q7)
Evaluate clustering quality:
python clustering.pyAll hyperparameters can be modified in config.py:
- Model Architecture:
ENCODER_DIM,PROJ_DIM - Training:
NUM_OF_EPOCHS,LEARNING_RATE,BATCH_SIZE - VICReg Loss:
LAMBDA(invariance),MU(variance),NU(covariance) - Device: Automatically set to CUDA if available
- Encoder: ResNet18-based encoder modified for CIFAR-10 (32x32 images)
- Projector: 3-layer MLP that maps encoder features to projection space
- LinearProber: Linear classifier for evaluating learned representations
Strong augmentations for contrastive learning:
- Random resized crop
- Random horizontal flip
- Color jitter
- Random grayscale
- Gaussian blur
- Linear Probing: Train a linear classifier on frozen representations
- Embedding Visualization: PCA and t-SNE plots to visualize learned representations
- Clustering: KMeans clustering with silhouette score evaluation
- Retrieval: Nearest neighbor retrieval using cosine similarity
Results and analysis are documented in report.pdf.
Omer Barlev
[Research purposes]