14
14
EPSILON = 1.0
15
15
BETA_1 = 0.9
16
16
BETA_2 = 0.999
17
+ DROPOUT = 0.5
17
18
18
19
SCALE_INPUTS = 1
19
20
DISTORT_INPUTS = True
26
27
27
28
def train (dataset , network , checkpoint_dir , batch_size = BATCH_SIZE ,
28
29
last_step = LAST_STEP , learning_rate = LEARNING_RATE , epsilon = EPSILON ,
29
- beta1 = BETA_1 , beta2 = BETA_2 , scale_inputs = SCALE_INPUTS ,
30
- distort_inputs = DISTORT_INPUTS , zero_mean_inputs = ZERO_MEAN_INPUTS ,
31
- display_step = DISPLAY_STEP , save_checkpoint_secs = SAVE_CHECKPOINT_SECS ,
30
+ beta1 = BETA_1 , beta2 = BETA_2 , dropout = DROPOUT ,
31
+ scale_inputs = SCALE_INPUTS , distort_inputs = DISTORT_INPUTS ,
32
+ zero_mean_inputs = ZERO_MEAN_INPUTS , display_step = DISPLAY_STEP ,
33
+ save_checkpoint_secs = SAVE_CHECKPOINT_SECS ,
32
34
save_summaries_steps = SAVE_SUMMARIES_STEPS ):
33
35
34
36
if not tf .gfile .Exists (checkpoint_dir ):
@@ -37,18 +39,19 @@ def train(dataset, network, checkpoint_dir, batch_size=BATCH_SIZE,
37
39
with tf .Graph ().as_default ():
38
40
39
41
ckpt = tf .train .get_checkpoint_state (checkpoint_dir )
42
+ global_step_init = - 1
40
43
if ckpt and ckpt .model_checkpoint_path :
41
- global_step = int (
44
+ global_step_init = int (
42
45
ckpt .model_checkpoint_path .split ('/' )[- 1 ].split ('-' )[- 1 ])
43
- global_step = tf .Variable (global_step , name = 'global_step' ,
46
+ global_step = tf .Variable (global_step_init , name = 'global_step' ,
44
47
dtype = tf .int64 , trainable = False )
45
48
else :
46
49
global_step = tf .contrib .framework .get_or_create_global_step ()
47
50
48
51
data , labels = inputs (dataset , False , batch_size , scale_inputs ,
49
52
distort_inputs , zero_mean_inputs , shuffle = True )
50
53
51
- logits = inference (data , network )
54
+ logits = inference (data , network , dropout )
52
55
loss = cal_loss (logits , labels )
53
56
acc = cal_accuracy (logits , labels )
54
57
@@ -82,6 +85,7 @@ def train_from_config(dataset, config, display_step=DISPLAY_STEP,
82
85
config .get ('epsilon' , EPSILON ),
83
86
config .get ('beta1' , BETA_1 ),
84
87
config .get ('beta2' , BETA_2 ),
88
+ config .get ('dropout' , DROPOUT ),
85
89
config .get ('scale_inputs' , SCALE_INPUTS ),
86
90
config .get ('distort_inputs' , DISTORT_INPUTS ),
87
91
config .get ('zero_mean_inputs' , ZERO_MEAN_INPUTS ),
0 commit comments