1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Jul 29 22:27:00 2018
4
+
5
+ @author: wzy
6
+ """
7
+ import numpy as np
8
+
9
+ """
10
+ 函数说明:加载数据
11
+
12
+ Parameters:
13
+ filename - 文件名
14
+
15
+ Returns:
16
+ xArr - x数据集
17
+ yArr - y数据集
18
+
19
+ Modify:
20
+ 2018-07-30
21
+ """
22
+ def loadDataSet (filename ):
23
+ # 计算特征个数,由于最后一列为y值所以减一
24
+ numFeat = len (open (filename ).readline ().split ('\t ' )) - 1
25
+ xArr = []
26
+ yArr = []
27
+ fr = open (filename )
28
+ for line in fr .readlines ():
29
+ lineArr = []
30
+ curLine = line .strip ().split ('\t ' )
31
+ for i in range (numFeat ):
32
+ lineArr .append (float (curLine [i ]))
33
+ xArr .append (lineArr )
34
+ yArr .append (float (curLine [- 1 ]))
35
+ return xArr , yArr
36
+
37
+
38
+ """
39
+ 函数说明:使用局部加权线性回归计算回归系数w
40
+
41
+ Parameters:
42
+ testPoint - 测试样本点
43
+ xArr - x数据集
44
+ yArr - y数据集
45
+ k - 高斯核的k,自定义参数
46
+
47
+ Returns:
48
+ ws - 回归系数
49
+
50
+ Modify:
51
+ 2018-07-30
52
+ """
53
+ def lwlr (testPoint , xArr , yArr , k = 1.0 ):
54
+ xMat = np .mat (xArr )
55
+ yMat = np .mat (yArr ).T
56
+ m = np .shape (xMat )[0 ]
57
+ # 创建加权对角阵
58
+ weights = np .mat (np .eye ((m )))
59
+ for j in range (m ):
60
+ # 高斯核
61
+ diffMat = testPoint - xMat [j , :]
62
+ weights [j , j ] = np .exp (diffMat * diffMat .T / (- 2.0 * k ** 2 ))
63
+ xTx = xMat .T * (weights * xMat )
64
+ # 求矩阵的行列式
65
+ if np .linalg .det (xTx ) == 0.0 :
66
+ print ("矩阵为奇异矩阵,不能求逆" )
67
+ return
68
+ # .I求逆矩阵
69
+ ws = (xTx .I ) * (xMat .T * (weights * yMat ))
70
+ return testPoint * ws
71
+
72
+
73
+ """
74
+ 函数说明:局部加权线性回归测试
75
+
76
+ Parameters:
77
+ testArr - 测试数据集
78
+ xArr - x数据集
79
+ yArr - y数据集
80
+ k - 高斯核的k,自定义参数
81
+
82
+ Returns:
83
+ ws - 回归系数
84
+
85
+ Modify:
86
+ 2018-07-30
87
+ """
88
+ def lwlrTest (testArr , xArr , yArr , k = 1.0 ):
89
+ m = np .shape (testArr )[0 ]
90
+ yHat = np .zeros (m )
91
+ for i in range (m ):
92
+ yHat [i ] = lwlr (testArr [i ], xArr , yArr , k )
93
+ return yHat
94
+
95
+
96
+ """
97
+ 函数说明:计算回归系数w
98
+
99
+ Parameters:
100
+ xArr - x数据集
101
+ yArr - y数据集
102
+
103
+ Returns:
104
+ ws - 回归系数
105
+
106
+ Modify:
107
+ 2018-07-30
108
+ """
109
+ def standRegres (xArr , yArr ):
110
+ xMat = np .mat (xArr )
111
+ yMat = np .mat (yArr ).T
112
+ xTx = xMat .T * xMat
113
+ # 求矩阵的行列式
114
+ if np .linalg .det (xTx ) == 0.0 :
115
+ print ("矩阵为奇异矩阵,不能求逆" )
116
+ return
117
+ # .I求逆矩阵
118
+ ws = (xTx .I ) * (xMat .T ) * yMat
119
+ return ws
120
+
121
+
122
+ """
123
+ 函数说明:误差大小评价函数
124
+
125
+ Parameters:
126
+ yArr - 真实数据
127
+ yHatArr - 预测数据
128
+
129
+ Returns:
130
+ ws - 回归系数
131
+
132
+ Modify:
133
+ 2018-07-30
134
+ """
135
+ def rssError (yArr , yHatArr ):
136
+ return ((yArr - yHatArr )** 2 ).sum ()
137
+
138
+
139
+ if __name__ == '__main__' :
140
+ abX , abY = loadDataSet ('abalone.txt' )
141
+ print ("训练集与测试集相同:局部加权线性回归,核k的大小对预测的影响:" )
142
+ yHat01 = lwlrTest (abX [0 :99 ], abX [0 :99 ], abY [0 :99 ], 0.1 )
143
+ yHat1 = lwlrTest (abX [0 :99 ], abX [0 :99 ], abY [0 :99 ], 1 )
144
+ yHat10 = lwlrTest (abX [0 :99 ], abX [0 :99 ], abY [0 :99 ], 10 )
145
+ print ('k=0.1时,误差大小为:' , rssError (abY [0 :99 ], yHat01 .T ))
146
+ print ('k=1时,误差大小为:' , rssError (abY [0 :99 ], yHat1 .T ))
147
+ print ('k=10时,误差大小为:' , rssError (abY [0 :99 ], yHat10 .T ))
148
+ print ('' )
149
+ print ("训练集与测试集不同:局部加权线性回归,核k的大小是越小越好吗?更换数据集,测试结果如下:" )
150
+ yHat01 = lwlrTest (abX [100 :199 ], abX [0 :99 ], abY [0 :99 ], 0.1 )
151
+ yHat1 = lwlrTest (abX [100 :199 ], abX [0 :99 ], abY [0 :99 ], 1 )
152
+ yHat10 = lwlrTest (abX [100 :199 ], abX [0 :99 ], abY [0 :99 ], 10 )
153
+ print ('k=0.1时,误差大小为:' , rssError (abY [100 :199 ], yHat01 .T ))
154
+ print ('k=1时,误差大小为:' , rssError (abY [100 :199 ], yHat1 .T ))
155
+ print ('k=10时,误差大小为:' , rssError (abY [100 :199 ], yHat10 .T ))
156
+ print ('' )
157
+ print ("训练集与测试集不同:简单的线性回归与k=1时的局部加权线性回归对比:" )
158
+ print ('k=1时,误差大小为:' , rssError (abY [100 :199 ], yHat1 .T ))
159
+ ws = standRegres (abX [0 :99 ], abY [0 :99 ])
160
+ yHat = np .mat (abX [100 :199 ]) * ws
161
+ print ('简单的线性回归误差大小:' , rssError (abY [100 :199 ], yHat .T .A ))
162
+
0 commit comments