Skip to content

Commit 3b7b31c

Browse files
Create test_cart.py
1 parent 9e94039 commit 3b7b31c

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

Chapter_9 CART/test_cart.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# coding:UTF-8
2+
'''
3+
Date:20161030
4+
@author: zhaozhiyong
5+
'''
6+
7+
import random as rd
8+
import cPickle as pickle
9+
from train_cart import predict,node
10+
11+
def load_data():
12+
'''导入测试数据集
13+
'''
14+
data_test = []
15+
for i in xrange(400):
16+
tmp = []
17+
tmp.append(rd.random()) # 随机生成[0,1]之间的样本
18+
data_test.append(tmp)
19+
return data_test
20+
21+
def load_model(tree_file):
22+
'''导入训练好的CART回归树模型
23+
input: tree_file(list):保存CART回归树模型的文件
24+
output: regression_tree:CART回归树
25+
'''
26+
with open(tree_file, 'r') as f:
27+
regression_tree = pickle.load(f)
28+
return regression_tree
29+
30+
def get_prediction(data_test, regression_tree):
31+
'''对测试样本进行预测
32+
input: data_test(list):需要预测的样本
33+
regression_tree(regression_tree):训练好的回归树模型
34+
output: result(list):
35+
'''
36+
result = []
37+
for x in data_test:
38+
result.append(predict(x, regression_tree))
39+
return result
40+
41+
def save_result(data_test, result, prediction_file):
42+
'''保存最终的预测结果
43+
input: data_test(list):需要预测的数据集
44+
result(list):预测的结果
45+
prediction_file(string):保存结果的文件
46+
'''
47+
f = open(prediction_file, "w")
48+
for i in xrange(len(result)):
49+
a = str(data_test[i][0]) + "\t" + str(result[i]) + "\n"
50+
f.write(a)
51+
f.close()
52+
53+
if __name__ == "__main__":
54+
# 1、导入待计算的数据
55+
print "--------- 1、load data ----------"
56+
data_test = load_data()
57+
# 2、导入回归树模型
58+
print "--------- 2、load regression tree ---------"
59+
regression_tree = load_model("regression_tree")
60+
# 3、进行预测
61+
print "--------- 3、get prediction -----------"
62+
prediction = get_prediction(data_test, regression_tree)
63+
# 4、保存预测的结果
64+
print "--------- 4、save result ----------"
65+
save_result(data_test, prediction, "prediction")

0 commit comments

Comments
 (0)