Distributed LLM pretraining with PyTorch FSDP.
Research-grade distributed training framework: FSDP ZeRO-3, activation checkpointing, Flash Attention 2, MFU tracking, and FSDP-aware checkpointing — scales from a single GPU up to multi-node clusters via SLURM.
train.py Main training loop
config.py Typed dataclass config + YAML loading
model.py FSDP wrapping + activation checkpointing
dataset.py Streaming tokenization + sequence packing
benchmark.py Throughput/MFU/VRAM benchmark suite
utils/
distributed.py Distributed init, NCCL/GLOO auto-select
checkpointing.py Full + sharded FSDP checkpointing
logging.py MFU tracking, W&B integration
configs/
qwen_1b.yaml Qwen2.5-1.5B — 1–4 GPU (16GB VRAM)
qwen_7b.yaml Qwen2.5-7B — 4–8 GPU (24GB VRAM)
scripts/
launch.sh torchrun single-node launcher
slurm.sh SLURM multi-node launcher (2+ nodes)
# Install
pip install -r requirements.txt
# Flash Attention 2 (Linux, significant speedup):
pip install flash-attn --no-build-isolation
# Single GPU (dev / local RTX):
python train.py --config configs/qwen_1b.yaml
# 4x GPU (single node):
bash scripts/launch.sh configs/qwen_1b.yaml 4
# 2-node, 4 GPUs/node (via SLURM):
sbatch scripts/slurm.sh| Strategy | Shards | Memory/GPU | Comm. volume | Use case |
|---|---|---|---|---|
NO_SHARD |
none (DDP) | full model | low (all-reduce grads) | DDP reference |
SHARD_GRAD_OP |
grads + optim (ZeRO-2) | ~½ optim/grad state | medium | 2–4 GPU |
FULL_SHARD |
params + grads + optim (ZeRO-3) | lowest | high (all-gather params fwd+bwd) | ≥4 GPU, large models |
HYBRID_SHARD |
full within node, replicate across | lowest intra-node | inter-node reduced (intra-node all-gather) | multi-node |
Memory and comm. volume are the standard ZeRO-stage trade-offs (more sharding → less memory, more communication). Throughput is workload/interconnect-dependent — see measured numbers below rather than a generic percentage.
Documents are concatenated (with EOS tokens as separators) and split into fixed max_seq_len chunks. No padding waste — every token in every batch is a real training signal.
MFU = actual FLOP/s ÷ theoretical peak FLOP/s.
actual FLOP/s = 6 × N_params × tokens_per_second
Factor 6 = 2 (forward matmul) × 3 (forward + backward passes). A100 SXM: peak ~35–45% MFU at 7B scale with Flash Attn 2. H100: ~50–60%.
Recomputes activations during the backward pass instead of storing them. Reduces peak VRAM by ~60% at ~33% throughput cost. Applied per FSDP unit (each transformer block).
Measured on 2× NVIDIA T4 (Kaggle), Qwen2.5-1.5B (1.54B params), FSDP
FULL_SHARD, Flash Attention 2, batch size 1, sequence length 1024, bf16.
Raw data: benchmarks/results.json.
| Config | Grad ckpt | Tokens/s | Peak VRAM/GPU | Step time | MFU |
|---|---|---|---|---|---|
| FULL_SHARD | off | 406.0 | 12.26 GB | 5.04 s | 1.88% |
| FULL_SHARD | on | 303.7 | 8.14 GB | 6.74 s | 1.41% |
Reading these honestly:
- Activation checkpointing trade-off is real and measured here: turning it on cut peak VRAM 12.26 → 8.14 GB (−34%) at a 25% throughput cost (406 → 304 tok/s). That is the headroom that lets a larger model or longer sequence fit.
- MFU is low (~1.9%) — and that is expected on this hardware, not a bug. T4 is Turing (2018), no bf16 tensor-core acceleration, 16 GB, PCIe (no NVLink), and batch size 1 leaves the GPU under-fed. MFU is dominated by the small batch and the inter-GPU all-gather over PCIe, not by the kernels. The framework, sharding, packing and MFU accounting are validated end-to-end; absolute MFU scales with better interconnect + larger batch (A100/H100 with NVLink reach 35–55%).
Reproduce: torchrun --nproc_per_node=2 benchmark.py --model Qwen/Qwen2.5-1.5B
- Benchmark on A100/H100 with NVLink + larger batch (validate MFU scaling)
- LoRA / QLoRA fine-tuning mode
- DPO training loop
- Mixture of Experts (MoE) FSDP wrapping
| Config | Min GPUs | VRAM/GPU | Notes |
|---|---|---|---|
| qwen_1b (dev) | 1 | 16 GB | Single GPU local |
| qwen_1b (full) | 4 | 16 GB | Single node |
| qwen_7b | 4 | 24 GB | A10G or better |
| qwen_7b (2-node) | 8 | 24 GB | SLURM |
@software{prometheus2026,
author = {Antonio},
title = {PROMETHEUS: Distributed LLM Training with PyTorch FSDP},
year = {2026},
url = {https://github.com/QuantumDrizzy/prometheus}
}