16
16
flags = tf .app .flags
17
17
FLAGS = flags .FLAGS
18
18
flags .DEFINE_boolean ("enable_colored_log" , False , "Enable colored log" )
19
- flags .DEFINE_string ("train_tfrecords_file" ,
20
- "./data/cancer/cancer_train.csv.tfrecords" ,
19
+ flags .DEFINE_string ("input_file_format" , "tfrecord" , "Input file format" )
20
+ flags . DEFINE_string ( "train_file" , "./data/cancer/cancer_train.csv.tfrecords" ,
21
21
"The glob pattern of train TFRecords files" )
22
- flags .DEFINE_string ("validate_tfrecords_file" ,
23
- "./data/cancer/cancer_test.csv.tfrecords" ,
22
+ flags .DEFINE_string ("validate_file" , "./data/cancer/cancer_test.csv.tfrecords" ,
24
23
"The glob pattern of validate TFRecords files" )
25
24
flags .DEFINE_integer ("feature_size" , 9 , "Number of feature size" )
26
25
flags .DEFINE_integer ("label_size" , 2 , "Number of label size" )
@@ -63,6 +62,10 @@ def main():
63
62
import coloredlogs
64
63
coloredlogs .install ()
65
64
logging .basicConfig (level = logging .INFO )
65
+ INPUT_FILE_FORMAT = FLAGS .input_file_format
66
+ if INPUT_FILE_FORMAT not in ["tfrecord" , "csv" ]:
67
+ logging .error ("Unknow input file format: {}" .format (INPUT_FILE_FORMAT ))
68
+ exit (1 )
66
69
FEATURE_SIZE = FLAGS .feature_size
67
70
LABEL_SIZE = FLAGS .label_size
68
71
EPOCH_NUMBER = FLAGS .epoch_number
@@ -85,7 +88,7 @@ def main():
85
88
pprint .PrettyPrinter ().pprint (FLAGS .__flags )
86
89
87
90
# Process TFRecoreds files
88
- def read_and_decode (filename_queue ):
91
+ def read_and_decode_tfrecord (filename_queue ):
89
92
reader = tf .TFRecordReader ()
90
93
_ , serialized_example = reader .read (filename_queue )
91
94
features = tf .parse_single_example (
@@ -98,11 +101,29 @@ def read_and_decode(filename_queue):
98
101
features = features ["features" ]
99
102
return label , features
100
103
104
+ def read_and_decode_csv (filename_queue ):
105
+ # TODO: Not generic for all datasets
106
+ reader = tf .TextLineReader ()
107
+ key , value = reader .read (filename_queue )
108
+
109
+ # Default values, in case of empty columns. Also specifies the type of the
110
+ # decoded result.
111
+ #record_defaults = [[1], [1], [1], [1], [1]]
112
+ record_defaults = [[1 ], [1.0 ], [1.0 ], [1.0 ], [1.0 ]]
113
+ col1 , col2 , col3 , col4 , col5 = tf .decode_csv (
114
+ value , record_defaults = record_defaults )
115
+ label = col1
116
+ features = tf .stack ([col2 , col3 , col4 , col4 ])
117
+ return label , features
118
+
101
119
# Read TFRecords files for training
102
120
filename_queue = tf .train .string_input_producer (
103
- tf .train .match_filenames_once (FLAGS .train_tfrecords_file ),
121
+ tf .train .match_filenames_once (FLAGS .train_file ),
104
122
num_epochs = EPOCH_NUMBER )
105
- label , features = read_and_decode (filename_queue )
123
+ if INPUT_FILE_FORMAT == "tfrecord" :
124
+ label , features = read_and_decode_tfrecord (filename_queue )
125
+ elif INPUT_FILE_FORMAT == "csv" :
126
+ label , features = read_and_decode_csv (filename_queue )
106
127
batch_labels , batch_features = tf .train .shuffle_batch (
107
128
[label , features ],
108
129
batch_size = FLAGS .batch_size ,
@@ -112,9 +133,14 @@ def read_and_decode(filename_queue):
112
133
113
134
# Read TFRecords file for validatioin
114
135
validate_filename_queue = tf .train .string_input_producer (
115
- tf .train .match_filenames_once (FLAGS .validate_tfrecords_file ),
136
+ tf .train .match_filenames_once (FLAGS .validate_file ),
116
137
num_epochs = EPOCH_NUMBER )
117
- validate_label , validate_features = read_and_decode (validate_filename_queue )
138
+ if INPUT_FILE_FORMAT == "tfrecord" :
139
+ validate_label , validate_features = read_and_decode_tfrecord (
140
+ validate_filename_queue )
141
+ elif INPUT_FILE_FORMAT == "csv" :
142
+ validate_label , validate_features = read_and_decode_csv (
143
+ validate_filename_queue )
118
144
validate_batch_labels , validate_batch_features = tf .train .shuffle_batch (
119
145
[validate_label , validate_features ],
120
146
batch_size = FLAGS .validate_batch_size ,
@@ -292,8 +318,8 @@ def inference(inputs, is_train=True):
292
318
MODEL , FLAGS .model_network ))
293
319
logits = inference (batch_features , True )
294
320
batch_labels = tf .to_int64 (batch_labels )
295
- cross_entropy = tf .nn .sparse_softmax_cross_entropy_with_logits (logits = logits ,
296
- labels = batch_labels )
321
+ cross_entropy = tf .nn .sparse_softmax_cross_entropy_with_logits (
322
+ logits = logits , labels = batch_labels )
297
323
loss = tf .reduce_mean (cross_entropy , name = "loss" )
298
324
global_step = tf .Variable (0 , name = "global_step" , trainable = False )
299
325
if FLAGS .enable_lr_decay :
0 commit comments