Skip to content

Commit d8eae48

Browse files
committed
added dropout
1 parent 4c2bdf5 commit d8eae48

7 files changed

+19
-8
lines changed

model/eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def evaluate(dataset, network, checkpoint_dir, eval_dir, batch_size=BATCH_SIZE,
3737
distort_inputs, zero_mean_inputs, num_epochs=1,
3838
shuffle=False)
3939

40-
logits = inference(data, network)
40+
logits = inference(data, network, drouput=1.0)
4141
top_k_op = tf.nn.in_top_k(logits, labels, 1)
4242

4343
variable_averages = tf.train.ExponentialMovingAverage(0.99999)

model/inference.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _bias_variable(name, shape, constant):
4343
# biases: { constant: 0.1 },
4444
# }
4545
# }
46-
def inference(data, network):
46+
def inference(data, network, dropout):
4747
output = data
4848
i = 1
4949

@@ -111,6 +111,9 @@ def inference(data, network):
111111
input_channels = output.get_shape()[1].value
112112
output_channels = layer['output_channels']
113113

114+
# Apply dropout.
115+
output = tf.nn.dropout(output, dropout)
116+
114117
with tf.variable_scope('softmax_linear') as scope:
115118

116119
weights = _weight_variable(

model/train.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
EPSILON = 1.0
1515
BETA_1 = 0.9
1616
BETA_2 = 0.999
17+
DROPOUT = 0.5
1718

1819
SCALE_INPUTS = 1
1920
DISTORT_INPUTS = True
@@ -26,9 +27,10 @@
2627

2728
def train(dataset, network, checkpoint_dir, batch_size=BATCH_SIZE,
2829
last_step=LAST_STEP, learning_rate=LEARNING_RATE, epsilon=EPSILON,
29-
beta1=BETA_1, beta2=BETA_2, scale_inputs=SCALE_INPUTS,
30-
distort_inputs=DISTORT_INPUTS, zero_mean_inputs=ZERO_MEAN_INPUTS,
31-
display_step=DISPLAY_STEP, save_checkpoint_secs=SAVE_CHECKPOINT_SECS,
30+
beta1=BETA_1, beta2=BETA_2, dropout=DROPOUT,
31+
scale_inputs=SCALE_INPUTS, distort_inputs=DISTORT_INPUTS,
32+
zero_mean_inputs=ZERO_MEAN_INPUTS, display_step=DISPLAY_STEP,
33+
save_checkpoint_secs=SAVE_CHECKPOINT_SECS,
3234
save_summaries_steps=SAVE_SUMMARIES_STEPS):
3335

3436
if not tf.gfile.Exists(checkpoint_dir):
@@ -37,18 +39,19 @@ def train(dataset, network, checkpoint_dir, batch_size=BATCH_SIZE,
3739
with tf.Graph().as_default():
3840

3941
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
42+
global_step_init = -1
4043
if ckpt and ckpt.model_checkpoint_path:
41-
global_step = int(
44+
global_step_init = int(
4245
ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
43-
global_step = tf.Variable(global_step, name='global_step',
46+
global_step = tf.Variable(global_step_init, name='global_step',
4447
dtype=tf.int64, trainable=False)
4548
else:
4649
global_step = tf.contrib.framework.get_or_create_global_step()
4750

4851
data, labels = inputs(dataset, False, batch_size, scale_inputs,
4952
distort_inputs, zero_mean_inputs, shuffle=True)
5053

51-
logits = inference(data, network)
54+
logits = inference(data, network, dropout)
5255
loss = cal_loss(logits, labels)
5356
acc = cal_accuracy(logits, labels)
5457

@@ -82,6 +85,7 @@ def train_from_config(dataset, config, display_step=DISPLAY_STEP,
8285
config.get('epsilon', EPSILON),
8386
config.get('beta1', BETA_1),
8487
config.get('beta2', BETA_2),
88+
config.get('dropout', DROPOUT),
8589
config.get('scale_inputs', SCALE_INPUTS),
8690
config.get('distort_inputs', DISTORT_INPUTS),
8791
config.get('zero_mean_inputs', ZERO_MEAN_INPUTS),

networks/cifar_10.json

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"epsilon": 1.0,
1212
"beta1": 0.9,
1313
"beta2": 0.999,
14+
"dropout": 0.5,
1415
"scale_inputs": 1,
1516
"distort_inputs": true,
1617
"zero_mean_inputs": true,

networks/patchy_san_slic_pascal_voc.json

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"epsilon": 1,
3737
"beta1": 0.9,
3838
"beta2": 0.999,
39+
"dropout": 0.5,
3940
"scale_inputs": 1,
4041
"distort_inputs": false,
4142
"zero_mean_inputs": true,

networks/vgg16_pascal_voc.json

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"epsilon": 1.0,
1212
"beta1": 0.9,
1313
"beta2": 0.999,
14+
"dropout": 0.5,
1415
"scale_inputs": 1,
1516
"distort_inputs": true,
1617
"zero_mean_inputs": true,

networks/vgg19_pascal_voc.json

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"epsilon": 1.0,
1212
"beta1": 0.9,
1313
"beta2": 0.999,
14+
"dropout": 0.5,
1415
"scale_inputs": 1,
1516
"distort_inputs": true,
1617
"zero_mean_inputs": true,

0 commit comments

Comments
 (0)