Skip to content

Commit 143aa9d

Browse files
Create bp_test.py
1 parent 0614985 commit 143aa9d

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

Chapter_6 BP/bp_test.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# coding:UTF-8
2+
'''
3+
Date:20160831
4+
@author: zhaozhiyong
5+
'''
6+
import numpy as np
7+
from bp_train import get_predict
8+
9+
def load_data(file_name):
10+
'''导入数据
11+
input: file_name(string):文件的存储位置
12+
output: feature_data(mat):特征
13+
'''
14+
f = open(file_name) # 打开文件
15+
feature_data = []
16+
for line in f.readlines():
17+
feature_tmp = []
18+
lines = line.strip().split("\t")
19+
for i in xrange(len(lines)):
20+
feature_tmp.append(float(lines[i]))
21+
feature_data.append(feature_tmp)
22+
f.close() # 关闭文件
23+
return np.mat(feature_data)
24+
25+
def generate_data():
26+
'''在[-4.5,4.5]之间随机生成20000组点
27+
'''
28+
# 1、随机生成数据点
29+
data = np.mat(np.zeros((20000, 2)))
30+
m = np.shape(data)[0]
31+
x = np.mat(np.random.rand(20000, 2))
32+
for i in xrange(m):
33+
data[i, 0] = x[i, 0] * 9 - 4.5
34+
data[i, 1] = x[i, 1] * 9 - 4.5
35+
# 2、将数据点保存到文件“test_data”中
36+
f = open("test_data", "w")
37+
m,n = np.shape(dataTest)
38+
for i in xrange(m):
39+
tmp =[]
40+
for j in xrange(n):
41+
tmp.append(str(dataTest[i,j]))
42+
f.write("\t".join(tmp) + "\n")
43+
f.close()
44+
45+
def load_model(file_w0, file_w1, file_b0, file_b1):
46+
47+
def get_model(file_name):
48+
f = open(file_name)
49+
model = []
50+
for line in f.readlines():
51+
lines = line.strip().split("\t")
52+
model_tmp = []
53+
for x in lines:
54+
model_tmp.append(float(x.strip()))
55+
model.append(model_tmp)
56+
f.close()
57+
return np.mat(model)
58+
59+
# 1、导入输入层到隐含层之间的权重
60+
w0 = get_model(file_w0)
61+
62+
# 2、导入隐含层到输出层之间的权重
63+
w1 = get_model(file_w1)
64+
65+
# 3、导入输入层到隐含层之间的权重
66+
b0 = get_model(file_b0)
67+
68+
# 4、导入隐含层到输出层之间的权重
69+
b1 = get_model(file_b1)
70+
71+
return w0, w1, b0, b1
72+
73+
def save_predict(file_name, pre):
74+
'''保存最终的预测结果
75+
input: pre(mat):最终的预测结果
76+
output:
77+
'''
78+
f = open(file_name, "w")
79+
m = np.shape(pre)[0]
80+
result = []
81+
for i in xrange(m):
82+
result.append(str(pre[i, 0]))
83+
f.write("\n".join(result))
84+
f.close()
85+
86+
if __name__ == "__main__":
87+
generate_data()
88+
# 1、导入测试数据
89+
print "--------- 1.load data ------------"
90+
dataTest = load_data("test_data")
91+
# 2、导入BP神经网络模型
92+
print "--------- 2.load model ------------"
93+
w0, w1, b0, b1 = load_model("weight_w0", "weight_w1", "weight_b0", "weight_b1")
94+
# 3、得到最终的预测值
95+
print "--------- 3.get prediction ------------"
96+
result = get_predict(dataTest, w0, w1, b0, b1)
97+
# 4、保存最终的预测结果
98+
print "--------- 4.save result ------------"
99+
pre = np.argmax(result, axis=1)
100+
save_predict("result", pre)

0 commit comments

Comments
 (0)