Skip to content

Commit 09577a2

Browse files
committed
Fashion-MNIST support
1 parent b5ef85b commit 09577a2

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

capsNet.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, is_training=True):
1919
self.graph = tf.Graph()
2020
with self.graph.as_default():
2121
if is_training:
22-
self.X, self.labels = get_batch_data()
22+
self.X, self.labels = get_batch_data(cfg.dataset, cfg.batch_size, cfg.num_threads)
2323
self.Y = tf.one_hot(self.labels, depth=10, axis=1, dtype=tf.float32)
2424

2525
self.build_arch()
@@ -30,14 +30,10 @@ def __init__(self, is_training=True):
3030
self.global_step = tf.Variable(0, name='global_step', trainable=False)
3131
self.optimizer = tf.train.AdamOptimizer()
3232
self.train_op = self.optimizer.minimize(self.total_loss, global_step=self.global_step) # var_list=t_vars)
33-
elif cfg.mask_with_y:
34-
self.X = tf.placeholder(tf.float32,
35-
shape=(cfg.batch_size, 28, 28, 1))
36-
self.Y = tf.placeholder(tf.float32, shape=(cfg.batch_size, 10, 1))
37-
self.build_arch()
3833
else:
39-
self.X = tf.placeholder(tf.float32,
40-
shape=(cfg.batch_size, 28, 28, 1))
34+
self.X = tf.placeholder(tf.float32, shape=(cfg.batch_size, 28, 28, 1))
35+
self.labels = tf.placeholder(tf.int32, shape=(cfg.batch_size, ))
36+
self.Y = tf.reshape(self.labels, shape=(cfg.batch_size, 10, 1))
4137
self.build_arch()
4238

4339
tf.logging.info('Seting up the main structure')
@@ -150,5 +146,4 @@ def _summary(self):
150146
self.train_summary = tf.summary.merge(train_summary)
151147

152148
correct_prediction = tf.equal(tf.to_int32(self.labels), self.argmax_idx)
153-
self.batch_accuracy = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))
154-
self.test_acc = tf.placeholder_with_default(tf.constant(0.), shape=[])
149+
self.accuracy = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))

0 commit comments

Comments
 (0)