Skip to content

rishit2003/RNA-Prediction-CNN

Repository files navigation

RNA Contact Map Prediction using Dilated U-Net

This repository implements a biologically informed Dilated U-Net to predict RNA base-pair contact maps, achieving ~0.70 F1 on the Kaggle private leaderboard under strict data and architectural constraints.

Project Overview

This project tackles RNA secondary structure prediction as a pairwise contact map classification problem, where the goal is to predict which nucleotide positions (i, j) form base pairs.

The task is formulated as a pixel-wise binary classification problem over an (L × L) contact map, where L is the RNA sequence length. A Dilated U-Net CNN is trained to predict base-pair probabilities using biologically motivated pairwise features.

Performance is evaluated using F1-score on the Kaggle Private Leaderboard.


Constraints

This project was developed under the following constraints:

  • No external datasets or pre-trained models
  • Fixed Kaggle evaluation protocol
  • Variable-length RNA sequences
  • Batch size constrained to 1 due to variable-length sequences

Model Architecture

The model is a Dilated U-Net designed to capture:

  • Local RNA interactions near the main diagonal of the contact map
  • Long-range dependencies between distant nucleotide pairs

This is achieved using the following architectural components:

  • Encoder–decoder U-Net backbone
  • Dilated convolutions in the encoder with exponentially increasing dilation (2d) to expand receptive field without loss of resolution
  • Skip connections between encoder and decoder layers to preserve spatial detail
  • Group Normalization, chosen for stability with batch size = 1 and variable-length inputs

The final model configuration used for training:

  • in_channels = 10
  • base_channels = 64
  • depth = 4
  • dropout = 0.2

Input Features

For each RNA sequence, a (C, L, L) tensor is constructed with 10 channels:

Channel Description
0–3 One-hot encoding of base at position i
4–7 One-hot encoding of base at position j
8 Base pairing compatibility (AU, CG, GU)
9 Normalized distance

This explicitly injects biological prior knowledge into the model.


Loss Function

To handle extreme class imbalance (many more non-pairs than pairs), a hybrid loss is used:

Loss = BCEWithLogitsLoss(pos_weight) + DiceLoss

Dice loss complements BCE by directly optimizing overlap on sparse contact maps, preventing the model from collapsing to trivial all-zero predictions under extreme class imbalance.

  • pos_weight is computed from the training set and capped at 12
  • Dice loss encourages overlap between predicted and true contact maps

Threshold Selection

The model outputs probabilities. A threshold sweep is performed on the validation set after each epoch to maximize F1-score.

Thresholds tested:

np.linspace(0.05, 0.95, 19)

The best threshold and F1-score are logged and saved to metrics.csv.


Training Details

  • Optimizer: Adam
  • Learning rate: 3e-4
  • Weight decay: 1e-4
  • Scheduler: ReduceLROnPlateau
  • Batch size: 1 (variable-length sequences)
  • Validation split: 80 / 20

Best model weights are saved automatically to:

best_model.pth

Training metrics are saved to:

metrics.csv


Repository Structure

COEN432-RNA-Contact-Prediction/
├── train.csv
├── test.csv
├── best_model.pth
├── metrics.csv
├── submission.csv
├── coen432-project.ipynb / main.py
├── plots/
│   ├── losses.png
│   ├── val_f1.png
│   ├── thresholds.png
│   └── visualize_sample.png
├── logs/
├── README.md
└── report/
    └── COEN432_Final_Report.pdf

How to Run

1. Train the model

python main.py --mode train --epochs 100

This will: • Train the model • Save best_model.pth • Save metrics.csv • Generate plots • Run inference • Create submission.csv


2. Run inference only

python main.py --mode infer

Uses the saved model and metrics to generate:

submission.csv


3. Visualize predictions

python main.py --mode visualize

Displays:

  • Ground truth contact map
  • Predicted probabilities
  • Binarized prediction

4. Plot training curves

python main.py --mode plot

Kaggle Submission

The generated submission.csv follows the required format:

id,pairs SEQ_ID, i-j i-j i-j ...

Empty predictions are replaced with a space " " to avoid Kaggle null-value errors.


Results

  • Public Leaderboard F1: ~0.703
  • Private Leaderboard: ~0.699

These results indicate that incorporating explicit biological priors enables strong generalization even under strict data and architectural constraints.


About

Deep learning RNA contact map prediction using a Dilated U-Net. Applies dense pairwise classification, dilated CNNs, and hybrid BCE–Dice loss to achieve ~0.70 F1 under strict data constraints.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors