42
42
"The path of checkpoint" )
43
43
flags .DEFINE_string ("output_path" , "./tensorboard/" ,
44
44
"The path of tensorboard event files" )
45
+ flags .DEFINE_string ("scenario" , "classification" ,
46
+ "Support classification and regression" )
45
47
flags .DEFINE_string ("model" , "dnn" , "Support dnn, lr, wide_and_deep" )
46
48
flags .DEFINE_string ("model_network" , "128 32 8" , "The neural network of model" )
47
49
flags .DEFINE_boolean ("enable_bn" , False , "Enable batch normalization or not" )
@@ -86,6 +88,7 @@ def main():
86
88
MIN_AFTER_DEQUEUE = FLAGS .min_after_dequeue
87
89
BATCH_CAPACITY = BATCH_THREAD_NUMBER * FLAGS .batch_size + MIN_AFTER_DEQUEUE
88
90
MODE = FLAGS .mode
91
+ SCENARIO = FLAGS .scenario
89
92
MODEL = FLAGS .model
90
93
CHECKPOINT_PATH = FLAGS .checkpoint_path
91
94
if not CHECKPOINT_PATH .startswith ("fds://" ) and not os .path .exists (
@@ -311,10 +314,19 @@ def inference(inputs, is_train=True):
311
314
logging .info ("Use the model: {}, model network: {}" .format (
312
315
MODEL , FLAGS .model_network ))
313
316
logits = inference (batch_features , True )
314
- batch_labels = tf .to_int64 (batch_labels )
315
- cross_entropy = tf .nn .sparse_softmax_cross_entropy_with_logits (
316
- logits = logits , labels = batch_labels )
317
- loss = tf .reduce_mean (cross_entropy , name = "loss" )
317
+
318
+ if SCENARIO == "classification" :
319
+ batch_labels = tf .to_int64 (batch_labels )
320
+ cross_entropy = tf .nn .sparse_softmax_cross_entropy_with_logits (
321
+ logits = logits , labels = batch_labels )
322
+ loss = tf .reduce_mean (cross_entropy , name = "loss" )
323
+ elif SCENARIO == "regression" :
324
+ msl = tf .square (logits - batch_labels , name = "msl" )
325
+ loss = tf .reduce_mean (msl , name = "loss" )
326
+ else :
327
+ logging .error ("Unknow scenario: {}" .format (SCENARIO ))
328
+ return
329
+
318
330
global_step = tf .Variable (0 , name = "global_step" , trainable = False )
319
331
if FLAGS .enable_lr_decay :
320
332
logging .info (
@@ -332,6 +344,10 @@ def inference(inputs, is_train=True):
332
344
train_op = optimizer .minimize (loss , global_step = global_step )
333
345
tf .get_variable_scope ().reuse_variables ()
334
346
347
+ # Avoid error when not using acc and auc op
348
+ if SCENARIO == "regression" :
349
+ batch_labels = tf .to_int64 (batch_labels )
350
+
335
351
# Define accuracy op for train data
336
352
train_accuracy_logits = inference (batch_features , False )
337
353
train_softmax = tf .nn .softmax (train_accuracy_logits )
@@ -395,10 +411,11 @@ def inference(inputs, is_train=True):
395
411
# Initialize saver and summary
396
412
saver = tf .train .Saver ()
397
413
tf .summary .scalar ("loss" , loss )
398
- tf .summary .scalar ("train_accuracy" , train_accuracy )
399
- tf .summary .scalar ("train_auc" , train_auc )
400
- tf .summary .scalar ("validate_accuracy" , validate_accuracy )
401
- tf .summary .scalar ("validate_auc" , validate_auc )
414
+ if SCENARIO == "classification" :
415
+ tf .summary .scalar ("train_accuracy" , train_accuracy )
416
+ tf .summary .scalar ("train_auc" , train_auc )
417
+ tf .summary .scalar ("validate_accuracy" , validate_accuracy )
418
+ tf .summary .scalar ("validate_auc" , validate_auc )
402
419
summary_op = tf .summary .merge_all ()
403
420
init_op = [
404
421
tf .global_variables_initializer (),
@@ -427,17 +444,24 @@ def inference(inputs, is_train=True):
427
444
428
445
# Print state while training
429
446
if step % FLAGS .steps_to_validate == 0 :
430
- loss_value , train_accuracy_value , train_auc_value , validate_accuracy_value , validate_auc_value , summary_value = sess .run (
431
- [
432
- loss , train_accuracy , train_auc , validate_accuracy ,
433
- validate_auc , summary_op
434
- ])
435
- end_time = datetime .datetime .now ()
436
- logging .info (
437
- "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}" .
438
- format (end_time - start_time , step , loss_value ,
439
- train_accuracy_value , train_auc_value ,
440
- validate_accuracy_value , validate_auc_value ))
447
+ if SCENARIO == "classification" :
448
+ loss_value , train_accuracy_value , train_auc_value , validate_accuracy_value , validate_auc_value , summary_value = sess .run (
449
+ [
450
+ loss , train_accuracy , train_auc , validate_accuracy ,
451
+ validate_auc , summary_op
452
+ ])
453
+ end_time = datetime .datetime .now ()
454
+ logging .info (
455
+ "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}" .
456
+ format (end_time - start_time , step , loss_value ,
457
+ train_accuracy_value , train_auc_value ,
458
+ validate_accuracy_value , validate_auc_value ))
459
+ elif SCENARIO == "regression" :
460
+ loss_value , summary_value = sess .run ([loss , summary_op ])
461
+ end_time = datetime .datetime .now ()
462
+ logging .info ("[{}] Step: {}, loss: {}" .format (
463
+ end_time - start_time , step , loss_value ))
464
+
441
465
writer .add_summary (summary_value , step )
442
466
saver .save (sess , CHECKPOINT_FILE , global_step = step )
443
467
start_time = end_time
0 commit comments