Skip to content

Commit c21801a

Browse files
Create random_forests_test.py
1 parent 3c8c5a5 commit c21801a

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#coding:UTF-8
2+
'''
3+
Date:20161030
4+
@author: zhaozhiyong
5+
'''
6+
7+
import cPickle as pickle
8+
from random_forests_train import get_predict
9+
10+
def load_data(file_name):
11+
'''导入待分类的数据集
12+
input: file_name(string):待分类数据存储的位置
13+
output: test_data(list)
14+
'''
15+
f = open(file_name)
16+
test_data = []
17+
for line in f.readlines():
18+
lines = line.strip().split("\t")
19+
tmp = []
20+
for x in lines:
21+
tmp.append(float(x))
22+
tmp.append(0) # 保存初始的label
23+
test_data.append(tmp)
24+
f.close()
25+
return test_data
26+
27+
def load_model(result_file, feature_file):
28+
'''导入随机森林模型和每一个分类树中选择的特征
29+
input: result_file(string):随机森林模型存储的文件
30+
feature_file(string):分类树选择的特征存储的文件
31+
output: trees_result(list):随机森林模型
32+
trees_fiture(list):每一棵分类树选择的特征
33+
'''
34+
# 1、导入选择的特征
35+
trees_fiture = []
36+
f_fea = open(feature_file)
37+
for line in f_fea.readlines():
38+
lines = line.strip().split("\t")
39+
tmp = []
40+
for x in lines:
41+
tmp.append(int(x))
42+
trees_fiture.append(tmp)
43+
f_fea.close()
44+
45+
# 2、导入随机森林模型
46+
with open(result_file, 'r') as f:
47+
trees_result = pickle.load(f)
48+
49+
return trees_result, trees_fiture
50+
51+
def save_result(data_test, prediction, result_file):
52+
'''保存最终的预测结果
53+
input: data_test(list):待预测的数据
54+
prediction(list):预测的结果
55+
result_file(string):存储最终预测结果的文件名
56+
'''
57+
m = len(prediction)
58+
n = len(data_test[0])
59+
60+
f_result = open(result_file, "w")
61+
for i in xrange(m):
62+
tmp = []
63+
for j in xrange(n -1):
64+
tmp.append(str(data_test[i][j]))
65+
tmp.append(str(prediction[i]))
66+
f_result.writelines("\t".join(tmp) + "\n")
67+
f_result.close()
68+
69+
if __name__ == "__main__":
70+
# 1、导入测试数据集
71+
print "--------- 1、load test data --------"
72+
data_test = load_data("test_data.txt")
73+
# 2、导入随机森林模型
74+
print "--------- 2、load random forest model ----------"
75+
trees_result, trees_feature = load_model("result_file", "feature_file")
76+
# 3、预测
77+
print "--------- 3、get prediction -----------"
78+
prediction = get_predict(trees_result, trees_feature, data_test)
79+
# 4、保存最终的预测结果
80+
print "--------- 4、save result -----------"
81+
save_result(data_test, prediction, "final_result")
82+
83+
84+

0 commit comments

Comments
 (0)