Fast imagenet training on the TU Delft HPC with PyTorch using TFRecords and DALI.
Tested with PyTorch 1.7.0 and NVIDIA DALI 0.27.0. Install instructions for DALI: https://www.github.com/NVIDIA/DALI.
Description of files
-
imagenet_tfrecord.py
Python script containing ImageNet dataloader. Use this for your own project. -
main.py
Ready to run ImageNet training script for ResNet18. Will finish training in ~24 hours. -
imagenet.sbatch
Sbatch script with recommended settings.
Performance
Performance of ResNet18 is on par with the pre-trained torchvision model.
Top-1 error % | Top-5 error % | |
---|---|---|
ResNet18 - DALI [ours] | 29.99 | 10.79 |
ResNet18 - Torchvision [link] | 30.24 | 10.92 |
Limitations
- As all JPEG decoding and data augmentation is processed on the GPU, less GPU memory is available for your network. In case of OOM errors you can try to (1) use more GPUs, or (2) enable
dali_cpu
(possibly slower). - I'm not sure how exactly batches are shuffled, it might be less "random" compared to loading individual JPEG files.