Skip to content

barlevo/VICReg-model-for-SSL

Repository files navigation

VICReg: Representation Learning on CIFAR-10

This repository contains an implementation of VICReg (Variance-Invariance-Covariance Regularization), a self-supervised learning method for representation learning on CIFAR-10.

Overview

VICReg is a contrastive self-supervised learning method that learns representations by:

  1. Invariance: Making representations of augmented views of the same image similar
  2. Variance: Encouraging variance across features to prevent collapse
  3. Covariance: Decorrelating features to ensure information is distributed across dimensions

This implementation includes:

  • VICReg model training and evaluation
  • Linear probe evaluation for downstream task performance
  • Embedding visualization using PCA and t-SNE
  • Ablation studies (e.g., removing variance loss)
  • Clustering evaluation
  • Image retrieval evaluation

Project Structure

.
├── main.py                  # Main script with experiment pipelines
├── models.py                # Neural network architectures (Encoder, Projector, LinearProber)
├── train_functions.py       # Training and evaluation functions
├── vicreg_loss.py          # VICReg loss computation
├── cifar_loader.py         # Data loading utilities
├── augmentations.py        # Data augmentation transforms
├── plot_functions.py       # Visualization functions
├── clustering.py           # Clustering evaluation
├── config.py               # Configuration and hyperparameters
├── report.pdf              # Project report
└── README.md               # This file

Requirements

Install the required dependencies:

pip install torch torchvision numpy scikit-learn matplotlib tqdm

Usage

Basic Training

Train a VICReg model from scratch:

from main import run_vicreg_training_pipeline

vicreg_model, encoder, projector, train_loss, test_loss = run_vicreg_training_pipeline()

Embedding Visualization

Visualize learned representations using PCA and t-SNE:

from main import run_vicreg_embedding_visualization
from config import ENCODER_Q1_PATH

run_vicreg_embedding_visualization(checkpoint_path=str(ENCODER_Q1_PATH))

Linear Probe Evaluation

Evaluate representations using a linear probe:

from main import run_linear_prober_evaluation
from models import Encoder
from config import ENCODER_DIM, DEVICE, ENCODER_Q1_PATH
import torch

encoder = Encoder(D=ENCODER_DIM, device=DEVICE).to(DEVICE)
encoder.load_state_dict(torch.load(ENCODER_Q1_PATH, map_location=DEVICE))
encoder.eval()

lp_model, train_accuracies, test_acc = run_linear_prober_evaluation(encoder)

Running All Experiments

To run all experiments (Q1-Q7), execute:

python main.py

This will:

  1. Train a VICReg model (Q1)
  2. Visualize embeddings (Q2)
  3. Evaluate with linear probe (Q3)
  4. Run ablation study without variance loss (Q4)
  5. Train with nearest neighbor augmentation (Q5)
  6. Evaluate retrieval performance (Q7)

Clustering Evaluation

Evaluate clustering quality:

python clustering.py

Configuration

All hyperparameters can be modified in config.py:

  • Model Architecture: ENCODER_DIM, PROJ_DIM
  • Training: NUM_OF_EPOCHS, LEARNING_RATE, BATCH_SIZE
  • VICReg Loss: LAMBDA (invariance), MU (variance), NU (covariance)
  • Device: Automatically set to CUDA if available

Key Features

Models

  • Encoder: ResNet18-based encoder modified for CIFAR-10 (32x32 images)
  • Projector: 3-layer MLP that maps encoder features to projection space
  • LinearProber: Linear classifier for evaluating learned representations

Data Augmentation

Strong augmentations for contrastive learning:

  • Random resized crop
  • Random horizontal flip
  • Color jitter
  • Random grayscale
  • Gaussian blur

Evaluation Methods

  1. Linear Probing: Train a linear classifier on frozen representations
  2. Embedding Visualization: PCA and t-SNE plots to visualize learned representations
  3. Clustering: KMeans clustering with silhouette score evaluation
  4. Retrieval: Nearest neighbor retrieval using cosine similarity

Results

Results and analysis are documented in report.pdf.

Author

Omer Barlev

License

[Research purposes]

About

Self-supervised learning (SSL) has been shown to be highly effective for many downstream tasks. In this repo there's an implement of a state-of-the-art SSL method. It experiments with several downstream applications of self- supervised learning, and exploring its usefulness.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages