Skip to content

Commit b45c0a8

Browse files
committed
Capture model hyperparams and multiple metrics, in preparation for tracking in mlflow
1 parent 24dbc77 commit b45c0a8

File tree

3 files changed

+33
-33
lines changed

3 files changed

+33
-33
lines changed

src/decision_tree.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
from enum import Enum
22
import numpy as np
33
import pandas as pd
4-
import sys
5-
import os
4+
import sys, os, json
65
from sklearn.preprocessing import LabelEncoder
76
from sklearn.externals import joblib
87
sys.path.append(os.path.join('..', 'src'))
98
sys.path.append(os.path.join('src'))
10-
from sklearn import tree
11-
from sklearn import ensemble
9+
from sklearn import tree, ensemble, metrics
1210
import evaluation
1311

14-
1512
class Model(Enum):
1613
DECISION_TREE = 0
1714
RANDOM_FOREST = 1
@@ -61,22 +58,26 @@ def encode(train, validate):
6158
return train, validate
6259

6360

64-
def make_model(train, model=Model.DECISION_TREE, seed=None):
65-
print("Creating decision tree model")
61+
def train_model(train, model=Model.DECISION_TREE, seed=None):
62+
print("Training model using regressor: {}".format(model.name))
6663
train_dropped = train.drop('unit_sales', axis=1)
6764
target = train['unit_sales']
6865

6966
if model == Model.RANDOM_FOREST:
70-
clf = ensemble.RandomForestRegressor(random_state=seed)
67+
params = {'n_estimators': 10}
68+
clf = ensemble.RandomForestRegressor(random_state=seed, **params)
7169
elif model == Model.ADABOOST:
72-
clf = ensemble.AdaBoostRegressor(random_state=seed)
70+
params = {'n_estimators': 50, 'learning_rate': 1.0, 'loss':'linear'}
71+
clf = ensemble.AdaBoostRegressor(random_state=seed, **params)
7372
elif model == Model.GRADIENT_BOOST:
74-
clf = ensemble.GradientBoostingRegressor(max_depth=4, n_estimators=200, random_state=seed)
73+
params = {'n_estimators': 200, 'max_depth': 4}
74+
clf = ensemble.GradientBoostingRegressor(random_state=seed, **params)
7575
else:
76+
params = {'criterion': 'mse'}
7677
clf = tree.DecisionTreeRegressor(random_state=seed)
7778

78-
clf = clf.fit(train_dropped, target)
79-
return clf
79+
model = clf.fit(train_dropped, target)
80+
return (model,params)
8081

8182

8283
def overwrite_unseen_prediction_with_zero(preds, train, validate):
@@ -90,46 +91,45 @@ def overwrite_unseen_prediction_with_zero(preds, train, validate):
9091
return preds
9192

9293

93-
def make_predictions(clf, validate):
94+
def make_predictions(model, validate):
9495
print("Making prediction on validation data")
9596
validate_dropped = validate.drop('unit_sales', axis=1).fillna(-1)
96-
validate_preds = clf.predict(validate_dropped)
97+
validate_preds = model.predict(validate_dropped)
9798
return validate_preds
9899

99100

100-
def write_predictions_and_score(validation_score, model, columns_used):
101+
def write_predictions_and_score(evaluation_metrics, model, columns_used):
101102
key = "decision_tree"
102103
if not os.path.exists('data/{}'.format(key)):
103104
os.makedirs('data/{}'.format(key))
104105
filename = 'data/{}/model.pkl'.format(key)
105106
print("Writing to {}".format(filename))
106107
joblib.dump(model, filename)
107108

108-
filename = 'results/score.txt'
109+
filename = 'results/metrics.json'
109110
print("Writing to {}".format(filename))
110111
if not os.path.exists('results'):
111112
os.makedirs('results')
112113
with open(filename, 'w+') as score_file:
113-
score_file.write(str(validation_score))
114-
# score = pd.DataFrame({'estimate': [validation_score]})
115-
# score.to_csv(filename, index=False)
116-
117-
print("Done deciding with trees")
114+
json.dump(evaluation_metrics, score_file)
118115

119116

120117
def main(model=Model.DECISION_TREE, seed=None):
121118
original_train, original_validate = load_data()
122119
train, validate = encode(original_train, original_validate)
123-
model = make_model(train, model, seed)
120+
model, params = train_model(train, model, seed)
124121
validation_predictions = make_predictions(model, validate)
125122

126-
print("Calculating estimated error")
127-
validation_score = evaluation.nwrmsle(validation_predictions, validate['unit_sales'].values, validate['perishable'].values)
123+
print("Calculating metrics")
124+
evaluation_metrics = {
125+
'nwrmsle': evaluation.nwrmsle(validation_predictions, validate['unit_sales'].values, validate['perishable'].values),
126+
'r2_score': metrics.r2_score(y_true=validate['unit_sales'].values, y_pred=validation_predictions)
127+
}
128128

129-
write_predictions_and_score(validation_score, model, original_train.columns)
129+
write_predictions_and_score(evaluation_metrics, model, original_train.columns)
130130

131-
print("Decision tree analysis done with a validation score (error rate) of {}.".format(validation_score))
131+
print("Evaluation done with metrics {}.".format(json.dumps(evaluation_metrics)))
132132

133133

134134
if __name__ == "__main__":
135-
main(seed=8675309)
135+
main(model=Model.DECISION_TREE, seed=8675309)

test/app_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import requests
22

33
def test_endpoint():
4-
query_params = '?day_off=false&perishable=false&date=2017-06-14&item_nbr=99197&family=GROCERY%20I&class=1067&transactions=4170';
4+
query_params = '?date=2017-06-14&item_nbr=99197';
55
resp = requests.get('http://localhost:5005/prediction' + query_params);
66

77
assert resp.status_code == 200

test/test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import unittest
2-
2+
import json
33

44
class TestAccuracy(unittest.TestCase):
5-
METRICS_FILE = "results/score.txt"
5+
METRICS_FILE = "results/metrics.json"
66

77
def test_80percent_error_score(self):
88
with open(self.METRICS_FILE, 'r') as file:
9-
error_score = float(file.read())
10-
11-
self.assertLessEqual(error_score, 0.80)
9+
metrics = json.load(file)
10+
self.assertLessEqual(metrics['nwrmsle'], 0.80)
11+
self.assertGreater(metrics['r2_score'], 0.0)
1212

1313

1414
if __name__ == "__main__":

0 commit comments

Comments
 (0)