@@ -169,58 +169,63 @@ def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1,
169
169
history ['val_acc' ] = []
170
170
171
171
index_array = np .arange (len (X ))
172
- for epoch in range (nb_epoch ):
173
- if verbose :
174
- print ('Epoch' , epoch )
175
- progbar = Progbar (target = len (X ), verbose = verbose )
176
- if shuffle :
177
- np .random .shuffle (index_array )
178
-
179
- av_loss = 0.
180
- av_acc = 0.
181
- seen = 0
182
172
183
- batches = make_batches (len (X ), batch_size )
184
- for batch_index , (batch_start , batch_end ) in enumerate (batches ):
173
+ # -- allow user to terminate training
174
+ try :
175
+ for epoch in range (nb_epoch ):
176
+ if verbose :
177
+ print ('Epoch' , epoch )
178
+ progbar = Progbar (target = len (X ), verbose = verbose )
185
179
if shuffle :
186
- batch_ids = index_array [batch_start :batch_end ]
187
- else :
188
- batch_ids = slice (batch_start , batch_end )
189
- seen += len (batch_ids )
190
- X_batch = X [batch_ids ]
191
- y_batch = y [batch_ids ]
180
+ np .random .shuffle (index_array )
192
181
193
- if show_accuracy :
194
- loss , acc = self ._train_with_acc (X_batch , y_batch )
195
- log_values = [('loss' , loss ), ('acc.' , acc )]
196
- av_loss += loss * len (batch_ids )
197
- av_acc += acc * len (batch_ids )
198
- else :
199
- loss = self ._train (X_batch , y_batch )
200
- log_values = [('loss' , loss )]
201
- av_loss += loss * len (batch_ids )
182
+ av_loss = 0.
183
+ av_acc = 0.
184
+ seen = 0
202
185
203
- # validation
204
- if do_validation and (batch_index == len (batches ) - 1 ):
205
- if show_accuracy :
206
- val_loss , val_acc = self .test (X_val , y_val , accuracy = True )
207
- log_values += [('val. loss' , val_loss ), ('val. acc.' , val_acc )]
186
+ batches = make_batches (len (X ), batch_size )
187
+ for batch_index , (batch_start , batch_end ) in enumerate (batches ):
188
+ if shuffle :
189
+ batch_ids = index_array [batch_start :batch_end ]
208
190
else :
209
- val_loss = self .test (X_val , y_val )
210
- log_values += [('val. loss' , val_loss )]
211
-
212
- # logging
213
- if verbose :
214
- progbar .update (batch_end , log_values )
191
+ batch_ids = slice (batch_start , batch_end )
192
+ seen += len (batch_ids )
193
+ X_batch = X [batch_ids ]
194
+ y_batch = y [batch_ids ]
215
195
216
- history ['epoch' ].append (epoch )
217
- history ['loss' ].append (av_loss / seen )
218
- if do_validation :
219
- history ['val_loss' ].append (float (val_loss ))
220
- if show_accuracy :
221
- history ['acc' ].append (av_acc / seen )
196
+ if show_accuracy :
197
+ loss , acc = self ._train_with_acc (X_batch , y_batch )
198
+ log_values = [('loss' , loss ), ('acc.' , acc )]
199
+ av_loss += loss * len (batch_ids )
200
+ av_acc += acc * len (batch_ids )
201
+ else :
202
+ loss = self ._train (X_batch , y_batch )
203
+ log_values = [('loss' , loss )]
204
+ av_loss += loss * len (batch_ids )
205
+
206
+ # validation
207
+ if do_validation and (batch_index == len (batches ) - 1 ):
208
+ if show_accuracy :
209
+ val_loss , val_acc = self .test (X_val , y_val , accuracy = True )
210
+ log_values += [('val. loss' , val_loss ), ('val. acc.' , val_acc )]
211
+ else :
212
+ val_loss = self .test (X_val , y_val )
213
+ log_values += [('val. loss' , val_loss )]
214
+
215
+ # logging
216
+ if verbose :
217
+ progbar .update (batch_end , log_values )
218
+
219
+ history ['epoch' ].append (epoch )
220
+ history ['loss' ].append (av_loss / seen )
222
221
if do_validation :
223
- history ['val_acc' ].append (float (val_acc ))
222
+ history ['val_loss' ].append (float (val_loss ))
223
+ if show_accuracy :
224
+ history ['acc' ].append (av_acc / seen )
225
+ if do_validation :
226
+ history ['val_acc' ].append (float (val_acc ))
227
+ except KeyboardInterrupt :
228
+ print ('Terminating on KeyboardInterrupt at Epoch' , epoch )
224
229
return history
225
230
226
231
def predict (self , X , batch_size = 128 , verbose = 1 ):
0 commit comments