44from functools import partial
55
66import numpy as np
7+ import torch
78from mmcv .parallel import collate
89from mmcv .runner import get_dist_info
910from mmcv .utils import Registry , build_from_cfg
10- from mmcv .utils .parrots_wrapper import DataLoader , PoolDataLoader
11- from torch .utils .data import DistributedSampler
11+ from torch .utils .data import DataLoader , DistributedSampler
1212
1313if platform .system () != 'Windows' :
1414 # https://github.com/pytorch/pytorch/issues/973
@@ -84,7 +84,7 @@ def build_dataloader(dataset,
8484 seed = None ,
8585 drop_last = False ,
8686 pin_memory = True ,
87- dataloader_type = 'PoolDataLoader' ,
87+ persistent_workers = True ,
8888 ** kwargs ):
8989 """Build PyTorch DataLoader.
9090
@@ -106,7 +106,11 @@ def build_dataloader(dataset,
106106 Default: False
107107 pin_memory (bool): Whether to use pin_memory in DataLoader.
108108 Default: True
109- dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
109+ persistent_workers (bool): If True, the data loader will not shutdown
110+ the worker processes after a dataset has been consumed once.
111+ This allows to maintain the workers Dataset instances alive.
112+ The argument also has effect in PyTorch>=1.7.0.
113+ Default: True
110114 kwargs: any keyword argument to be used to initialize DataLoader
111115
112116 Returns:
@@ -128,26 +132,31 @@ def build_dataloader(dataset,
128132 worker_init_fn , num_workers = num_workers , rank = rank ,
129133 seed = seed ) if seed is not None else None
130134
131- assert dataloader_type in (
132- 'DataLoader' ,
133- 'PoolDataLoader' ), f'unsupported dataloader { dataloader_type } '
134-
135- if dataloader_type == 'PoolDataLoader' :
136- dataloader = PoolDataLoader
137- elif dataloader_type == 'DataLoader' :
138- dataloader = DataLoader
139-
140- data_loader = dataloader (
141- dataset ,
142- batch_size = batch_size ,
143- sampler = sampler ,
144- num_workers = num_workers ,
145- collate_fn = partial (collate , samples_per_gpu = samples_per_gpu ),
146- pin_memory = pin_memory ,
147- shuffle = shuffle ,
148- worker_init_fn = init_fn ,
149- drop_last = drop_last ,
150- ** kwargs )
135+ if torch .__version__ >= '1.7.0' :
136+ data_loader = DataLoader (
137+ dataset ,
138+ batch_size = batch_size ,
139+ sampler = sampler ,
140+ num_workers = num_workers ,
141+ collate_fn = partial (collate , samples_per_gpu = samples_per_gpu ),
142+ pin_memory = pin_memory ,
143+ shuffle = shuffle ,
144+ worker_init_fn = init_fn ,
145+ drop_last = drop_last ,
146+ persistent_workers = persistent_workers ,
147+ ** kwargs )
148+ else :
149+ data_loader = DataLoader (
150+ dataset ,
151+ batch_size = batch_size ,
152+ sampler = sampler ,
153+ num_workers = num_workers ,
154+ collate_fn = partial (collate , samples_per_gpu = samples_per_gpu ),
155+ pin_memory = pin_memory ,
156+ shuffle = shuffle ,
157+ worker_init_fn = init_fn ,
158+ drop_last = drop_last ,
159+ ** kwargs )
151160
152161 return data_loader
153162
0 commit comments