88 import nvidia .dali .ops as ops
99 import nvidia .dali .types as types
1010except ImportError :
11- raise ImportError ("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example." )
11+ raise ImportError (
12+ "Please install DALI from https://www.github.com/NVIDIA/DALI to run this example." )
1213
1314
1415def get_data (dataset , data_path , cutout_length , auto_augmentation ):
@@ -29,11 +30,14 @@ def get_data(dataset, data_path, cutout_length, auto_augmentation):
2930 n_classes = 1000
3031 else :
3132 raise ValueError (dataset )
32- trn_transform , val_transform = preproc .data_transforms (dataset , cutout_length , auto_augmentation )
33+ trn_transform , val_transform = preproc .data_transforms (
34+ dataset , cutout_length , auto_augmentation )
3335 if 'imagenet' in dataset :
34- trn_data = dset_cls (root = os .path .join (data_path , 'train' ), transform = trn_transform )
36+ trn_data = dset_cls (root = os .path .join (
37+ data_path , 'train' ), transform = trn_transform )
3538 else :
36- trn_data = dset_cls (root = data_path , train = True , download = True , transform = trn_transform )
39+ trn_data = dset_cls (root = data_path , train = True ,
40+ download = True , transform = trn_transform )
3741
3842 # assuming shape is NHW or NHWC
3943 if 'imagenet' in dataset :
@@ -56,9 +60,11 @@ def get_data(dataset, data_path, cutout_length, auto_augmentation):
5660 input_size = shape [1 ]
5761 ret = [input_size , input_channels , n_classes , trn_data ]
5862 if 'imagenet' in dataset :
59- ret .append (dset_cls (root = os .path .join (data_path , 'val' ), transform = val_transform ))
63+ ret .append (dset_cls (root = os .path .join (
64+ data_path , 'val' ), transform = val_transform ))
6065 else :
61- ret .append (dset_cls (root = data_path , train = False , download = True , transform = val_transform ))
66+ ret .append (dset_cls (root = data_path , train = False ,
67+ download = True , transform = val_transform ))
6268 return ret
6369
6470
@@ -71,7 +77,7 @@ def get_data_dali(dataset, data_path, batch_size=256, num_threads=4):
7177 train_loader = cifar10 .get_cifar_iter_dali (type = 'train' , image_dir = data_path ,
7278 batch_size = batch_size , num_threads = num_threads )
7379 val_loader = cifar10 .get_cifar_iter_dali (type = 'val' , image_dir = data_path ,
74- batch_size = batch_size , num_threads = num_threads )
80+ batch_size = batch_size , num_threads = num_threads )
7581 elif dataset == 'imagenet' :
7682 input_size = 224
7783 input_channels = 3
@@ -80,8 +86,8 @@ def get_data_dali(dataset, data_path, batch_size=256, num_threads=4):
8086 batch_size = batch_size , num_threads = num_threads ,
8187 crop = 224 , val_size = 256 )
8288 val_loader = imagenet .get_imagenet_iter_dali (type = 'val' , image_dir = data_path ,
83- batch_size = batch_size , num_threads = num_threads ,
84- crop = 224 , val_size = 256 )
89+ batch_size = batch_size , num_threads = num_threads ,
90+ crop = 224 , val_size = 256 )
8591 elif dataset == 'imagenet112' :
8692 input_size = 112
8793 input_channels = 3
@@ -119,13 +125,16 @@ def get_data_dali(dataset, data_path, batch_size=256, num_threads=4):
119125
120126class HybridTrainPipe (Pipeline ):
121127 def __init__ (self , batch_size , num_threads , device_id , data_dir , crop , dali_cpu = False ):
122- super (HybridTrainPipe , self ).__init__ (batch_size , num_threads , device_id , seed = 12 + device_id )
123- self .input = ops .FileReader (file_root = data_dir , shard_id = 0 , num_shards = 1 , random_shuffle = True )
124- #let user decide which pipeline works him bets for RN version he runs
128+ super (HybridTrainPipe , self ).__init__ (batch_size ,
129+ num_threads , device_id , seed = 12 + device_id )
130+ self .input = ops .FileReader (
131+ file_root = data_dir , shard_id = 0 , num_shards = 1 , random_shuffle = True )
132+ # let user decide which pipeline works him bets for RN version he runs
125133 if dali_cpu :
126134 dali_device = "cpu"
127135 self .decode = ops .HostDecoderRandomCrop (device = dali_device , output_type = types .RGB ,
128- random_aspect_ratio = [0.8 , 1.25 ],
136+ random_aspect_ratio = [
137+ 0.8 , 1.25 ],
129138 random_area = [0.1 , 1.0 ],
130139 num_attempts = 100 )
131140 else :
@@ -135,17 +144,20 @@ def __init__(self, batch_size, num_threads, device_id, data_dir, crop, dali_cpu=
135144 self .decode = ops .nvJPEGDecoderRandomCrop (device = "mixed" , output_type = types .RGB ,
136145 device_memory_padding = 211025920 ,
137146 host_memory_padding = 140544512 ,
138- random_aspect_ratio = [0.8 , 1.25 ],
147+ random_aspect_ratio = [
148+ 0.8 , 1.25 ],
139149 random_area = [0.1 , 1.0 ],
140150 num_attempts = 100 )
141- self .res = ops .Resize (device = dali_device , resize_x = crop , resize_y = crop , interp_type = types .INTERP_TRIANGULAR )
151+ self .res = ops .Resize (device = dali_device , resize_x = crop ,
152+ resize_y = crop , interp_type = types .INTERP_TRIANGULAR )
142153 self .cmnp = ops .CropMirrorNormalize (device = "gpu" ,
143154 output_dtype = types .FLOAT ,
144155 output_layout = types .NCHW ,
145156 crop = (crop , crop ),
146157 image_type = types .RGB ,
147- mean = [0.485 * 255 ,0.456 * 255 ,0.406 * 255 ],
148- std = [0.229 * 255 ,0.224 * 255 ,0.225 * 255 ])
158+ mean = [0.485 * 255 , 0.456 *
159+ 255 , 0.406 * 255 ],
160+ std = [0.229 * 255 , 0.224 * 255 , 0.225 * 255 ])
149161 self .coin = ops .CoinFlip (probability = 0.5 )
150162 # self.color_jitter = [ops.Brightness(device="gpu", brightness=0.4),
151163 # ops.Contrast(device="gpu", contrast=0.4),
@@ -166,16 +178,20 @@ def define_graph(self):
166178
167179class HybridValPipe (Pipeline ):
168180 def __init__ (self , batch_size , num_threads , device_id , data_dir , crop , size ):
169- super (HybridValPipe , self ).__init__ (batch_size , num_threads , device_id , seed = 12 + device_id )
170- self .input = ops .FileReader (file_root = data_dir , shard_id = 0 , num_shards = 1 , random_shuffle = False )
181+ super (HybridValPipe , self ).__init__ (batch_size ,
182+ num_threads , device_id , seed = 12 + device_id )
183+ self .input = ops .FileReader (
184+ file_root = data_dir , shard_id = 0 , num_shards = 1 , random_shuffle = False )
171185 self .decode = ops .nvJPEGDecoder (device = "mixed" , output_type = types .RGB )
172- self .res = ops .Resize (device = "gpu" , resize_shorter = size , interp_type = types .INTERP_TRIANGULAR )
186+ self .res = ops .Resize (device = "gpu" , resize_shorter = size ,
187+ interp_type = types .INTERP_TRIANGULAR )
173188 self .cmnp = ops .CropMirrorNormalize (device = "gpu" ,
174189 output_dtype = types .FLOAT ,
175190 output_layout = types .NCHW ,
176191 crop = (crop , crop ),
177192 image_type = types .RGB ,
178- mean = [0.485 * 255 , 0.456 * 255 , 0.406 * 255 ],
193+ mean = [0.485 * 255 , 0.456 *
194+ 255 , 0.406 * 255 ],
179195 std = [0.229 * 255 , 0.224 * 255 , 0.225 * 255 ])
180196
181197 def define_graph (self ):
@@ -189,13 +205,12 @@ def define_graph(self):
189205def get_dali_imagenet_pipeline (batch_size , num_threads , data_path , train_cpu = False ,
190206 crop = 224 , size = 256 ):
191207 train_pipe = HybridTrainPipe (batch_size = batch_size , num_threads = num_threads , device_id = 0 ,
192- data_dir = os .path .join (data_path , 'train' ),
193- crop = crop , dali_cpu = train_cpu )
208+ data_dir = os .path .join (data_path , 'train' ),
209+ crop = crop , dali_cpu = train_cpu )
194210 train_pipe .build ()
195211
196212 val_pipe = HybridValPipe (batch_size = batch_size , num_threads = num_threads , device_id = 0 ,
197- data_dir = os .path .join (data_path , 'val' ),
198- crop = crop , size = size )
213+ data_dir = os .path .join (data_path , 'val' ),
214+ crop = crop , size = size )
199215 val_pipe .build ()
200216 return [train_pipe , val_pipe ]
201-
0 commit comments