Skip to content

Commit a20eb0b

Browse files
committed
added balanced classes generation in the preprocessing module for images.
1 parent 4fa7e5d commit a20eb0b

File tree

1 file changed

+93
-6
lines changed

1 file changed

+93
-6
lines changed

keras/preprocessing/image.py

Lines changed: 93 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,47 @@ def list_pictures(directory, ext='jpg|jpeg|bmp|png'):
309309
if re.match('([\w]+\.(?:' + ext + '))', f)]
310310

311311

312+
def make_balanced_batches(samples_size, batch_size, labels):
313+
#print("Making balanced batches...")
314+
def chunkify(lst, n):
315+
if type(lst) == xrange:
316+
lst = list(lst)
317+
return [lst[i::n] for i in xrange(n)]
318+
assert(labels.shape[0] == samples_size)
319+
nb_batch = int(np.ceil(samples_size/float(batch_size)))
320+
l_classes = np.unique(labels)
321+
n_classes = len(l_classes)
322+
chunks_per_cl = chunkify(range(batch_size), n_classes)
323+
324+
l_all_ind = []
325+
for i_batch in range(nb_batch):
326+
perm_cl = np.random.permutation(l_classes)
327+
l_ind = []
328+
#print("Progress: %f" % (i_batch/float(nb_batch)))
329+
for i_cl, cl in enumerate(perm_cl):
330+
n_chunks = len(chunks_per_cl[i_cl])
331+
# indexes in the label array corresponding to class cl
332+
ind_of_class = np.where(labels == cl)[0]
333+
# random permute indexes
334+
perm_cl_ind = np.random.permutation(ind_of_class)
335+
# and take only n_chunks
336+
perm_cl_ind = list(perm_cl_ind[:n_chunks])
337+
# if n_chunks not available, keep adding 1 random number at a time until I reac
338+
# n_chunks
339+
while len(perm_cl_ind) < n_chunks:
340+
ind = np.random.permutation(ind_of_class)[0]
341+
perm_cl_ind.append(ind)
342+
l_ind.append(perm_cl_ind)
343+
l_ind = tuple(sum(l_ind, []))
344+
l_all_ind.append(l_ind)
345+
assert(len(l_ind) == batch_size)
346+
l_all_ind = tuple(l_all_ind)
347+
assert(len(l_all_ind) == nb_batch)
348+
l_all_ind = np.array(sum(l_all_ind, ()))
349+
350+
#print("Done making balanced batches...")
351+
return l_all_ind
352+
312353
class ImageDataGenerator(object):
313354
"""Generate minibatches of image data with real-time data augmentation.
314355
@@ -415,7 +456,8 @@ def __init__(self,
415456
'Received arg: ', zoom_range)
416457

417458
def flow(self, X, y=None, batch_size=32, shuffle=True, seed=None,
418-
save_to_dir=None, save_prefix='', save_format='jpeg'):
459+
save_to_dir=None, save_prefix='', save_format='jpeg',
460+
balanced_classes = False):
419461
return NumpyArrayIterator(
420462
X, y, self,
421463
batch_size=batch_size,
@@ -424,7 +466,8 @@ def flow(self, X, y=None, batch_size=32, shuffle=True, seed=None,
424466
dim_ordering=self.dim_ordering,
425467
save_to_dir=save_to_dir,
426468
save_prefix=save_prefix,
427-
save_format=save_format)
469+
save_format=save_format,
470+
balanced_classes = balanced_classes)
428471

429472
def flow_from_directory(self, directory,
430473
target_size=(256, 256), color_mode='rgb',
@@ -619,20 +662,28 @@ def fit(self, x,
619662

620663
class Iterator(object):
621664

622-
def __init__(self, n, batch_size, shuffle, seed):
665+
def __init__(self, n, batch_size, shuffle, seed, targets = None):
623666
self.n = n
624667
self.batch_size = batch_size
625668
self.shuffle = shuffle
626669
self.batch_index = 0
627670
self.total_batches_seen = 0
628671
self.lock = threading.Lock()
629-
self.index_generator = self._flow_index(n, batch_size, shuffle, seed)
672+
if targets is not None:
673+
self.index_generator = self._flow_index_balanced(n,
674+
batch_size,
675+
shuffle,
676+
seed,
677+
targets = targets)
678+
else:
679+
self.index_generator = self._flow_index(n, batch_size, shuffle, seed)
630680

631681
def reset(self):
632682
self.batch_index = 0
633683

634684
def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
635685
# ensure self.batch_index is 0
686+
636687
self.reset()
637688
while 1:
638689
if seed is not None:
@@ -653,6 +704,29 @@ def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
653704
yield (index_array[current_index: current_index + current_batch_size],
654705
current_index, current_batch_size)
655706

707+
def _flow_index_balanced(self, n, batch_size=32, shuffle=False, seed=None, targets = None):
708+
# ensure self.batch_index is 0
709+
self.reset()
710+
while 1:
711+
if seed is not None:
712+
np.random.seed(seed + self.total_batches_seen)
713+
if self.batch_index == 0:
714+
index_array = np.arange(n)
715+
if shuffle:
716+
#index_array = np.random.permutation(n)
717+
index_array = make_balanced_batches(n, batch_size, targets)
718+
719+
current_index = (self.batch_index * batch_size) % n
720+
if n >= current_index + batch_size:
721+
current_batch_size = batch_size
722+
self.batch_index += 1
723+
else:
724+
current_batch_size = n - current_index
725+
self.batch_index = 0
726+
self.total_batches_seen += 1
727+
yield (index_array[current_index: current_index + current_batch_size],
728+
current_index, current_batch_size)
729+
656730
def __iter__(self):
657731
# needed if we want to do something like:
658732
# for x, y in data_gen.flow(...):
@@ -667,7 +741,8 @@ class NumpyArrayIterator(Iterator):
667741
def __init__(self, x, y, image_data_generator,
668742
batch_size=32, shuffle=False, seed=None,
669743
dim_ordering='default',
670-
save_to_dir=None, save_prefix='', save_format='jpeg'):
744+
save_to_dir=None, save_prefix='', save_format='jpeg',
745+
balanced_classes = False):
671746
if y is not None and len(x) != len(y):
672747
raise ValueError('X (images tensor) and y (labels) '
673748
'should have the same length. '
@@ -697,7 +772,18 @@ def __init__(self, x, y, image_data_generator,
697772
self.save_to_dir = save_to_dir
698773
self.save_prefix = save_prefix
699774
self.save_format = save_format
700-
super(NumpyArrayIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)
775+
776+
if balanced_classes:
777+
super(NumpyArrayIterator, self).__init__(x.shape[0],
778+
batch_size,
779+
shuffle,
780+
seed,
781+
targets = y.argmax(axis = 1))
782+
else:
783+
super(NumpyArrayIterator, self).__init__(x.shape[0],
784+
batch_size,
785+
shuffle,
786+
seed)
701787

702788
def next(self):
703789
# for python 2.x.
@@ -725,6 +811,7 @@ def next(self):
725811
if self.y is None:
726812
return batch_x
727813
batch_y = self.y[index_array]
814+
#print(batch_y.argmax(1))
728815
return batch_x, batch_y
729816

730817

0 commit comments

Comments
 (0)