Skip to content
Snippets Groups Projects
Commit 258b584b authored by Attila Lengyel's avatar Attila Lengyel
Browse files

Updated to DALI 1.22, added WANDB logging.

parent 3e97ab6d
No related branches found
No related tags found
No related merge requests found
#!/bin/sh
#SBATCH --partition=general
#SBATCH --qos=long
#SBATCH --time=48:00:00
#SBATCH --qos=medium
#SBATCH --time=36:00:00
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --mem=16384
#SBATCH --mem=16G
#SBATCH --mail-type=END
#SBATCH --gres=gpu:4
#SBATCH --gres=gpu:a40
module use /opt/insy/modulefiles
module load cuda/10.0 cudnn/10.0-7.4.2.24
export OUT_DIR=./output/imagenet
export WANDB_DIR=/tmp/wandb
srun python main.py
srun python main.py --batch-size 256
\ No newline at end of file
from subprocess import call
import os.path
import os
try:
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import nvidia.dali.tfrecord as tfrec
except ImportError:
raise ImportError("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example.")
class TFRecordPipeline(Pipeline):
def __init__(self, tfrecord_path, tfrecord_idx_path, device_id, num_gpus, batch_size=64,
num_threads=2, dali_cpu=False, augment=False, crop=224, size=256):
super(TFRecordPipeline, self).__init__(batch_size, num_threads, device_id)
self.augment = augment
self.input = ops.TFRecordReader(
path=tfrecord_path,
index_path=tfrecord_idx_path,
features={"image/encoded": tfrec.FixedLenFeature((),tfrec.string, ""),
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)},
num_shards=num_gpus,
shard_id=device_id
raise ImportError(
"Please install DALI from https://www.github.com/NVIDIA/DALI to run this example."
)
@pipeline_def
def tfrecord_pipeline(
tfrecord_path,
tfrecord_idx_path,
num_gpus,
augment,
shard_id,
dali_cpu=False,
crop=224,
size=256,
):
# Specify devices to use.
dali_device = "cpu" if dali_cpu else "gpu"
decoder_device = "cpu" if dali_cpu else "mixed"
# This padding sets the size of the internal nvJPEG buffers to be able to
# handle all images from full-sized ImageNet without additional reallocations.
device_memory_padding = 211025920 if decoder_device == "mixed" else 0
host_memory_padding = 140544512 if decoder_device == "mixed" else 0
inputs = fn.readers.tfrecord( # type: ignore
path=tfrecord_path,
index_path=tfrecord_idx_path,
features={
"image/encoded": tfrec.FixedLenFeature((), tfrec.string, ""), # type: ignore
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1), # type: ignore
},
num_shards=num_gpus,
shard_id=shard_id,
name="Reader",
)
# Decoder and data augmentation
if augment:
images = fn.decoders.image_random_crop( # type: ignore
inputs["image/encoded"],
device=decoder_device,
output_type=types.RGB, # type: ignore
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[0.8, 1.25],
random_area=[0.1, 1.0],
num_attempts=100,
)
# Specify devices to use
dali_device = 'cpu' if dali_cpu else 'gpu'
decoder_device = 'cpu' if dali_cpu else 'mixed'
# This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet without additional reallocations.
# Not sure if this is needed for TFRecords though.
device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
# Decoder and data augmentation
if augment:
# To use for training
self.decode = ops.ImageDecoderRandomCrop(device=decoder_device,
output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[0.8, 1.25],
random_area=[0.1, 1.0],
num_attempts=100)
self.resize = ops.Resize(device=dali_device,
resize_x=crop,
resize_y=crop,
interp_type=types.INTERP_TRIANGULAR)
self.coin = ops.CoinFlip(probability=0.5)
else:
# To use for validation
self.decode = ops.ImageDecoder(device=decoder_device,
output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding)
self.resize = ops.Resize(device=dali_device,
resize_shorter=size,
interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device="gpu",
dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
mean=[0.485 * 255,0.456 * 255,0.406 * 255],
std=[0.229 * 255,0.224 * 255,0.225 * 255])
def define_graph(self):
inputs = self.input()
images = inputs["image/encoded"]
images = self.decode(images)
images = self.resize(images)
if self.augment:
rng = self.coin()
images = self.cmnp(images.gpu(), mirror=rng)
else:
images = self.cmnp(images.gpu())
labels = inputs["image/class/label"]-1
return [images, labels]
def ImageNet_TFRecord(root, split, batch_size, num_threads, device_id, num_gpus,
dali_cpu=False, augment=False):
images = fn.resize( # type: ignore
images,
device=dali_device,
resize_x=crop,
resize_y=crop,
interp_type=types.INTERP_TRIANGULAR, # type: ignore
)
rng = fn.random.coin_flip() # type: ignore
else:
images = fn.decoders.image( # type: ignore
inputs["image/encoded"],
device=decoder_device,
output_type=types.RGB, # type: ignore
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
)
images = fn.resize( # type: ignore
images,
device=dali_device,
resize_shorter=size,
interp_type=types.INTERP_TRIANGULAR, # type: ignore
)
rng = False
# Normalize such that values are in the range [0, 1].
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
images = fn.crop_mirror_normalize( # type: ignore
images.gpu(),
dtype=types.FLOAT, # type: ignore
output_layout=types.NCHW, # type: ignore
crop=(crop, crop),
mean=mean,
std=std,
mirror=rng,
)
labels = inputs["image/class/label"] - 1
return images, labels
def ImageNet_TFRecord(
root: str,
split: str,
batch_size: int,
num_threads: int,
device_id: int,
num_gpus: int,
dali_cpu: bool = False,
augment: bool = False,
):
"""
PyTorch dataloader for ImageNet TFRecord files.
......@@ -95,48 +123,75 @@ def ImageNet_TFRecord(root, split, batch_size, num_threads, device_id, num_gpus,
will be divided over all subprocesses.
num_gpus (int): Total number of GPUS available.
dali_cpu (bool): Set True to perform part of data loading on CPU instead of GPU (default=False).
augment (bool): Whether or not to apply data augmentation (random cropping,
horizontal flips).
augment (bool): Set True to use training preprocessing (default=False).
jitter (float): Set to a value between 0 and 0.5 to apply random hue jitter to images (default=0.0).
grayscale (bool): Set True to convert images to grayscale (default=False).
subset (float): Fraction of dataset to use (default=1.0).
Returns:
PyTorch dataloader.
"""
# List all tfrecord files in directory
tf_files = os.listdir(os.path.join(root, split, 'data'))
# List all tfrecord files in directory.
tf_files = sorted(os.listdir(os.path.join(root, split, "data")))
# Create dir for idx files if not exists
idx_files_dir = os.path.join(root, split, 'idx_files')
# Create dir for idx files if not exists.
idx_files_dir = os.path.join(root, split, "idx_files")
if not os.path.exists(idx_files_dir):
os.mkdir(idx_files_dir)
tfrec_path_list = []
idx_path_list = []
n_samples = 0
# Create idx files and create TFRecordPipelines
# Create idx files and create TFRecordPipelines.
for tf_file in tf_files:
# Path of tf_file and idx file
tfrec_path = os.path.join(root, split, 'data', tf_file)
# Path of tf_file and idx file.
tfrec_path = os.path.join(root, split, "data", tf_file)
tfrec_path_list.append(tfrec_path)
idx_path = os.path.join(idx_files_dir, tf_file+'_idx')
idx_path = os.path.join(idx_files_dir, tf_file + "_idx")
idx_path_list.append(idx_path)
# Create idx file for tf_file by calling tfrecord2idx script
# Create idx file for tf_file by calling tfrecord2idx script.
if not os.path.isfile(idx_path):
call(["tfrecord2idx", tfrec_path, idx_path])
with open(idx_path, 'r') as f:
with open(idx_path, "r") as f:
n_samples += len(f.readlines())
# Create TFRecordPipeline for each TFRecord file
pipe = TFRecordPipeline(tfrecord_path=tfrec_path_list,
tfrecord_idx_path=idx_path_list,
device_id=device_id,
num_gpus=num_gpus,
batch_size=batch_size,
num_threads=num_threads,
augment=augment,
dali_cpu=dali_cpu)
# Create TFRecordPipeline for each TFRecord file.
pipe = tfrecord_pipeline(
tfrec_path_list,
idx_path_list,
augment=augment,
device_id=device_id, # type: ignore
shard_id=device_id,
num_gpus=num_gpus,
batch_size=batch_size, # type: ignore
num_threads=num_threads, # type: ignore
dali_cpu=dali_cpu,
)
pipe.build()
dataloader = DALIClassificationIterator(pipelines=pipe, fill_last_batch=False,
size=(n_samples//num_gpus+1))
dataloader = DALIClassificationIterator(
pipelines=pipe,
reader_name="Reader",
last_batch_policy=LastBatchPolicy.PARTIAL,
)
return dataloader
if __name__ == "__main__":
# Create dataloader.
dataloader = ImageNet_TFRecord(
root="/tudelft.net/staff-bulk/ewi/insy/CV-DataSets/imagenet/tfrecords",
split="train",
batch_size=64,
num_threads=2,
device_id=0,
num_gpus=1,
dali_cpu=False,
augment=True,
)
# Get first batch and print shape.
print("Number of batches: ", len(dataloader))
data = next(iter(dataloader))
print(data[0]["data"].shape, data[0]["label"].shape) # type: ignore
This diff is collapsed.
Fast imagenet training on the TU Delft HPC with PyTorch using TFRecords and DALI.
Tested with PyTorch 1.7.0 and NVIDIA DALI 0.27.0.
Tested with PyTorch 1.12.1, CUDA 11.6 and NVIDIA DALI 1.22.0.
Install instructions for DALI: https://www.github.com/NVIDIA/DALI.
##### Description of files
......@@ -9,6 +9,12 @@ Install instructions for DALI: https://www.github.com/NVIDIA/DALI.
* `main.py` Ready to run ImageNet training script for ResNet18. Will finish training in ~24 hours.
* `imagenet.sbatch` Sbatch script with recommended settings.
##### Usage
* Set `OUT_DIR` and `WANDB_DIR` environment variables (for example in `imagenet.sbatch`).
* Run `sbatch imagenet.sbatch` to start training.
* For ResNet18 requires 4x 1080Ti GPUs (batch size 64 each), or 1x A40 GPU (batch size 256).
##### Performance
Performance of ResNet18 is on par with the pre-trained torchvision model.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment