Skip to content

AIPEAC/PredNet-with-Transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PredNet PyTorch Implementation

PyTorch re-implementation of PredNet (Lotter et al., 2017) trained on Moving MNIST. Two variants are provided: a baseline ConvLSTM-only model and an enhanced version augmented with a sparse Transformer layer.


Explainations


Repository structure

PredNet-with-Transformer/
 original_prednet_pytorch/              # Baseline: PredNet (ConvLSTM only)
   ├── prednet_x.py             # PredNet model
   ├── conv_lstm_cell_x.py      # ConvLSTM cell with peephole / gating support
   ├── mnist_data.py            # Dataset loader
   ├── mnist_settings.py        # DATA_DIR path helper
   ├── mnist_train.py           # Single-epoch quick training entry point
   ├── mnist_train_all.py       # Full training loop (all epochs)
   ├── history/                 # Saved param snapshots + prediction plots
   └── models/                  # Saved .pt checkpoints

 transformer_/  # Enhanced: PredNet + sparse Transformer
   ├── TBF

 mnist_npy_data/              # Shared data directory
   ├── X_train.hkl              # 7000 training sequences
   ├── X_val.hkl                # 1000 validation sequences
   ├── X_test.hkl               # 2000 test sequences
   ├── process_mnist.py         # Convert .npy → .hkl
   ├── mnist_test.py            # Sanity-check data loading
   └── mnist_settings.py        # DATA_DIR (shared reference)

 data_compare/                # Cross-model comparison outputs
   ├── loss_history/            # *.jsonl loss logs + plot_loss_comparison.py
   └── test_plots/              # MSE-per-timestep JSONs + plot_test_sequence_mse.py

 Dockerfile                   # CPU image (data prep only)
 Dockerfile.gpu               # NVIDIA GPU image (training)
 docker-compose.dev.yml       # Dev container (attach-to-running, SSH + workspace mounts)
 docker-compose.gpu.yml       # One-shot GPU training run
 .devcontainer/               # VS Code dev container config (builds Dockerfile.gpu)
 requirements-docker-cpu.txt  # CPU dependency pins
 requirements-docker-gpu.txt  # GPU dependency pins (CUDA 11.1 + PyTorch)
 prednet-env-windows.yml      # Optional conda YAML snapshot

Data setup

Download Moving MNIST (if not already present) and run the preprocessor once:

# Place the raw file:  mnist_npy_data/mnist_test_seq.npy
python mnist_npy_data/process_mnist.py
# Produces: X_train.hkl  X_val.hkl  X_test.hkl  inside mnist_npy_data/

Environment

Use Python 3.8 and a venv named envv.

python3.8 -m venv envv
# Windows CMD:        envv\Scripts\activate.bat
# Windows PowerShell: envv\Scripts\Activate.ps1
# Linux/macOS:        source envv/bin/activate
pip install --upgrade pip setuptools wheel
pip install -r requirements-docker-cpu.txt   # CPU / data prep
# or
pip install -r requirements-docker-gpu.txt   # NVIDIA GPU training

Docker

CPU image — data prep only

docker build -t prednet:cpu .
docker run --rm -it -v "$(pwd):/workspace" -w /workspace prednet:cpu
python mnist_npy_data/process_mnist.py

GPU image — training

Requires NVIDIA Container Toolkit.

docker build -f Dockerfile.gpu -t prednet:gpu .
# Baseline training:
docker run --rm -it --gpus all -v "$(pwd):/workspace" -w /workspace/original_mnist prednet:gpu python mnist_train_all.py
# Enhanced training:
docker run --rm -it --gpus all -v "$(pwd):/workspace" -w /workspace/transformer_mnist_enhanced prednet:gpu python mnist_train_all_sparse.py
# Or via compose:
docker compose -f docker-compose.gpu.yml run --rm prednet

Dev container (VS Code)

.devcontainer/devcontainer.json builds Dockerfile.gpu and passes --gpus all. Open the repo folder in VS Code and choose Reopen in Container.

For a persistent dev session (attach workflow):

# PowerShell (WSL2 engine)
$env:PREDNET_WORKSPACE = "/mnt/s/Projects/#ProjectNTU"   # adjust to your path
docker compose -f docker-compose.dev.yml up -d
# VS Code → Ctrl+Shift+P → "Dev Containers: Attach to Running Container…" → prednet-dev
# Open folder: /workspace

SSH keys are mounted read-only from ${HOME}/.ssh so git@A:… remotes work inside the container.


Training

Baseline — original_mnist/

cd original_mnist
python mnist_train_all.py
Parameter Value
num_epochs 1
batch_size 8
lr 0.001 (halved at epoch num_epochs//2)
nt 20 frames
n_train_seq 7000
n_val_seq 1000
loss_mode L_all
A_channels / R_channels (3, 48, 96, 192)

Outputs:

  • original_mnist/models/prednet-L_all-mul-peepFalse-tbiasFalse-best.pt
  • original_mnist/history/prednet-L_all-mul-peepFalse-tbiasFalse-param_history.jsonl
  • original_mnist/history/prediction_plots/ — frame prediction PNGs every 100 steps
  • data_compare/loss_history/original_mnist-prednet-L_all-mul-peepFalse-tbiasFalse-loss_history.jsonl

Prednet with Transformer — transformer_mnist/

tbf

tbf

Outputs: tbf

Comparison plots

python data_compare/loss_history/plot_loss_comparison.py
python data_compare/test_plots/plot_test_sequence_mse.py

References

License

  • This repository (excluding incorporated code) is licensed under the Apache License 2.0. See LICENSE for details.
  • Incorporated code from coxlab/prednet, which is licensed under the MIT License. See LICENSE-COXLAB for details.

About

Rebuilding PredNet with PyTorch and refactoring structural design with Transformer.

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors