@@ -500,14 +500,16 @@ def flow_from_directory(self, directory,
500
500
save_prefix = '' ,
501
501
save_format = 'png' ,
502
502
follow_links = False ,
503
- interpolation = 'nearest' ):
503
+ interpolation = 'nearest' ,
504
+ save_mask = False ):
504
505
return DirectoryIterator (
505
506
directory , self ,
506
507
target_size = target_size , color_mode = color_mode ,
507
508
classes = classes , class_mode = class_mode ,
508
509
data_format = self .data_format ,
509
510
batch_size = batch_size , shuffle = shuffle , seed = seed ,
510
511
save_to_dir = save_to_dir ,
512
+ save_mask = save_mask ,
511
513
save_prefix = save_prefix ,
512
514
save_format = save_format ,
513
515
follow_links = follow_links ,
@@ -1018,7 +1020,8 @@ def __init__(self, directory, image_data_generator,
1018
1020
batch_size = 32 , shuffle = True , seed = None ,
1019
1021
data_format = None , save_to_dir = None ,
1020
1022
save_prefix = '' , save_format = 'png' ,
1021
- follow_links = False , interpolation = 'nearest' ):
1023
+ follow_links = False , interpolation = 'nearest' ,
1024
+ save_mask = False ):
1022
1025
if data_format is None :
1023
1026
data_format = K .image_data_format ()
1024
1027
self .directory = directory
@@ -1051,6 +1054,7 @@ def __init__(self, directory, image_data_generator,
1051
1054
self .save_prefix = save_prefix
1052
1055
self .save_format = save_format
1053
1056
self .interpolation = interpolation
1057
+ self .save_mask = save_mask
1054
1058
1055
1059
white_list_formats = {'png' , 'jpg' , 'jpeg' , 'bmp' , 'ppm' }
1056
1060
@@ -1117,6 +1121,18 @@ def _get_batches_of_transformed_samples(self, index_array):
1117
1121
hash = np .random .randint (1e7 ),
1118
1122
format = self .save_format )
1119
1123
img .save (os .path .join (self .save_to_dir , fname ))
1124
+
1125
+ if self .save_mask :
1126
+ # Create the mask
1127
+ mask = np .zeros (self .image_shape [:2 ])
1128
+ for px , layer_x in enumerate (mask ):
1129
+ for py , layer_y in enumerate (layer_x ):
1130
+ mask [px ][py ] = not all (batch_x [i ][px ][py ] == [1 , 1 , 1 ])
1131
+
1132
+ # Save the mask
1133
+ mask_fname = '{}.mask' .format (os .path .join (self .save_to_dir , fname ))
1134
+ np .save (mask_fname , mask , allow_pickle = True )
1135
+
1120
1136
# build batch of labels
1121
1137
if self .class_mode == 'input' :
1122
1138
batch_y = batch_x .copy ()
0 commit comments