Skip to content

Commit 1a3f984

Browse files
committed
adding ability for users to stop training early and still return the training history
1 parent 3613ce4 commit 1a3f984

File tree

1 file changed

+50
-45
lines changed

1 file changed

+50
-45
lines changed

keras/models.py

Lines changed: 50 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -169,58 +169,63 @@ def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1,
169169
history['val_acc'] = []
170170

171171
index_array = np.arange(len(X))
172-
for epoch in range(nb_epoch):
173-
if verbose:
174-
print('Epoch', epoch)
175-
progbar = Progbar(target=len(X), verbose=verbose)
176-
if shuffle:
177-
np.random.shuffle(index_array)
178-
179-
av_loss = 0.
180-
av_acc = 0.
181-
seen = 0
182172

183-
batches = make_batches(len(X), batch_size)
184-
for batch_index, (batch_start, batch_end) in enumerate(batches):
173+
# -- allow user to terminate training
174+
try:
175+
for epoch in range(nb_epoch):
176+
if verbose:
177+
print('Epoch', epoch)
178+
progbar = Progbar(target=len(X), verbose=verbose)
185179
if shuffle:
186-
batch_ids = index_array[batch_start:batch_end]
187-
else:
188-
batch_ids = slice(batch_start, batch_end)
189-
seen += len(batch_ids)
190-
X_batch = X[batch_ids]
191-
y_batch = y[batch_ids]
180+
np.random.shuffle(index_array)
192181

193-
if show_accuracy:
194-
loss, acc = self._train_with_acc(X_batch, y_batch)
195-
log_values = [('loss', loss), ('acc.', acc)]
196-
av_loss += loss * len(batch_ids)
197-
av_acc += acc * len(batch_ids)
198-
else:
199-
loss = self._train(X_batch, y_batch)
200-
log_values = [('loss', loss)]
201-
av_loss += loss * len(batch_ids)
182+
av_loss = 0.
183+
av_acc = 0.
184+
seen = 0
202185

203-
# validation
204-
if do_validation and (batch_index == len(batches) - 1):
205-
if show_accuracy:
206-
val_loss, val_acc = self.test(X_val, y_val, accuracy=True)
207-
log_values += [('val. loss', val_loss), ('val. acc.', val_acc)]
186+
batches = make_batches(len(X), batch_size)
187+
for batch_index, (batch_start, batch_end) in enumerate(batches):
188+
if shuffle:
189+
batch_ids = index_array[batch_start:batch_end]
208190
else:
209-
val_loss = self.test(X_val, y_val)
210-
log_values += [('val. loss', val_loss)]
211-
212-
# logging
213-
if verbose:
214-
progbar.update(batch_end, log_values)
191+
batch_ids = slice(batch_start, batch_end)
192+
seen += len(batch_ids)
193+
X_batch = X[batch_ids]
194+
y_batch = y[batch_ids]
215195

216-
history['epoch'].append(epoch)
217-
history['loss'].append(av_loss/seen)
218-
if do_validation:
219-
history['val_loss'].append(float(val_loss))
220-
if show_accuracy:
221-
history['acc'].append(av_acc/seen)
196+
if show_accuracy:
197+
loss, acc = self._train_with_acc(X_batch, y_batch)
198+
log_values = [('loss', loss), ('acc.', acc)]
199+
av_loss += loss * len(batch_ids)
200+
av_acc += acc * len(batch_ids)
201+
else:
202+
loss = self._train(X_batch, y_batch)
203+
log_values = [('loss', loss)]
204+
av_loss += loss * len(batch_ids)
205+
206+
# validation
207+
if do_validation and (batch_index == len(batches) - 1):
208+
if show_accuracy:
209+
val_loss, val_acc = self.test(X_val, y_val, accuracy=True)
210+
log_values += [('val. loss', val_loss), ('val. acc.', val_acc)]
211+
else:
212+
val_loss = self.test(X_val, y_val)
213+
log_values += [('val. loss', val_loss)]
214+
215+
# logging
216+
if verbose:
217+
progbar.update(batch_end, log_values)
218+
219+
history['epoch'].append(epoch)
220+
history['loss'].append(av_loss/seen)
222221
if do_validation:
223-
history['val_acc'].append(float(val_acc))
222+
history['val_loss'].append(float(val_loss))
223+
if show_accuracy:
224+
history['acc'].append(av_acc/seen)
225+
if do_validation:
226+
history['val_acc'].append(float(val_acc))
227+
except KeyboardInterrupt:
228+
print('Terminating on KeyboardInterrupt at Epoch', epoch)
224229
return history
225230

226231
def predict(self, X, batch_size=128, verbose=1):

0 commit comments

Comments
 (0)