Skip to content

Commit 270956a

Browse files
committed
Support batch reading csv files
1 parent 776749b commit 270956a

File tree

3 files changed

+115
-13
lines changed

3 files changed

+115
-13
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,15 @@ Using different models or hyperparameters is easy with TensorFlow flags.
8383
If you use other dataset like [iris](./data/iris/), no need to modify the code. Just run with parameters to specify the TFRecords files.
8484

8585
```
86-
./dense_classifier.py --train_tfrecords_file ./data/iris/iris_train.csv.tfrecords --validate_tfrecords_file ./data/iris/iris_test.csv.tfrecords --feature_size 4 --label_size 3
86+
./dense_classifier.py --train_file ./data/iris/iris_train.csv.tfrecords --validate_file ./data/iris/iris_test.csv.tfrecords --feature_size 4 --label_size 3 --enable_colored_log
87+
88+
./dense_classifier.py --train_file ./data/iris/iris_train.csv --validate_file ./data/iris/iris_test.csv --feature_size 4 --label_size 3 --input_file_format csv --enable_colored_log
8789
```
8890

8991
If you want to use CNN model, try this command.
9092

9193
```
92-
./dense_classifier.py --train_tfrecords_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --validate_tfrecords_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --feature_size 262144 --label_size 2 --batch_size 2 --validate_batch_size 2 --epoch_number -1 --model cnn
94+
./dense_classifier.py --train_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --validate_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --feature_size 262144 --label_size 2 --batch_size 2 --validate_batch_size 2 --epoch_number -1 --model cnn
9395
```
9496

9597
### Export The Model

dense_classifier.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
flags = tf.app.flags
1717
FLAGS = flags.FLAGS
1818
flags.DEFINE_boolean("enable_colored_log", False, "Enable colored log")
19-
flags.DEFINE_string("train_tfrecords_file",
20-
"./data/cancer/cancer_train.csv.tfrecords",
19+
flags.DEFINE_string("input_file_format", "tfrecord", "Input file format")
20+
flags.DEFINE_string("train_file", "./data/cancer/cancer_train.csv.tfrecords",
2121
"The glob pattern of train TFRecords files")
22-
flags.DEFINE_string("validate_tfrecords_file",
23-
"./data/cancer/cancer_test.csv.tfrecords",
22+
flags.DEFINE_string("validate_file", "./data/cancer/cancer_test.csv.tfrecords",
2423
"The glob pattern of validate TFRecords files")
2524
flags.DEFINE_integer("feature_size", 9, "Number of feature size")
2625
flags.DEFINE_integer("label_size", 2, "Number of label size")
@@ -63,6 +62,10 @@ def main():
6362
import coloredlogs
6463
coloredlogs.install()
6564
logging.basicConfig(level=logging.INFO)
65+
INPUT_FILE_FORMAT = FLAGS.input_file_format
66+
if INPUT_FILE_FORMAT not in ["tfrecord", "csv"]:
67+
logging.error("Unknow input file format: {}".format(INPUT_FILE_FORMAT))
68+
exit(1)
6669
FEATURE_SIZE = FLAGS.feature_size
6770
LABEL_SIZE = FLAGS.label_size
6871
EPOCH_NUMBER = FLAGS.epoch_number
@@ -85,7 +88,7 @@ def main():
8588
pprint.PrettyPrinter().pprint(FLAGS.__flags)
8689

8790
# Process TFRecoreds files
88-
def read_and_decode(filename_queue):
91+
def read_and_decode_tfrecord(filename_queue):
8992
reader = tf.TFRecordReader()
9093
_, serialized_example = reader.read(filename_queue)
9194
features = tf.parse_single_example(
@@ -98,11 +101,29 @@ def read_and_decode(filename_queue):
98101
features = features["features"]
99102
return label, features
100103

104+
def read_and_decode_csv(filename_queue):
105+
# TODO: Not generic for all datasets
106+
reader = tf.TextLineReader()
107+
key, value = reader.read(filename_queue)
108+
109+
# Default values, in case of empty columns. Also specifies the type of the
110+
# decoded result.
111+
#record_defaults = [[1], [1], [1], [1], [1]]
112+
record_defaults = [[1], [1.0], [1.0], [1.0], [1.0]]
113+
col1, col2, col3, col4, col5 = tf.decode_csv(
114+
value, record_defaults=record_defaults)
115+
label = col1
116+
features = tf.stack([col2, col3, col4, col4])
117+
return label, features
118+
101119
# Read TFRecords files for training
102120
filename_queue = tf.train.string_input_producer(
103-
tf.train.match_filenames_once(FLAGS.train_tfrecords_file),
121+
tf.train.match_filenames_once(FLAGS.train_file),
104122
num_epochs=EPOCH_NUMBER)
105-
label, features = read_and_decode(filename_queue)
123+
if INPUT_FILE_FORMAT == "tfrecord":
124+
label, features = read_and_decode_tfrecord(filename_queue)
125+
elif INPUT_FILE_FORMAT == "csv":
126+
label, features = read_and_decode_csv(filename_queue)
106127
batch_labels, batch_features = tf.train.shuffle_batch(
107128
[label, features],
108129
batch_size=FLAGS.batch_size,
@@ -112,9 +133,14 @@ def read_and_decode(filename_queue):
112133

113134
# Read TFRecords file for validatioin
114135
validate_filename_queue = tf.train.string_input_producer(
115-
tf.train.match_filenames_once(FLAGS.validate_tfrecords_file),
136+
tf.train.match_filenames_once(FLAGS.validate_file),
116137
num_epochs=EPOCH_NUMBER)
117-
validate_label, validate_features = read_and_decode(validate_filename_queue)
138+
if INPUT_FILE_FORMAT == "tfrecord":
139+
validate_label, validate_features = read_and_decode_tfrecord(
140+
validate_filename_queue)
141+
elif INPUT_FILE_FORMAT == "csv":
142+
validate_label, validate_features = read_and_decode_csv(
143+
validate_filename_queue)
118144
validate_batch_labels, validate_batch_features = tf.train.shuffle_batch(
119145
[validate_label, validate_features],
120146
batch_size=FLAGS.validate_batch_size,
@@ -292,8 +318,8 @@ def inference(inputs, is_train=True):
292318
MODEL, FLAGS.model_network))
293319
logits = inference(batch_features, True)
294320
batch_labels = tf.to_int64(batch_labels)
295-
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
296-
labels=batch_labels)
321+
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
322+
logits=logits, labels=batch_labels)
297323
loss = tf.reduce_mean(cross_entropy, name="loss")
298324
global_step = tf.Variable(0, name="global_step", trainable=False)
299325
if FLAGS.enable_lr_decay:

java_predict_client2/pom.xml

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
7+
<groupId>com.tobe</groupId>
8+
<artifactId>predict</artifactId>
9+
<version>1.0-SNAPSHOT</version>
10+
11+
12+
<dependencies>
13+
14+
<dependency>
15+
<groupId>com.google.protobuf</groupId>
16+
<artifactId>protobuf-java</artifactId>
17+
<version>3.0.0</version>
18+
</dependency>
19+
20+
<dependency>
21+
<groupId>io.grpc</groupId>
22+
<artifactId>grpc-netty</artifactId>
23+
<version>1.0.0</version>
24+
</dependency>
25+
<dependency>
26+
<groupId>io.grpc</groupId>
27+
<artifactId>grpc-protobuf</artifactId>
28+
<version>1.0.0</version>
29+
</dependency>
30+
<dependency>
31+
<groupId>io.grpc</groupId>
32+
<artifactId>grpc-stub</artifactId>
33+
<version>1.0.0</version>
34+
</dependency>
35+
36+
</dependencies>
37+
38+
<build>
39+
<extensions>
40+
<extension>
41+
<groupId>kr.motd.maven</groupId>
42+
<artifactId>os-maven-plugin</artifactId>
43+
<version>1.4.1.Final</version>
44+
</extension>
45+
</extensions>
46+
<plugins>
47+
<plugin>
48+
<groupId>org.xolstice.maven.plugins</groupId>
49+
<artifactId>protobuf-maven-plugin</artifactId>
50+
<version>0.5.0</version>
51+
<configuration>
52+
<!--
53+
The version of protoc must match protobuf-java. If you don't depend on
54+
protobuf-java directly, you will be transitively depending on the
55+
protobuf-java version that grpc depends on.
56+
-->
57+
<protocArtifact>com.google.protobuf:protoc:3.0.0:exe:${os.detected.classifier}</protocArtifact>
58+
<pluginId>grpc-java</pluginId>
59+
<pluginArtifact>io.grpc:protoc-gen-grpc-java:1.0.0:exe:${os.detected.classifier}</pluginArtifact>
60+
</configuration>
61+
<executions>
62+
<execution>
63+
<goals>
64+
<goal>compile</goal>
65+
<goal>compile-custom</goal>
66+
</goals>
67+
</execution>
68+
</executions>
69+
</plugin>
70+
</plugins>
71+
</build>
72+
73+
74+
</project>

0 commit comments

Comments
 (0)