Skip to content

Commit 657b9fb

Browse files
committed
Merge pull request keras-team#1651 from tboquet/fit_gen_tensorb
Support + tests for fit_generator + tensorboard
2 parents cae797b + d68e331 commit 657b9fb

File tree

3 files changed

+97
-25
lines changed

3 files changed

+97
-25
lines changed

keras/callbacks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,14 +456,15 @@ def __init__(self, log_dir='./logs', histogram_freq=0):
456456
'with the TensorFlow backend.')
457457
self.log_dir = log_dir
458458
self.histogram_freq = histogram_freq
459+
self.merged = None
459460

460461
def _set_model(self, model):
461462
import tensorflow as tf
462463
import keras.backend.tensorflow_backend as KTF
463464

464465
self.model = model
465466
self.sess = KTF._get_session()
466-
if self.histogram_freq:
467+
if self.histogram_freq and not self.merged:
467468
mod_type = self.model.get_config()['name']
468469
if mod_type == 'Sequential':
469470
layers = {l.get_config()['name']: l for l in self.model.layers}
@@ -515,7 +516,7 @@ def on_epoch_end(self, epoch, logs={}):
515516

516517
all_values = self.totals.copy()
517518
all_values.update(logs)
518-
519+
519520
for name, value in all_values.items():
520521
if name in ['batch', 'size']:
521522
continue

keras/models.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -971,8 +971,17 @@ def input_validation(generator_output):
971971
_stop.set()
972972
raise Exception('The generator output tuple must have '
973973
'2 or 3 elements.')
974+
975+
sample_weight = standardize_weights(y, sample_weight=sample_weight,
976+
sample_weight_mode=self.sample_weight_mode)
974977
return X, y, sample_weight
975978

979+
if do_validation:
980+
X_val, y_val, sample_weight_val = input_validation(validation_data)
981+
self.validation_data = X_val + [y_val, sample_weight_val]
982+
else:
983+
self.validation_data = None
984+
976985
# start generator thread storing batches into a queue
977986
generator_queue = queue.Queue()
978987
_stop = threading.Event()
@@ -1044,10 +1053,9 @@ def generator_task():
10441053
raise NotImplementedError()
10451054
else:
10461055
# input validation
1047-
X, y, sample_weight = input_validation(validation_data)
1048-
val_outs = self.evaluate(X, y,
1056+
val_outs = self.evaluate(X_val, y_val,
10491057
show_accuracy=show_accuracy,
1050-
sample_weight=sample_weight,
1058+
sample_weight=sample_weight_val,
10511059
verbose=0)
10521060
if type(val_outs) != list:
10531061
val_outs = [val_outs]
@@ -1435,8 +1443,19 @@ def input_validation(generator_output):
14351443
[len(sample_weight[name]) for name in sample_weight.keys()])) != 1:
14361444
raise Exception('All input arrays and target arrays must have '
14371445
'the same number of samples.')
1446+
sample_weight = {name: standardize_weights(data[name],
1447+
sample_weight=sample_weight.get(name),
1448+
sample_weight_mode=self.sample_weight_modes.get(name)) for name in self.output_order}
14381449
return data, sample_weight
14391450

1451+
if do_validation:
1452+
data_val, sample_weight_val = input_validation(validation_data)
1453+
sample_weight_val_l = [sample_weight_val[name] for name in self.output_order]
1454+
y_val = [standardize_y(data_val[name]) for name in self.output_order]
1455+
self.validation_data = [data_val[name] for name in self.input_order] + y_val + sample_weight_val_l
1456+
else:
1457+
self.validation_data = None
1458+
14401459
# start generator thread storing batches into a queue
14411460
generator_queue = queue.Queue()
14421461
_stop = threading.Event()
@@ -1502,10 +1521,8 @@ def generator_task():
15021521
_stop.set()
15031522
raise NotImplementedError()
15041523
else:
1505-
# input validation
1506-
data, sample_weight = input_validation(validation_data)
1507-
val_outs = self.evaluate(data,
1508-
sample_weight=sample_weight,
1524+
val_outs = self.evaluate(data_val,
1525+
sample_weight=sample_weight_val,
15091526
verbose=0)
15101527
if type(val_outs) != list:
15111528
val_outs = [val_outs]

tests/keras/test_callbacks.py

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -136,23 +136,30 @@ def test_TensorBoard():
136136
nb_class=nb_class)
137137
y_test = np_utils.to_categorical(y_test)
138138
y_train = np_utils.to_categorical(y_train)
139-
# case 1 Sequential wo accuracy
140-
with tf.Graph().as_default():
141-
session = tf.Session('')
142-
KTF._set_session(session)
143-
model = Sequential()
144-
model.add(Dense(nb_hidden, input_dim=input_dim, activation='relu'))
145-
model.add(Dense(nb_class, activation='softmax'))
146-
model.compile(loss='categorical_crossentropy', optimizer='sgd')
147139

148-
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
149-
cbks = [tsb]
150-
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=True,
151-
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)
152-
assert os.path.exists(filepath)
153-
shutil.rmtree(filepath)
140+
def data_generator(train):
141+
if train:
142+
max_batch_index = len(X_train) // batch_size
143+
else:
144+
max_batch_index = len(X_test) // batch_size
145+
i = 0
146+
while 1:
147+
if train:
148+
yield (X_train[i * batch_size: (i + 1) * batch_size], y_train[i * batch_size: (i + 1) * batch_size])
149+
else:
150+
yield (X_test[i * batch_size: (i + 1) * batch_size], y_test[i * batch_size: (i + 1) * batch_size])
151+
i += 1
152+
i = i % max_batch_index
153+
154+
def data_generator_graph(train):
155+
while 1:
156+
if train:
157+
yield {'X_vars': X_train, 'output': y_train}
158+
else:
159+
yield {'X_vars': X_test, 'output': y_test}
160+
161+
# case 1 Sequential
154162

155-
# case 2 Sequential w accuracy
156163
with tf.Graph().as_default():
157164
session = tf.Session('')
158165
KTF._set_session(session)
@@ -163,12 +170,42 @@ def test_TensorBoard():
163170

164171
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
165172
cbks = [tsb]
173+
174+
# fit with validation data
175+
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=False,
176+
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)
177+
178+
# fit with validation data and accuracy
166179
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=True,
167180
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)
181+
182+
# fit generator with validation data
183+
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
184+
show_accuracy=False,
185+
validation_data=(X_test, y_test),
186+
callbacks=cbks)
187+
188+
# fit generator without validation data
189+
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
190+
show_accuracy=False,
191+
callbacks=cbks)
192+
193+
# fit generator with validation data and accuracy
194+
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
195+
show_accuracy=True,
196+
validation_data=(X_test, y_test),
197+
callbacks=cbks)
198+
199+
# fit generator without validation data and accuracy
200+
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
201+
show_accuracy=True,
202+
callbacks=cbks)
203+
168204
assert os.path.exists(filepath)
169205
shutil.rmtree(filepath)
170206

171-
# case 3 Graph
207+
# case 2 Graph
208+
172209
with tf.Graph().as_default():
173210
session = tf.Session('')
174211
KTF._set_session(session)
@@ -185,10 +222,27 @@ def test_TensorBoard():
185222

186223
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
187224
cbks = [tsb]
225+
226+
# fit with validation
188227
model.fit({'X_vars': X_train, 'output': y_train},
189228
batch_size=batch_size,
190229
validation_data={'X_vars': X_test, 'output': y_test},
191230
callbacks=cbks, nb_epoch=2)
231+
232+
# fit wo validation
233+
model.fit({'X_vars': X_train, 'output': y_train},
234+
batch_size=batch_size,
235+
callbacks=cbks, nb_epoch=2)
236+
237+
# fit generator with validation
238+
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
239+
validation_data={'X_vars': X_test, 'output': y_test},
240+
callbacks=cbks)
241+
242+
# fit generator wo validation
243+
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
244+
callbacks=cbks)
245+
192246
assert os.path.exists(filepath)
193247
shutil.rmtree(filepath)
194248

0 commit comments

Comments
 (0)