@@ -19,7 +19,7 @@ def __init__(self, is_training=True):
1919 self .graph = tf .Graph ()
2020 with self .graph .as_default ():
2121 if is_training :
22- self .X , self .labels = get_batch_data ()
22+ self .X , self .labels = get_batch_data (cfg . dataset , cfg . batch_size , cfg . num_threads )
2323 self .Y = tf .one_hot (self .labels , depth = 10 , axis = 1 , dtype = tf .float32 )
2424
2525 self .build_arch ()
@@ -30,14 +30,10 @@ def __init__(self, is_training=True):
3030 self .global_step = tf .Variable (0 , name = 'global_step' , trainable = False )
3131 self .optimizer = tf .train .AdamOptimizer ()
3232 self .train_op = self .optimizer .minimize (self .total_loss , global_step = self .global_step ) # var_list=t_vars)
33- elif cfg .mask_with_y :
34- self .X = tf .placeholder (tf .float32 ,
35- shape = (cfg .batch_size , 28 , 28 , 1 ))
36- self .Y = tf .placeholder (tf .float32 , shape = (cfg .batch_size , 10 , 1 ))
37- self .build_arch ()
3833 else :
39- self .X = tf .placeholder (tf .float32 ,
40- shape = (cfg .batch_size , 28 , 28 , 1 ))
34+ self .X = tf .placeholder (tf .float32 , shape = (cfg .batch_size , 28 , 28 , 1 ))
35+ self .labels = tf .placeholder (tf .int32 , shape = (cfg .batch_size , ))
36+ self .Y = tf .reshape (self .labels , shape = (cfg .batch_size , 10 , 1 ))
4137 self .build_arch ()
4238
4339 tf .logging .info ('Seting up the main structure' )
@@ -150,5 +146,4 @@ def _summary(self):
150146 self .train_summary = tf .summary .merge (train_summary )
151147
152148 correct_prediction = tf .equal (tf .to_int32 (self .labels ), self .argmax_idx )
153- self .batch_accuracy = tf .reduce_sum (tf .cast (correct_prediction , tf .float32 ))
154- self .test_acc = tf .placeholder_with_default (tf .constant (0. ), shape = [])
149+ self .accuracy = tf .reduce_sum (tf .cast (correct_prediction , tf .float32 ))
0 commit comments