Speed-Up Pytorch Data Pre-Processing with DALI

DALI 是 NVIDIA 推出一个数据预处理加速框架,可以把数据预处理也从CPU转移到GPU上完成。

支持以下图片格式的输入:flac, .ogg, .wav, .jpg, .jpeg, .png, .bmp, .tif, .tiff, .pnm, .ppm, .pgm, .pbm。

但 DALI 所使用的解码模块 nvjpeg 无法处理 gif 图片

1 安装

pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/cuda/10.0 nvidia-dali

2 在 pytorch 中使用 DALI 进行加速

2.1 训练阶段

初步测试结果,使用dali模块进行数据预处理可将训练时间缩短为原来的80%

from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIClassificationIterator

class HybridTrainPipe(Pipeline):
    def __init__(self, batch_size, data_dir, num_threads=4, device_id=0, crop=224, dali_cpu=False):
        super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
        self.input = ops.FileReader(file_root=data_dir, random_shuffle=True)

        dali_device = 'cpu' if dali_cpu else 'gpu'
        decoder_device = 'cpu' if dali_cpu else 'mixed'
        # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
        # without additional reallocations
        device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
        host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
        self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB,
                                                 device_memory_padding=device_memory_padding,
                                                 host_memory_padding=host_memory_padding,
                                                 random_aspect_ratio=[0.8, 1.25],
                                                 random_area=[0.6, 1.0],
                                                 num_attempts=100)
        self.res = ops.Resize(device=dali_device, resize_x=crop, resize_y=crop, interp_type=types.INTERP_TRIANGULAR)
        self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                            output_dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            crop=(crop, crop),
                                            image_type=types.RGB,
                                            mean=[0.485 * 255,0.456 * 255,0.406 * 255],
                                            std=[0.229 * 255,0.224 * 255,0.225 * 255])
        self.coin = ops.CoinFlip(probability=0.5)

    def define_graph(self):
        rng = self.coin()
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        output = self.cmnp(images.gpu(), mirror=rng)
        return [output, self.labels]

class HybridValPipe(Pipeline):
    def __init__(self, batch_size, data_dir, num_threads=4, device_id=0, crop=224, size=224):
        super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
        self.input = ops.FileReader(file_root=data_dir, random_shuffle=False)
        self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
        self.res = ops.Resize(device="gpu", resize_shorter=size, interp_type=types.INTERP_TRIANGULAR)
        self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                            output_dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            crop=(crop, crop),
                                            image_type=types.RGB,
                                            mean=[0.485 * 255,0.456 * 255,0.406 * 255],
                                            std=[0.229 * 255,0.224 * 255,0.225 * 255])

    def define_graph(self):
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        output = self.cmnp(images)
        return [output, self.labels]

      
def train(model, device, train_loader, NUM_TRAIN, epoch):
    # some code
    for batch_idx, data in enumerate(train_loader):
        inputs = data[0]["data"]
        labels = data[0]["label"].squeeze().cuda().long()
        outputs = model(inputs)
        
if __name__ == '__main__':
    pipe = HybridTrainPipe(batch_size=BATCH_SIZE, data_dir=os.path.join(data_dir, 'train'))
    pipe.build()
    NUM_TRAIN = pipe.epoch_size("Reader")
    train_loader = DALIClassificationIterator(pipe, size=NUM_TRAIN)

    pipe = HybridValPipe(batch_size=BATCH_SIZE, data_dir=os.path.join(data_dir, 'val'))
    pipe.build()
    NUM_VAL = pipe.epoch_size("Reader")
    val_loader = DALIClassificationIterator(pipe, size=NUM_VAL)
    print(NUM_TRAIN, NUM_VAL)

2.2 预测阶段

预测阶段的加速效果,取决于对CPU高占用率的缓解程度。

from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import numpy as np

class MyPipe(Pipeline):
    def __init__(self, fn, batch_size=1, num_threads=1, device_id=0, crop=224, size=224):
        super(MyPipe, self).__init__(batch_size, num_threads, device_id)
        self.input = ops.ExternalSource()
        self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
        #self.res = ops.Resize(device="gpu", resize_shorter=size, interp_type=types.INTERP_TRIANGULAR)
        #self.res = ops.Resize(device="gpu", resize_shorter=size, interp_type=types.INTERP_LINEAR)
        self.res = ops.Resize(device="gpu", resize_shorter=size, interp_type=types.INTERP_NN)
        self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                            output_dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            crop=(crop, crop),
                                            image_type=types.RGB,
                                            mean=[0.485 * 255,0.456 * 255,0.406 * 255],
                                            std=[0.229 * 255,0.224 * 255,0.225 * 255])

    def define_graph(self):
        self.jpegs = self.input()
        images = self.decode(self.jpegs)
        images = self.res(images)
        output = self.cmnp(images)
        return output

    def iter_setup(self):
        images = []
        fp = open(fn, 'rb')
        images.append(np.frombuffer(fp.read(), dtype = np.uint8))
        self.feed_input(self.jpegs, images, layout="HWC")
        
if __name__ == '__main__':
    model_path = constants.model_path
    pipe = MyPipe(fn)
    pipe.build()
    try:
        dali_iterator = DALIGenericIterator(pipe, ['images'], 1)
        img = dali_iterator.next()[0]['images']
        #print(type(img))
    except:
        continue
        

Reference

[1] DALI pytorch example

[2] DALI function api

[3] pytorch with nvidia dalia