A comprehensive distributed neural network training system that scales across multiple GPUs and machines.
- Multiple Parallelism Strategies: Data Parallel, Model Parallel, and Pipeline Parallel training
- Ring AllReduce: Efficient gradient synchronization using the Ring AllReduce algorithm
- Fault Tolerance: Automatic checkpointing and recovery from failures
- Real-time Monitoring: Live dashboard with GPU utilization, training metrics, and cluster health
- Elastic Scaling: Add or remove nodes dynamically
- Mock Implementation: Full mock database and simulated training for development
- Python 3.10+
- Node.js 18+
- npm or yarn
# Install Python dependencies
pip install -r requirements.txt
# Install Node.js dependencies
cd dashboard
npm install# Set mock database environment variable
export MOCK_DB=true
# Start the API server
python -m uvicorn api.main:app --reload --host 0.0.0.0 --port 8000
# In another terminal, start the dashboard
cd dashboard
npm run devOpen your browser to http://localhost:3000
| Method | Endpoint | Description |
|---|---|---|
| POST | /api/v1/training/jobs |
Create new training job |
| GET | /api/v1/training/jobs |
List all training jobs |
| GET | /api/v1/training/jobs/:id |
Get job details |
| PUT | /api/v1/training/jobs/:id/pause |
Pause running job |
| PUT | /api/v1/training/jobs/:id/resume |
Resume paused job |
| DELETE | /api/v1/training/jobs/:id |
Cancel training job |
| Method | Endpoint | Description |
|---|---|---|
| GET | /api/v1/cluster/nodes |
List cluster nodes |
| POST | /api/v1/cluster/nodes |
Add worker node |
| DELETE | /api/v1/cluster/nodes/:id |
Remove worker node |
| GET | /api/v1/cluster/gpu-status |
Get all GPU statuses |
| Method | Endpoint | Description |
|---|---|---|
| GET | /api/v1/metrics/jobs/:id |
Get job metrics |
| GET | /api/v1/metrics/cluster/summary |
Cluster-wide metrics |
| WS | /api/v1/metrics/stream/:jobId |
Real-time metrics stream |
| Method | Endpoint | Description |
|---|---|---|
| GET | /api/v1/checkpoints/jobs/:id |
List job checkpoints |
| DELETE | /api/v1/checkpoints/:id |
Delete checkpoint |
| POST | /api/v1/checkpoints/:id/restore |
Restore from checkpoint |
| Method | Endpoint | Description |
|---|---|---|
| POST | /api/v1/experiments |
Create benchmark experiment |
| GET | /api/v1/experiments |
List experiments |
| POST | /api/v1/experiments/:id/analyze |
Run analysis on results |
┌─────────────────────────────────────────────────────────────────┐
│ User Interface │
│ (CLI / Web Dashboard / API) │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Training Coordinator │
│ (Central Training Manager Service) │
└─────────────────────────────────────────────────────────────────┘
│
┌───────────────┼───────────────┐
│ │ │
▼ ▼ ▼
┌──────────────────┐ ┌──────────────────┐ ┌──────────────────┐
│ Worker Node 1 │ │ Worker Node 2 │ │ Worker Node N │
│ (GPU x4) │ │ (GPU x4) │ │ (GPU xM) │
└──────────────────┘ └──────────────────┘ └──────────────────┘
Each GPU gets a complete copy of the model and processes a different subset of the batch. Gradients are averaged via AllReduce.
GPU0 → model copy → batch subset 0 → gradients 0
GPU1 → model copy → batch subset 1 → gradients 1
GPU2 → model copy → batch subset 2 → gradients 2
GPU3 → model copy → batch subset 3 → gradients 3
↓
AllReduce(gradients)
The model is split across multiple GPUs, with each GPU holding a portion of the layers.
GPU0 → layers 0-3
GPU1 → layers 4-7
GPU2 → layers 8-11
GPU3 → layers 12-15
Different GPUs process different pipeline stages with the 1F1B (One-Forward-One-Backward) schedule.
GPU0 → Stage 0 (layers 0-3)
GPU1 → Stage 1 (layers 4-7)
GPU2 → Stage 2 (layers 8-11)
GPU3 → Stage 3 (layers 12-15)
pytest tests/api/ -v --cov=api --cov-report=htmlpytest tests/distributed/ -vcd dashboard
npm run testdistributed-gpu-training/
├── api/ # FastAPI application
│ ├── routes/ # API route handlers
│ ├── middleware/ # Authentication middleware
│ └── websocket/ # WebSocket handlers
├── cluster/ # Cluster management
├── distributed/ # Distributed training core
├── parallelism/ # Parallelism implementations
├── training/ # Training loop and checkpointing
├── monitoring/ # Metrics collection
├── db/ # Database layer
├── dashboard/ # React frontend
├── tests/ # Test suites
└── scripts/ # Utility scripts
MIT License