@@ -136,23 +136,30 @@ def test_TensorBoard():
136
136
nb_class = nb_class )
137
137
y_test = np_utils .to_categorical (y_test )
138
138
y_train = np_utils .to_categorical (y_train )
139
- # case 1 Sequential wo accuracy
140
- with tf .Graph ().as_default ():
141
- session = tf .Session ('' )
142
- KTF ._set_session (session )
143
- model = Sequential ()
144
- model .add (Dense (nb_hidden , input_dim = input_dim , activation = 'relu' ))
145
- model .add (Dense (nb_class , activation = 'softmax' ))
146
- model .compile (loss = 'categorical_crossentropy' , optimizer = 'sgd' )
147
139
148
- tsb = callbacks .TensorBoard (log_dir = filepath , histogram_freq = 1 )
149
- cbks = [tsb ]
150
- model .fit (X_train , y_train , batch_size = batch_size , show_accuracy = True ,
151
- validation_data = (X_test , y_test ), callbacks = cbks , nb_epoch = 2 )
152
- assert os .path .exists (filepath )
153
- shutil .rmtree (filepath )
140
+ def data_generator (train ):
141
+ if train :
142
+ max_batch_index = len (X_train ) // batch_size
143
+ else :
144
+ max_batch_index = len (X_test ) // batch_size
145
+ i = 0
146
+ while 1 :
147
+ if train :
148
+ yield (X_train [i * batch_size : (i + 1 ) * batch_size ], y_train [i * batch_size : (i + 1 ) * batch_size ])
149
+ else :
150
+ yield (X_test [i * batch_size : (i + 1 ) * batch_size ], y_test [i * batch_size : (i + 1 ) * batch_size ])
151
+ i += 1
152
+ i = i % max_batch_index
153
+
154
+ def data_generator_graph (train ):
155
+ while 1 :
156
+ if train :
157
+ yield {'X_vars' : X_train , 'output' : y_train }
158
+ else :
159
+ yield {'X_vars' : X_test , 'output' : y_test }
160
+
161
+ # case 1 Sequential
154
162
155
- # case 2 Sequential w accuracy
156
163
with tf .Graph ().as_default ():
157
164
session = tf .Session ('' )
158
165
KTF ._set_session (session )
@@ -163,12 +170,42 @@ def test_TensorBoard():
163
170
164
171
tsb = callbacks .TensorBoard (log_dir = filepath , histogram_freq = 1 )
165
172
cbks = [tsb ]
173
+
174
+ # fit with validation data
175
+ model .fit (X_train , y_train , batch_size = batch_size , show_accuracy = False ,
176
+ validation_data = (X_test , y_test ), callbacks = cbks , nb_epoch = 2 )
177
+
178
+ # fit with validation data and accuracy
166
179
model .fit (X_train , y_train , batch_size = batch_size , show_accuracy = True ,
167
180
validation_data = (X_test , y_test ), callbacks = cbks , nb_epoch = 2 )
181
+
182
+ # fit generator with validation data
183
+ model .fit_generator (data_generator (True ), len (X_train ), nb_epoch = 2 ,
184
+ show_accuracy = False ,
185
+ validation_data = (X_test , y_test ),
186
+ callbacks = cbks )
187
+
188
+ # fit generator without validation data
189
+ model .fit_generator (data_generator (True ), len (X_train ), nb_epoch = 2 ,
190
+ show_accuracy = False ,
191
+ callbacks = cbks )
192
+
193
+ # fit generator with validation data and accuracy
194
+ model .fit_generator (data_generator (True ), len (X_train ), nb_epoch = 2 ,
195
+ show_accuracy = True ,
196
+ validation_data = (X_test , y_test ),
197
+ callbacks = cbks )
198
+
199
+ # fit generator without validation data and accuracy
200
+ model .fit_generator (data_generator (True ), len (X_train ), nb_epoch = 2 ,
201
+ show_accuracy = True ,
202
+ callbacks = cbks )
203
+
168
204
assert os .path .exists (filepath )
169
205
shutil .rmtree (filepath )
170
206
171
- # case 3 Graph
207
+ # case 2 Graph
208
+
172
209
with tf .Graph ().as_default ():
173
210
session = tf .Session ('' )
174
211
KTF ._set_session (session )
@@ -185,10 +222,27 @@ def test_TensorBoard():
185
222
186
223
tsb = callbacks .TensorBoard (log_dir = filepath , histogram_freq = 1 )
187
224
cbks = [tsb ]
225
+
226
+ # fit with validation
188
227
model .fit ({'X_vars' : X_train , 'output' : y_train },
189
228
batch_size = batch_size ,
190
229
validation_data = {'X_vars' : X_test , 'output' : y_test },
191
230
callbacks = cbks , nb_epoch = 2 )
231
+
232
+ # fit wo validation
233
+ model .fit ({'X_vars' : X_train , 'output' : y_train },
234
+ batch_size = batch_size ,
235
+ callbacks = cbks , nb_epoch = 2 )
236
+
237
+ # fit generator with validation
238
+ model .fit_generator (data_generator_graph (True ), 1000 , nb_epoch = 2 ,
239
+ validation_data = {'X_vars' : X_test , 'output' : y_test },
240
+ callbacks = cbks )
241
+
242
+ # fit generator wo validation
243
+ model .fit_generator (data_generator_graph (True ), 1000 , nb_epoch = 2 ,
244
+ callbacks = cbks )
245
+
192
246
assert os .path .exists (filepath )
193
247
shutil .rmtree (filepath )
194
248
0 commit comments