@@ -309,6 +309,47 @@ def list_pictures(directory, ext='jpg|jpeg|bmp|png'):
309
309
if re .match ('([\w]+\.(?:' + ext + '))' , f )]
310
310
311
311
312
+ def make_balanced_batches (samples_size , batch_size , labels ):
313
+ #print("Making balanced batches...")
314
+ def chunkify (lst , n ):
315
+ if type (lst ) == xrange :
316
+ lst = list (lst )
317
+ return [lst [i ::n ] for i in xrange (n )]
318
+ assert (labels .shape [0 ] == samples_size )
319
+ nb_batch = int (np .ceil (samples_size / float (batch_size )))
320
+ l_classes = np .unique (labels )
321
+ n_classes = len (l_classes )
322
+ chunks_per_cl = chunkify (range (batch_size ), n_classes )
323
+
324
+ l_all_ind = []
325
+ for i_batch in range (nb_batch ):
326
+ perm_cl = np .random .permutation (l_classes )
327
+ l_ind = []
328
+ #print("Progress: %f" % (i_batch/float(nb_batch)))
329
+ for i_cl , cl in enumerate (perm_cl ):
330
+ n_chunks = len (chunks_per_cl [i_cl ])
331
+ # indexes in the label array corresponding to class cl
332
+ ind_of_class = np .where (labels == cl )[0 ]
333
+ # random permute indexes
334
+ perm_cl_ind = np .random .permutation (ind_of_class )
335
+ # and take only n_chunks
336
+ perm_cl_ind = list (perm_cl_ind [:n_chunks ])
337
+ # if n_chunks not available, keep adding 1 random number at a time until I reac
338
+ # n_chunks
339
+ while len (perm_cl_ind ) < n_chunks :
340
+ ind = np .random .permutation (ind_of_class )[0 ]
341
+ perm_cl_ind .append (ind )
342
+ l_ind .append (perm_cl_ind )
343
+ l_ind = tuple (sum (l_ind , []))
344
+ l_all_ind .append (l_ind )
345
+ assert (len (l_ind ) == batch_size )
346
+ l_all_ind = tuple (l_all_ind )
347
+ assert (len (l_all_ind ) == nb_batch )
348
+ l_all_ind = np .array (sum (l_all_ind , ()))
349
+
350
+ #print("Done making balanced batches...")
351
+ return l_all_ind
352
+
312
353
class ImageDataGenerator (object ):
313
354
"""Generate minibatches of image data with real-time data augmentation.
314
355
@@ -415,7 +456,8 @@ def __init__(self,
415
456
'Received arg: ' , zoom_range )
416
457
417
458
def flow (self , X , y = None , batch_size = 32 , shuffle = True , seed = None ,
418
- save_to_dir = None , save_prefix = '' , save_format = 'jpeg' ):
459
+ save_to_dir = None , save_prefix = '' , save_format = 'jpeg' ,
460
+ balanced_classes = False ):
419
461
return NumpyArrayIterator (
420
462
X , y , self ,
421
463
batch_size = batch_size ,
@@ -424,7 +466,8 @@ def flow(self, X, y=None, batch_size=32, shuffle=True, seed=None,
424
466
dim_ordering = self .dim_ordering ,
425
467
save_to_dir = save_to_dir ,
426
468
save_prefix = save_prefix ,
427
- save_format = save_format )
469
+ save_format = save_format ,
470
+ balanced_classes = balanced_classes )
428
471
429
472
def flow_from_directory (self , directory ,
430
473
target_size = (256 , 256 ), color_mode = 'rgb' ,
@@ -619,20 +662,28 @@ def fit(self, x,
619
662
620
663
class Iterator (object ):
621
664
622
- def __init__ (self , n , batch_size , shuffle , seed ):
665
+ def __init__ (self , n , batch_size , shuffle , seed , targets = None ):
623
666
self .n = n
624
667
self .batch_size = batch_size
625
668
self .shuffle = shuffle
626
669
self .batch_index = 0
627
670
self .total_batches_seen = 0
628
671
self .lock = threading .Lock ()
629
- self .index_generator = self ._flow_index (n , batch_size , shuffle , seed )
672
+ if targets is not None :
673
+ self .index_generator = self ._flow_index_balanced (n ,
674
+ batch_size ,
675
+ shuffle ,
676
+ seed ,
677
+ targets = targets )
678
+ else :
679
+ self .index_generator = self ._flow_index (n , batch_size , shuffle , seed )
630
680
631
681
def reset (self ):
632
682
self .batch_index = 0
633
683
634
684
def _flow_index (self , n , batch_size = 32 , shuffle = False , seed = None ):
635
685
# ensure self.batch_index is 0
686
+
636
687
self .reset ()
637
688
while 1 :
638
689
if seed is not None :
@@ -653,6 +704,29 @@ def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
653
704
yield (index_array [current_index : current_index + current_batch_size ],
654
705
current_index , current_batch_size )
655
706
707
+ def _flow_index_balanced (self , n , batch_size = 32 , shuffle = False , seed = None , targets = None ):
708
+ # ensure self.batch_index is 0
709
+ self .reset ()
710
+ while 1 :
711
+ if seed is not None :
712
+ np .random .seed (seed + self .total_batches_seen )
713
+ if self .batch_index == 0 :
714
+ index_array = np .arange (n )
715
+ if shuffle :
716
+ #index_array = np.random.permutation(n)
717
+ index_array = make_balanced_batches (n , batch_size , targets )
718
+
719
+ current_index = (self .batch_index * batch_size ) % n
720
+ if n >= current_index + batch_size :
721
+ current_batch_size = batch_size
722
+ self .batch_index += 1
723
+ else :
724
+ current_batch_size = n - current_index
725
+ self .batch_index = 0
726
+ self .total_batches_seen += 1
727
+ yield (index_array [current_index : current_index + current_batch_size ],
728
+ current_index , current_batch_size )
729
+
656
730
def __iter__ (self ):
657
731
# needed if we want to do something like:
658
732
# for x, y in data_gen.flow(...):
@@ -667,7 +741,8 @@ class NumpyArrayIterator(Iterator):
667
741
def __init__ (self , x , y , image_data_generator ,
668
742
batch_size = 32 , shuffle = False , seed = None ,
669
743
dim_ordering = 'default' ,
670
- save_to_dir = None , save_prefix = '' , save_format = 'jpeg' ):
744
+ save_to_dir = None , save_prefix = '' , save_format = 'jpeg' ,
745
+ balanced_classes = False ):
671
746
if y is not None and len (x ) != len (y ):
672
747
raise ValueError ('X (images tensor) and y (labels) '
673
748
'should have the same length. '
@@ -697,7 +772,18 @@ def __init__(self, x, y, image_data_generator,
697
772
self .save_to_dir = save_to_dir
698
773
self .save_prefix = save_prefix
699
774
self .save_format = save_format
700
- super (NumpyArrayIterator , self ).__init__ (x .shape [0 ], batch_size , shuffle , seed )
775
+
776
+ if balanced_classes :
777
+ super (NumpyArrayIterator , self ).__init__ (x .shape [0 ],
778
+ batch_size ,
779
+ shuffle ,
780
+ seed ,
781
+ targets = y .argmax (axis = 1 ))
782
+ else :
783
+ super (NumpyArrayIterator , self ).__init__ (x .shape [0 ],
784
+ batch_size ,
785
+ shuffle ,
786
+ seed )
701
787
702
788
def next (self ):
703
789
# for python 2.x.
@@ -725,6 +811,7 @@ def next(self):
725
811
if self .y is None :
726
812
return batch_x
727
813
batch_y = self .y [index_array ]
814
+ #print(batch_y.argmax(1))
728
815
return batch_x , batch_y
729
816
730
817
0 commit comments