Skip to content

Commit 1ca48b6

Browse files
committed
Add a flag to save the mask for the generated image
1 parent d045e10 commit 1ca48b6

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

keras/preprocessing/image.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,14 +500,16 @@ def flow_from_directory(self, directory,
500500
save_prefix='',
501501
save_format='png',
502502
follow_links=False,
503-
interpolation='nearest'):
503+
interpolation='nearest',
504+
save_mask=False):
504505
return DirectoryIterator(
505506
directory, self,
506507
target_size=target_size, color_mode=color_mode,
507508
classes=classes, class_mode=class_mode,
508509
data_format=self.data_format,
509510
batch_size=batch_size, shuffle=shuffle, seed=seed,
510511
save_to_dir=save_to_dir,
512+
save_mask=save_mask,
511513
save_prefix=save_prefix,
512514
save_format=save_format,
513515
follow_links=follow_links,
@@ -1018,7 +1020,8 @@ def __init__(self, directory, image_data_generator,
10181020
batch_size=32, shuffle=True, seed=None,
10191021
data_format=None, save_to_dir=None,
10201022
save_prefix='', save_format='png',
1021-
follow_links=False, interpolation='nearest'):
1023+
follow_links=False, interpolation='nearest',
1024+
save_mask=False):
10221025
if data_format is None:
10231026
data_format = K.image_data_format()
10241027
self.directory = directory
@@ -1051,6 +1054,7 @@ def __init__(self, directory, image_data_generator,
10511054
self.save_prefix = save_prefix
10521055
self.save_format = save_format
10531056
self.interpolation = interpolation
1057+
self.save_mask = save_mask
10541058

10551059
white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', 'ppm'}
10561060

@@ -1117,6 +1121,18 @@ def _get_batches_of_transformed_samples(self, index_array):
11171121
hash=np.random.randint(1e7),
11181122
format=self.save_format)
11191123
img.save(os.path.join(self.save_to_dir, fname))
1124+
1125+
if self.save_mask:
1126+
# Create the mask
1127+
mask = np.zeros(self.image_shape[:2])
1128+
for px, layer_x in enumerate(mask):
1129+
for py, layer_y in enumerate(layer_x):
1130+
mask[px][py] = not all(batch_x[i][px][py] == [1, 1, 1])
1131+
1132+
# Save the mask
1133+
mask_fname = '{}.mask'.format(os.path.join(self.save_to_dir, fname))
1134+
np.save(mask_fname, mask, allow_pickle=True)
1135+
11201136
# build batch of labels
11211137
if self.class_mode == 'input':
11221138
batch_y = batch_x.copy()

0 commit comments

Comments
 (0)