Skip to content

Commit 9f2a5d8

Browse files
zuoxingdongsoumith
authored andcommitted
Add a flag to fix when dataset size is not divisible by batch size. (pytorch#1133)
1 parent aa506fa commit 9f2a5d8

File tree

2 files changed

+64
-9
lines changed

2 files changed

+64
-9
lines changed

test/test_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import unittest
99
import traceback
1010
import torch
11+
import torch.utils.data
1112
import torch.cuda
1213
import warnings
1314
from torch.autograd import Variable
@@ -107,6 +108,44 @@ def __len__(self):
107108
return 10
108109

109110

111+
class TestDataLoader(TestCase):
112+
def setUp(self):
113+
self.dataset = torch.randn(5, 3, 3, 2)
114+
self.batch_size = 3
115+
116+
def test_single_keep(self):
117+
dataloader = torch.utils.data.DataLoader(self.dataset,
118+
batch_size=self.batch_size,
119+
num_workers=0,
120+
drop_last=False)
121+
dataiter = iter(dataloader)
122+
self.assertEqual(len(list(dataiter)), 2)
123+
124+
def test_single_drop(self):
125+
dataloader = torch.utils.data.DataLoader(self.dataset,
126+
batch_size=self.batch_size,
127+
num_workers=0,
128+
drop_last=True)
129+
dataiter = iter(dataloader)
130+
self.assertEqual(len(list(dataiter)), 1)
131+
132+
def test_multi_keep(self):
133+
dataloader = torch.utils.data.DataLoader(self.dataset,
134+
batch_size=self.batch_size,
135+
num_workers=2,
136+
drop_last=False)
137+
dataiter = iter(dataloader)
138+
self.assertEqual(len(list(dataiter)), 2)
139+
140+
def test_multi_drop(self):
141+
dataloader = torch.utils.data.DataLoader(self.dataset,
142+
batch_size=self.batch_size,
143+
num_workers=2,
144+
drop_last=True)
145+
dataiter = iter(dataloader)
146+
self.assertEqual(len(list(dataiter)), 1)
147+
148+
110149
class TestTrainer(TestCase):
111150

112151
intervals = [

torch/utils/data/dataloader.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(self, loader):
103103
self.sampler = loader.sampler
104104
self.num_workers = loader.num_workers
105105
self.pin_memory = loader.pin_memory
106+
self.drop_last = loader.drop_last
106107
self.done_event = threading.Event()
107108

108109
self.samples_remaining = len(self.sampler)
@@ -141,11 +142,15 @@ def __init__(self, loader):
141142
self._put_indices()
142143

143144
def __len__(self):
144-
return int(math.ceil(len(self.sampler) / float(self.batch_size)))
145+
if self.drop_last:
146+
return len(self.sampler) // self.batch_size
147+
else:
148+
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
145149

146150
def __next__(self):
147-
if self.num_workers == 0:
148-
# same-process loading
151+
if self.num_workers == 0: # same-process loading
152+
if self.drop_last and self.samples_remaining < self.batch_size:
153+
raise StopIteration
149154
if self.samples_remaining == 0:
150155
raise StopIteration
151156
indices = self._next_indices()
@@ -187,9 +192,12 @@ def _next_indices(self):
187192
def _put_indices(self):
188193
assert self.batches_outstanding < 2 * self.num_workers
189194
if self.samples_remaining > 0:
190-
self.index_queue.put((self.send_idx, self._next_indices()))
191-
self.batches_outstanding += 1
192-
self.send_idx += 1
195+
if self.samples_remaining < self.batch_size and self.drop_last:
196+
self._next_indices()
197+
else:
198+
self.index_queue.put((self.send_idx, self._next_indices()))
199+
self.batches_outstanding += 1
200+
self.send_idx += 1
193201

194202
def _process_next_batch(self, batch):
195203
self.rcvd_idx += 1
@@ -236,15 +244,20 @@ class DataLoader(object):
236244
(default: 0)
237245
collate_fn (callable, optional)
238246
pin_memory (bool, optional)
247+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
248+
if the dataset size is not divisible by the batch size. If False and
249+
the size of dataset is not divisible by the batch size, then the last batch
250+
will be smaller. (default: False)
239251
"""
240252

241-
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
242-
num_workers=0, collate_fn=default_collate, pin_memory=False):
253+
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0,
254+
collate_fn=default_collate, pin_memory=False, drop_last=False):
243255
self.dataset = dataset
244256
self.batch_size = batch_size
245257
self.num_workers = num_workers
246258
self.collate_fn = collate_fn
247259
self.pin_memory = pin_memory
260+
self.drop_last = drop_last
248261

249262
if sampler is not None:
250263
self.sampler = sampler
@@ -257,4 +270,7 @@ def __iter__(self):
257270
return DataLoaderIter(self)
258271

259272
def __len__(self):
260-
return int(math.ceil(len(self.sampler) / float(self.batch_size)))
273+
if self.drop_last:
274+
return len(self.sampler) // self.batch_size
275+
else:
276+
return (len(self.sampler) + self.batch_size - 1) // self.batch_size

0 commit comments

Comments
 (0)