Skip to content

Commit 4039d9b

Browse files
committed
modify an error
1 parent 85a0e72 commit 4039d9b

File tree

4 files changed

+3
-3
lines changed

4 files changed

+3
-3
lines changed

.DS_Store

2 KB
Binary file not shown.

code/.DS_Store

0 Bytes
Binary file not shown.

code/triplet-loss/.DS_Store

0 Bytes
Binary file not shown.

code/triplet-loss/train_with_triplet_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def main(argv):
107107
params = {
108108
"learning_rate": 1e-3,
109109
"batch_size": 64,
110-
"num_epochs": 10,
110+
"num_epochs": 20,
111111

112112
"num_channels": 32,
113113
"use_batch_norm": False,
@@ -124,13 +124,13 @@ def main(argv):
124124

125125
"num_parallel_calls": 4
126126
}
127-
config = tf.estimator.RunConfig(model_dir=args.model_dir, tf_random_seed=100)
127+
config = tf.estimator.RunConfig(model_dir=args.model_dir, tf_random_seed=230)
128128
cls = tf.estimator.Estimator(model_fn=my_model, config=config, params=params)
129129
tf.logging.info("开始训练模型,共{} epochs....".format(params['num_epochs']))
130130
cls.train(input_fn = lambda: train_input_fn(args.data_dir, params))
131131

132132
tf.logging.info("测试集评价模型....")
133-
res = cls.evaluate(input_fn = lambda: test_input_fn(args.data_dit, params))
133+
res = cls.evaluate(input_fn = lambda: test_input_fn(args.data_dir, params))
134134
for key in res:
135135
print("{} : {}".format(key, res[key]))
136136

0 commit comments

Comments
 (0)