Skip to content

Commit f4385ba

Browse files
committed
Support scenario like regression
1 parent 3c87f7c commit f4385ba

File tree

4 files changed

+65
-29
lines changed

4 files changed

+65
-29
lines changed

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@ Following are the supported features.
2626
- [x] Validate acc/auc
2727
- [x] Inference online
2828
- [x] Inference offline
29-
- [x] Network Model
29+
- [x] Network Models
3030
- [x] Logistic regression
3131
- [x] Deep neural network
3232
- [x] Convolution neural network
3333
- [x] Wide and deep model
34+
- [x] Regression model
3435
- [x] Customized models
3536
- [x] Others
3637
- [x] Checkpoint
@@ -95,6 +96,12 @@ If you want to use CNN model, try this command.
9596
./dense_classifier.py --train_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --validate_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --feature_size 262144 --label_size 2 --batch_size 2 --validate_batch_size 2 --epoch_number -1 --model cnn
9697
```
9798

99+
For [boston housing](./data/boston_housing/) dataset.
100+
101+
```
102+
./dense_classifier.py --train_file ./data/boston_housing/train.csv.tfrecords --validate_file ./data/boston_housing/train.csv.tfrecords --feature_size 13 --label_size 1 --scenario regression --batch_size 1 --validate_batch_size 1
103+
```
104+
98105
### Export The Model
99106

100107
After training, it will export the model automatically. Or you can export manually.

data/boston_housing/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
## Data
3+
4+
The files are from https://inclass.kaggle.com/c/boston-housing .

data/boston_housing/generate_csv_tfrecords.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def generate_tfrecords(input_filename, output_filename):
1111
index = 0
1212
for line in open(input_filename, "r"):
1313
index += 1
14-
14+
1515
# Ignore the first line
1616
if index == 1:
1717
continue
@@ -20,17 +20,18 @@ def generate_tfrecords(input_filename, output_filename):
2020
label = float(data[14])
2121
features = [float(i) for i in data[1:14]]
2222

23-
example = tf.train.Example(features=tf.train.Features(feature={
24-
"label":
25-
tf.train.Feature(float_list=tf.train.FloatList(value=[label])),
26-
"features":
27-
tf.train.Feature(float_list=tf.train.FloatList(value=features)),
28-
}))
23+
example = tf.train.Example(features=tf.train.Features(
24+
feature={
25+
"label":
26+
tf.train.Feature(float_list=tf.train.FloatList(value=[label])),
27+
"features":
28+
tf.train.Feature(float_list=tf.train.FloatList(value=features)),
29+
}))
2930
writer.write(example.SerializeToString())
3031

3132
writer.close()
32-
print("Successfully convert {} to {}".format(input_filename,
33-
output_filename))
33+
print(
34+
"Successfully convert {} to {}".format(input_filename, output_filename))
3435

3536

3637
def main():

dense_classifier.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
"The path of checkpoint")
4343
flags.DEFINE_string("output_path", "./tensorboard/",
4444
"The path of tensorboard event files")
45+
flags.DEFINE_string("scenario", "classification",
46+
"Support classification and regression")
4547
flags.DEFINE_string("model", "dnn", "Support dnn, lr, wide_and_deep")
4648
flags.DEFINE_string("model_network", "128 32 8", "The neural network of model")
4749
flags.DEFINE_boolean("enable_bn", False, "Enable batch normalization or not")
@@ -86,6 +88,7 @@ def main():
8688
MIN_AFTER_DEQUEUE = FLAGS.min_after_dequeue
8789
BATCH_CAPACITY = BATCH_THREAD_NUMBER * FLAGS.batch_size + MIN_AFTER_DEQUEUE
8890
MODE = FLAGS.mode
91+
SCENARIO = FLAGS.scenario
8992
MODEL = FLAGS.model
9093
CHECKPOINT_PATH = FLAGS.checkpoint_path
9194
if not CHECKPOINT_PATH.startswith("fds://") and not os.path.exists(
@@ -311,10 +314,19 @@ def inference(inputs, is_train=True):
311314
logging.info("Use the model: {}, model network: {}".format(
312315
MODEL, FLAGS.model_network))
313316
logits = inference(batch_features, True)
314-
batch_labels = tf.to_int64(batch_labels)
315-
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
316-
logits=logits, labels=batch_labels)
317-
loss = tf.reduce_mean(cross_entropy, name="loss")
317+
318+
if SCENARIO == "classification":
319+
batch_labels = tf.to_int64(batch_labels)
320+
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
321+
logits=logits, labels=batch_labels)
322+
loss = tf.reduce_mean(cross_entropy, name="loss")
323+
elif SCENARIO == "regression":
324+
msl = tf.square(logits - batch_labels, name="msl")
325+
loss = tf.reduce_mean(msl, name="loss")
326+
else:
327+
logging.error("Unknow scenario: {}".format(SCENARIO))
328+
return
329+
318330
global_step = tf.Variable(0, name="global_step", trainable=False)
319331
if FLAGS.enable_lr_decay:
320332
logging.info(
@@ -332,6 +344,10 @@ def inference(inputs, is_train=True):
332344
train_op = optimizer.minimize(loss, global_step=global_step)
333345
tf.get_variable_scope().reuse_variables()
334346

347+
# Avoid error when not using acc and auc op
348+
if SCENARIO == "regression":
349+
batch_labels = tf.to_int64(batch_labels)
350+
335351
# Define accuracy op for train data
336352
train_accuracy_logits = inference(batch_features, False)
337353
train_softmax = tf.nn.softmax(train_accuracy_logits)
@@ -395,10 +411,11 @@ def inference(inputs, is_train=True):
395411
# Initialize saver and summary
396412
saver = tf.train.Saver()
397413
tf.summary.scalar("loss", loss)
398-
tf.summary.scalar("train_accuracy", train_accuracy)
399-
tf.summary.scalar("train_auc", train_auc)
400-
tf.summary.scalar("validate_accuracy", validate_accuracy)
401-
tf.summary.scalar("validate_auc", validate_auc)
414+
if SCENARIO == "classification":
415+
tf.summary.scalar("train_accuracy", train_accuracy)
416+
tf.summary.scalar("train_auc", train_auc)
417+
tf.summary.scalar("validate_accuracy", validate_accuracy)
418+
tf.summary.scalar("validate_auc", validate_auc)
402419
summary_op = tf.summary.merge_all()
403420
init_op = [
404421
tf.global_variables_initializer(),
@@ -427,17 +444,24 @@ def inference(inputs, is_train=True):
427444

428445
# Print state while training
429446
if step % FLAGS.steps_to_validate == 0:
430-
loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, validate_auc_value, summary_value = sess.run(
431-
[
432-
loss, train_accuracy, train_auc, validate_accuracy,
433-
validate_auc, summary_op
434-
])
435-
end_time = datetime.datetime.now()
436-
logging.info(
437-
"[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}".
438-
format(end_time - start_time, step, loss_value,
439-
train_accuracy_value, train_auc_value,
440-
validate_accuracy_value, validate_auc_value))
447+
if SCENARIO == "classification":
448+
loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, validate_auc_value, summary_value = sess.run(
449+
[
450+
loss, train_accuracy, train_auc, validate_accuracy,
451+
validate_auc, summary_op
452+
])
453+
end_time = datetime.datetime.now()
454+
logging.info(
455+
"[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}".
456+
format(end_time - start_time, step, loss_value,
457+
train_accuracy_value, train_auc_value,
458+
validate_accuracy_value, validate_auc_value))
459+
elif SCENARIO == "regression":
460+
loss_value, summary_value = sess.run([loss, summary_op])
461+
end_time = datetime.datetime.now()
462+
logging.info("[{}] Step: {}, loss: {}".format(
463+
end_time - start_time, step, loss_value))
464+
441465
writer.add_summary(summary_value, step)
442466
saver.save(sess, CHECKPOINT_FILE, global_step=step)
443467
start_time = end_time

0 commit comments

Comments
 (0)