PyTorch re-implementation of PredNet (Lotter et al., 2017) trained on Moving MNIST. Two variants are provided: a baseline ConvLSTM-only model and an enhanced version augmented with a sparse Transformer layer.
- Original PredNet: Original_Prenet_Design
- Transformer redesign: TBF
PredNet-with-Transformer/
original_prednet_pytorch/ # Baseline: PredNet (ConvLSTM only)
├── prednet_x.py # PredNet model
├── conv_lstm_cell_x.py # ConvLSTM cell with peephole / gating support
├── mnist_data.py # Dataset loader
├── mnist_settings.py # DATA_DIR path helper
├── mnist_train.py # Single-epoch quick training entry point
├── mnist_train_all.py # Full training loop (all epochs)
├── history/ # Saved param snapshots + prediction plots
└── models/ # Saved .pt checkpoints
transformer_/ # Enhanced: PredNet + sparse Transformer
├── TBF
mnist_npy_data/ # Shared data directory
├── X_train.hkl # 7000 training sequences
├── X_val.hkl # 1000 validation sequences
├── X_test.hkl # 2000 test sequences
├── process_mnist.py # Convert .npy → .hkl
├── mnist_test.py # Sanity-check data loading
└── mnist_settings.py # DATA_DIR (shared reference)
data_compare/ # Cross-model comparison outputs
├── loss_history/ # *.jsonl loss logs + plot_loss_comparison.py
└── test_plots/ # MSE-per-timestep JSONs + plot_test_sequence_mse.py
Dockerfile # CPU image (data prep only)
Dockerfile.gpu # NVIDIA GPU image (training)
docker-compose.dev.yml # Dev container (attach-to-running, SSH + workspace mounts)
docker-compose.gpu.yml # One-shot GPU training run
.devcontainer/ # VS Code dev container config (builds Dockerfile.gpu)
requirements-docker-cpu.txt # CPU dependency pins
requirements-docker-gpu.txt # GPU dependency pins (CUDA 11.1 + PyTorch)
prednet-env-windows.yml # Optional conda YAML snapshot
Download Moving MNIST (if not already present) and run the preprocessor once:
# Place the raw file: mnist_npy_data/mnist_test_seq.npy
python mnist_npy_data/process_mnist.py
# Produces: X_train.hkl X_val.hkl X_test.hkl inside mnist_npy_data/Use Python 3.8 and a venv named envv.
python3.8 -m venv envv
# Windows CMD: envv\Scripts\activate.bat
# Windows PowerShell: envv\Scripts\Activate.ps1
# Linux/macOS: source envv/bin/activate
pip install --upgrade pip setuptools wheel
pip install -r requirements-docker-cpu.txt # CPU / data prep
# or
pip install -r requirements-docker-gpu.txt # NVIDIA GPU trainingdocker build -t prednet:cpu .
docker run --rm -it -v "$(pwd):/workspace" -w /workspace prednet:cpu
python mnist_npy_data/process_mnist.pyRequires NVIDIA Container Toolkit.
docker build -f Dockerfile.gpu -t prednet:gpu .
# Baseline training:
docker run --rm -it --gpus all -v "$(pwd):/workspace" -w /workspace/original_mnist prednet:gpu python mnist_train_all.py
# Enhanced training:
docker run --rm -it --gpus all -v "$(pwd):/workspace" -w /workspace/transformer_mnist_enhanced prednet:gpu python mnist_train_all_sparse.py
# Or via compose:
docker compose -f docker-compose.gpu.yml run --rm prednet.devcontainer/devcontainer.json builds Dockerfile.gpu and passes --gpus all. Open the repo folder in VS Code and choose Reopen in Container.
For a persistent dev session (attach workflow):
# PowerShell (WSL2 engine)
$env:PREDNET_WORKSPACE = "/mnt/s/Projects/#ProjectNTU" # adjust to your path
docker compose -f docker-compose.dev.yml up -d
# VS Code → Ctrl+Shift+P → "Dev Containers: Attach to Running Container…" → prednet-dev
# Open folder: /workspaceSSH keys are mounted read-only from ${HOME}/.ssh so git@A:… remotes work inside the container.
cd original_mnist
python mnist_train_all.py| Parameter | Value |
|---|---|
num_epochs |
1 |
batch_size |
8 |
lr |
0.001 (halved at epoch num_epochs//2) |
nt |
20 frames |
n_train_seq |
7000 |
n_val_seq |
1000 |
loss_mode |
L_all |
A_channels / R_channels |
(3, 48, 96, 192) |
Outputs:
original_mnist/models/prednet-L_all-mul-peepFalse-tbiasFalse-best.ptoriginal_mnist/history/prednet-L_all-mul-peepFalse-tbiasFalse-param_history.jsonloriginal_mnist/history/prediction_plots/— frame prediction PNGs every 100 stepsdata_compare/loss_history/original_mnist-prednet-L_all-mul-peepFalse-tbiasFalse-loss_history.jsonl
tbftbf
python data_compare/loss_history/plot_loss_comparison.py
python data_compare/test_plots/plot_test_sequence_mse.py- Lotter, W., Kreiman, G., & Cox, D. (2017). Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning. ICLR 2017.
- This repository (excluding incorporated code) is licensed under the Apache License 2.0. See LICENSE for details.
- Incorporated code from coxlab/prednet, which is licensed under the MIT License. See LICENSE-COXLAB for details.