Skip to content

Commit 2b8eed4

Browse files
Create ridge_regression_test.py
1 parent 1e42129 commit 2b8eed4

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#coding:UTF-8
2+
3+
import numpy as np
4+
5+
def load_data(file_path):
6+
'''导入测试数据
7+
input: file_path(string):训练数据
8+
output: feature(mat):特征
9+
'''
10+
f = open(file_path)
11+
feature = []
12+
for line in f.readlines():
13+
feature_tmp = []
14+
lines = line.strip().split("\t")
15+
feature_tmp.append(1) # x0
16+
for i in xrange(len(lines)):
17+
feature_tmp.append(float(lines[i]))
18+
feature.append(feature_tmp)
19+
f.close()
20+
return np.mat(feature)
21+
22+
def load_model(model_file):
23+
'''导入模型
24+
input: model_file(string):线性回归模型
25+
output: w(mat):权重值
26+
'''
27+
w = []
28+
f = open(model_file)
29+
for line in f.readlines():
30+
w.append(float(line.strip()))
31+
f.close()
32+
return np.mat(w).T
33+
34+
def get_prediction(data, w):
35+
'''对新数据进行预测
36+
input: data(mat):测试数据
37+
w(mat):权重值
38+
output: 最终的预测
39+
'''
40+
return data * w
41+
42+
def save_result(file_name, predict):
43+
'''保存最终的结果
44+
input: file_name(string):需要保存的文件
45+
predict(mat):预测结果
46+
'''
47+
m = np.shape(predict)[0]
48+
result = []
49+
for i in xrange(m):
50+
result.append(str(predict[i,0]))
51+
f = open(file_name, "w")
52+
f.write("\n".join(result))
53+
f.close()
54+
55+
if __name__ == "__main__":
56+
# 1、导入测试数据
57+
print "----------1.load data ------------"
58+
testData = load_data("data.txt")
59+
# 2、导入线性回归模型
60+
print "----------2.load model ------------"
61+
w = load_model("weights")
62+
# 3、得到预测结果
63+
print "----------3.get prediction ------------"
64+
predict = get_prediction(testData, w)
65+
# 4、保存最终的结果
66+
print "----------4.save prediction ------------"
67+
save_result("predict_result", predict)

0 commit comments

Comments
 (0)