Skip to content

Commit ecbf73f

Browse files
gabrieldemarmiessefchollet
authored andcommitted
[RELNOTES] [P] Write to TensorBoard every x samples. (keras-team#11152)
* Working on improving tensor flow callbacks * Adding batch level TensorBoard logging (implementing the `on_batch_end` method to the TensorBoard class * Interim commit -- added notes. * Corrected stylistic issues -- brought to compliance w/ PEP8 * Added the missing argument in the test suite. * Added the possibility to choose how frequently tensorboard should log the metrics and losses. * Fixed the issue of the validation data not being displayed. * Fixed the issue about the callback not remembering when was the last time it wrote to the logs. * Removed the error check. * Used update_freq instead of write_step. * Forgot to change the constructor call.
1 parent ae6474d commit ecbf73f

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

keras/callbacks.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,12 @@ class TensorBoard(Callback):
719719
input) or list of Numpy arrays (if the model has multiple inputs).
720720
Learn [more about embeddings]
721721
(https://www.tensorflow.org/programmers_guide/embedding).
722+
update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, writes
723+
the losses and metrics to TensorBoard after each batch. The same
724+
applies for `'epoch'`. If using an integer, let's say `10000`,
725+
the callback will write the metrics and losses to TensorBoard every
726+
10000 samples. Note that writing too frequently to TensorBoard
727+
can slow down your training.
722728
"""
723729

724730
def __init__(self, log_dir='./logs',
@@ -730,7 +736,8 @@ def __init__(self, log_dir='./logs',
730736
embeddings_freq=0,
731737
embeddings_layer_names=None,
732738
embeddings_metadata=None,
733-
embeddings_data=None):
739+
embeddings_data=None,
740+
update_freq='epoch'):
734741
super(TensorBoard, self).__init__()
735742
global tf, projector
736743
try:
@@ -769,6 +776,13 @@ def __init__(self, log_dir='./logs',
769776
self.embeddings_metadata = embeddings_metadata or {}
770777
self.batch_size = batch_size
771778
self.embeddings_data = embeddings_data
779+
if update_freq == 'batch':
780+
# It is the same as writing as frequently as possible.
781+
self.update_freq = 1
782+
else:
783+
self.update_freq = update_freq
784+
self.samples_seen = 0
785+
self.samples_seen_at_last_write = 0
772786

773787
def set_model(self, model):
774788
self.model = model
@@ -968,6 +982,13 @@ def on_epoch_end(self, epoch, logs=None):
968982

969983
i += self.batch_size
970984

985+
if self.update_freq == 'epoch':
986+
index = epoch
987+
else:
988+
index = self.samples_seen
989+
self._write_logs(logs, index)
990+
991+
def _write_logs(self, logs, index):
971992
for name, value in logs.items():
972993
if name in ['batch', 'size']:
973994
continue
@@ -978,12 +999,20 @@ def on_epoch_end(self, epoch, logs=None):
978999
else:
9791000
summary_value.simple_value = value
9801001
summary_value.tag = name
981-
self.writer.add_summary(summary, epoch)
1002+
self.writer.add_summary(summary, index)
9821003
self.writer.flush()
9831004

9841005
def on_train_end(self, _):
9851006
self.writer.close()
9861007

1008+
def on_batch_end(self, batch, logs=None):
1009+
if self.update_freq != 'epoch':
1010+
self.samples_seen += logs['size']
1011+
samples_seen_since = self.samples_seen - self.samples_seen_at_last_write
1012+
if samples_seen_since >= self.update_freq:
1013+
self._write_logs(logs, self.samples_seen)
1014+
self.samples_seen_at_last_write = self.samples_seen
1015+
9871016

9881017
class ReduceLROnPlateau(Callback):
9891018
"""Reduce learning rate when a metric has stopped improving.

tests/keras/test_callbacks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,8 @@ def make_model():
550550
assert not tmpdir.listdir()
551551

552552

553-
def test_TensorBoard(tmpdir):
553+
@pytest.mark.parametrize('update_freq', ['batch', 'epoch', 9])
554+
def test_TensorBoard(tmpdir, update_freq):
554555
np.random.seed(np.random.randint(1, 1e7))
555556
filepath = str(tmpdir / 'logs')
556557

@@ -588,7 +589,8 @@ def callbacks_factory(histogram_freq, embeddings_freq=1):
588589
embeddings_freq=embeddings_freq,
589590
embeddings_layer_names=['dense_1'],
590591
embeddings_data=X_test,
591-
batch_size=5)]
592+
batch_size=5,
593+
update_freq=update_freq)]
592594

593595
# fit without validation data
594596
model.fit(X_train, y_train, batch_size=batch_size,

0 commit comments

Comments
 (0)