@@ -103,6 +103,7 @@ def __init__(self, loader):
103103 self .sampler = loader .sampler
104104 self .num_workers = loader .num_workers
105105 self .pin_memory = loader .pin_memory
106+ self .drop_last = loader .drop_last
106107 self .done_event = threading .Event ()
107108
108109 self .samples_remaining = len (self .sampler )
@@ -141,11 +142,15 @@ def __init__(self, loader):
141142 self ._put_indices ()
142143
143144 def __len__ (self ):
144- return int (math .ceil (len (self .sampler ) / float (self .batch_size )))
145+ if self .drop_last :
146+ return len (self .sampler ) // self .batch_size
147+ else :
148+ return (len (self .sampler ) + self .batch_size - 1 ) // self .batch_size
145149
146150 def __next__ (self ):
147- if self .num_workers == 0 :
148- # same-process loading
151+ if self .num_workers == 0 : # same-process loading
152+ if self .drop_last and self .samples_remaining < self .batch_size :
153+ raise StopIteration
149154 if self .samples_remaining == 0 :
150155 raise StopIteration
151156 indices = self ._next_indices ()
@@ -187,9 +192,12 @@ def _next_indices(self):
187192 def _put_indices (self ):
188193 assert self .batches_outstanding < 2 * self .num_workers
189194 if self .samples_remaining > 0 :
190- self .index_queue .put ((self .send_idx , self ._next_indices ()))
191- self .batches_outstanding += 1
192- self .send_idx += 1
195+ if self .samples_remaining < self .batch_size and self .drop_last :
196+ self ._next_indices ()
197+ else :
198+ self .index_queue .put ((self .send_idx , self ._next_indices ()))
199+ self .batches_outstanding += 1
200+ self .send_idx += 1
193201
194202 def _process_next_batch (self , batch ):
195203 self .rcvd_idx += 1
@@ -236,15 +244,20 @@ class DataLoader(object):
236244 (default: 0)
237245 collate_fn (callable, optional)
238246 pin_memory (bool, optional)
247+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
248+ if the dataset size is not divisible by the batch size. If False and
249+ the size of dataset is not divisible by the batch size, then the last batch
250+ will be smaller. (default: False)
239251 """
240252
241- def __init__ (self , dataset , batch_size = 1 , shuffle = False , sampler = None ,
242- num_workers = 0 , collate_fn = default_collate , pin_memory = False ):
253+ def __init__ (self , dataset , batch_size = 1 , shuffle = False , sampler = None , num_workers = 0 ,
254+ collate_fn = default_collate , pin_memory = False , drop_last = False ):
243255 self .dataset = dataset
244256 self .batch_size = batch_size
245257 self .num_workers = num_workers
246258 self .collate_fn = collate_fn
247259 self .pin_memory = pin_memory
260+ self .drop_last = drop_last
248261
249262 if sampler is not None :
250263 self .sampler = sampler
@@ -257,4 +270,7 @@ def __iter__(self):
257270 return DataLoaderIter (self )
258271
259272 def __len__ (self ):
260- return int (math .ceil (len (self .sampler ) / float (self .batch_size )))
273+ if self .drop_last :
274+ return len (self .sampler ) // self .batch_size
275+ else :
276+ return (len (self .sampler ) + self .batch_size - 1 ) // self .batch_size
0 commit comments