1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Wed Aug 1 21:21:14 2018
4
+
5
+ @author: wzy
6
+ """
7
+ import matplotlib .pyplot as plt
8
+ import numpy as np
9
+ import types
10
+
11
+ """
12
+ 函数说明:加载数据
13
+
14
+ Parameters:
15
+ fileName - 文件名
16
+
17
+ Returns:
18
+ dataMat - 数据矩阵
19
+
20
+ Modify:
21
+ 2018-08-01
22
+ """
23
+ def loadDataSet (fileName ):
24
+ dataMat = []
25
+ fr = open (fileName )
26
+ for line in fr .readlines ():
27
+ curLine = line .strip ().split ('\t ' )
28
+ # 转换为float类型
29
+ # map()是 Python 内置的高阶函数,它接收一个函数 f 和一个 list,并通过把函数 f 依次作用在 list 的每个元素上,得到一个新的 list 并返回。
30
+ fltLine = list (map (float , curLine ))
31
+ dataMat .append (fltLine )
32
+ return dataMat
33
+
34
+
35
+ """
36
+ 函数说明:绘制数据集
37
+
38
+ Parameters:
39
+ fileName - 文件名
40
+
41
+ Returns:
42
+ None
43
+
44
+ Modify:
45
+ 2018-08-01
46
+ """
47
+ def plotDataSet (filename ):
48
+ dataMat = loadDataSet (filename )
49
+ n = len (dataMat )
50
+ xcord = []
51
+ ycord = []
52
+ # 样本点
53
+ for i in range (n ):
54
+ xcord .append (dataMat [i ][0 ])
55
+ ycord .append (dataMat [i ][1 ])
56
+ fig = plt .figure ()
57
+ ax = fig .add_subplot (111 )
58
+ # 绘制样本点
59
+ ax .scatter (xcord , ycord , s = 20 , c = 'blue' , alpha = .5 )
60
+ plt .title ('DataSet' )
61
+ plt .xlabel ('X' )
62
+ # plt.show()
63
+
64
+
65
+ """
66
+ 函数说明:根据特征切分数据集合
67
+
68
+ Parameters:
69
+ dataSet - 数据集合
70
+ feature - 带切分的特征
71
+ value - 该特征的值
72
+
73
+ Returns:
74
+ mat0 - 切分的数据集合0
75
+ mat1 - 切分的数据集合1
76
+
77
+ Modify:
78
+ 2018-08-01
79
+ """
80
+ def binSplitDataSet (dataSet , feature , value ):
81
+ mat0 = dataSet [np .nonzero (dataSet [:, feature ] > value )[0 ], :]
82
+ mat1 = dataSet [np .nonzero (dataSet [:, feature ] <= value )[0 ], :]
83
+ return mat0 , mat1
84
+
85
+
86
+ """
87
+ 函数说明:生成叶结点
88
+
89
+ Parameters:
90
+ dataSet - 数据集合
91
+
92
+ Returns:
93
+ 目标变量均值
94
+
95
+ Modify:
96
+ 2018-08-01
97
+ """
98
+ def regLeaf (dataSet ):
99
+ return np .mean (dataSet [:, - 1 ])
100
+
101
+
102
+ """
103
+ 函数说明:误差估计函数
104
+
105
+ Parameters:
106
+ dataSet - 数据集合
107
+
108
+ Returns:
109
+ 目标变量的总方差
110
+
111
+ Modify:
112
+ 2018-08-01
113
+ """
114
+ def regErr (dataSet ):
115
+ # var表示方差,即各项-均值的平方求和后再除以N
116
+ return np .var (dataSet [:, - 1 ]) * np .shape (dataSet )[0 ]
117
+
118
+
119
+ """
120
+ 函数说明:找到数据的最佳二元切分方式函数
121
+ 预剪枝
122
+
123
+ Parameters:
124
+ dataSet - 数据集合
125
+ leafType - 生成叶结点的函数
126
+ errType - 误差估计函数
127
+ ops - 用户定义的参数构成的元组
128
+
129
+ Returns:
130
+ bestIndex - 最佳切分特征
131
+ bestValue - 最佳特征值
132
+
133
+ Modify:
134
+ 2018-08-01
135
+ """
136
+ def chooseBestSplit (dataSet , leafType = regLeaf , errType = regErr , ops = (1 , 4 )):
137
+ # tolS:允许的误差下降值
138
+ tolS = ops [0 ]
139
+ # tolN:切分的最小样本数
140
+ tolN = ops [1 ]
141
+ # 如果当前所有值相等,则退出(根据set的特性只保留不重复的元素)
142
+ if len (set (dataSet [:, - 1 ].T .tolist ()[0 ])) == 1 :
143
+ return None , leafType (dataSet )
144
+ # 统计数据集合的行m和列n
145
+ m , n = np .shape (dataSet )
146
+ # 默认最后一个特征为最佳切分特征,计算其误差估计
147
+ S = errType (dataSet )
148
+ # 分别为最佳误差,最佳特征切分的索引值,最佳特征值
149
+ bestS = float ('inf' )
150
+ bestIndex = 0
151
+ bestValue = 0
152
+ # 遍历所有特征
153
+ for featIndex in range (n - 1 ):
154
+ # 遍历所有特征值
155
+ for splitVal in set (dataSet [:, featIndex ].T .A .tolist ()[0 ]):
156
+ # 根据特征和特征值切分数据集
157
+ mat0 , mat1 = binSplitDataSet (dataSet , featIndex , splitVal )
158
+ # 如果数据少于tolN,则退出剪枝操作
159
+ if (np .shape (mat0 )[0 ] < tolN ) or (np .shape (mat1 )[0 ] < tolN ):
160
+ continue
161
+ # 计算误差估计,寻找newS的最小值
162
+ newS = errType (mat0 ) + errType (mat1 )
163
+ # 如果误差估计更小,则更新特征索引值和特征值
164
+ if newS < bestS :
165
+ # 特征索引
166
+ bestIndex = featIndex
167
+ # 分割标准
168
+ bestValue = splitVal
169
+ # 更新目标函数的最小值
170
+ bestS = newS
171
+ # 如果误差减少不大则退出
172
+ if (S - bestS ) < tolS :
173
+ return None , leafType (dataSet )
174
+ # 根据最佳的切分特征和特征值切分数据集合
175
+ mat0 , mat1 = binSplitDataSet (dataSet , bestIndex , bestValue )
176
+ # 如果切分出的数据集很小则退出
177
+ if (np .shape (mat0 )[0 ] < tolN ) or (np .shape (mat1 )[0 ] < tolN ):
178
+ return None , leafType (dataSet )
179
+ # 返回最佳切分特征和特征值
180
+ return bestIndex , bestValue
181
+
182
+
183
+ """
184
+ 函数说明:树构建函数
185
+
186
+ Parameters:
187
+ dataSet - 数据集合
188
+ leafType - 生成叶结点的函数
189
+ errType - 误差估计函数
190
+ ops - 用户定义的参数构成的元组
191
+
192
+ Returns:
193
+ retTree - 构建的回归树
194
+
195
+ Modify:
196
+ 2018-08-01
197
+ """
198
+ def createTree (dataSet , leafType = regLeaf , errType = regErr , ops = (1 , 4 )):
199
+ # 选择最佳切分特征和特征值
200
+ feat , val = chooseBestSplit (dataSet , leafType , errType , ops )
201
+ # 如果没有特征,则返回特征值
202
+ if feat == None :
203
+ return val
204
+ # 回归树
205
+ retTree = {}
206
+ # 分割特征索引
207
+ retTree ['spInd' ] = feat
208
+ # 分割标准
209
+ retTree ['spVal' ] = val
210
+ # 分成左数据集和右数据集
211
+ lSet , rSet = binSplitDataSet (dataSet , feat , val )
212
+ # 创建左子树和右子树 递归
213
+ retTree ['left' ] = createTree (lSet , leafType , errType , ops )
214
+ retTree ['right' ] = createTree (rSet , leafType , errType , ops )
215
+ return retTree
216
+
217
+
218
+ """
219
+ 函数说明:判断测试输入变量是否是一颗树
220
+ 树是通过字典存储的
221
+
222
+ Parameters:
223
+ obj - 测试对象
224
+
225
+ Returns:
226
+ 是否是一颗树
227
+
228
+ Modify:
229
+ 2018-08-01
230
+ """
231
+ def isTree (obj ):
232
+ return (type (obj ).__name__ == 'dict' )
233
+
234
+
235
+ """
236
+ 函数说明:对树进行塌陷处理(即返回树平均值)
237
+
238
+ Parameters:
239
+ tree - 树
240
+
241
+ Returns:
242
+ 树的平均值
243
+
244
+ Modify:
245
+ 2018-08-01
246
+ """
247
+ def getMean (tree ):
248
+ if isTree (tree ['right' ]):
249
+ tree ['right' ] = getMean (tree ['right' ])
250
+ if isTree (tree ['left' ]):
251
+ tree ['left' ] = getMean (tree ['left' ])
252
+ return (tree ['left' ] + tree ['right' ]) / 2.0
253
+
254
+
255
+ """
256
+ 函数说明:后剪枝
257
+
258
+ Parameters:
259
+ tree - 树
260
+ testData - 测试集
261
+
262
+ Returns:
263
+ 树
264
+
265
+ Modify:
266
+ 2018-08-01
267
+ """
268
+ def prune (tree , testData ):
269
+ # 如果测试集为空,则对树进行塌陷处理
270
+ if np .shape (testData )[0 ] == 0 :
271
+ return getMean (tree )
272
+ # 如果有左子树或者右子树,则切分数据集
273
+ if (isTree (tree ['right' ]) or isTree (tree ['left' ])):
274
+ lSet , rSet = binSplitDataSet (testData , tree ['spInd' ], tree ['spVal' ])
275
+ # 处理左子树(剪枝)
276
+ if isTree (tree ['left' ]):
277
+ tree ['left' ] = prune (tree ['left' ], lSet )
278
+ # 处理右子树(剪枝)
279
+ if isTree (tree ['right' ]):
280
+ tree ['right' ] = prune (tree ['right' ], rSet )
281
+ # 如果当前节点的左右结点为叶结点
282
+ if not isTree (tree ['left' ]) and not isTree (tree ['right' ]):
283
+ lSet , rSet = binSplitDataSet (testData , tree ['spInd' ], tree ['spVal' ])
284
+ # 计算没有合并的误差
285
+ errorNoMerge = np .sum (np .power (lSet [:, - 1 ] - tree ['left' ], 2 )) + np .sum (np .power (rSet [:, 1 ] - tree ['right' ], 2 ))
286
+ # 计算合并的均值
287
+ treeMean = (tree ['left' ] + tree ['right' ]) / 2.0
288
+ # 计算合并的误差
289
+ errorMerge = np .sum (np .power (testData [:, - 1 ] - treeMean , 2 ))
290
+ # 如果合并的误差小于没有合并的误差,则合并
291
+ if errorMerge < errorNoMerge :
292
+ return treeMean
293
+ else :
294
+ return tree
295
+ else :
296
+ return tree
297
+
298
+
299
+ """
300
+ 函数说明:简单线性回归
301
+
302
+ Parameters:
303
+ dataSet - 数据集
304
+
305
+ Returns:
306
+ ws - 最佳回归系数
307
+ X - 特征矩阵
308
+ Y - label列向量
309
+
310
+ Modify:
311
+ 2018-08-02
312
+ """
313
+ def linearSolve (dataSet ):
314
+ m , n = np .shape (dataSet )
315
+ X = np .mat (np .ones ((m , n )))
316
+ Y = np .mat (np .ones ((m , 1 )))
317
+ # 保存特征矩阵X的第一列全为1
318
+ X [:, 1 :n ] = dataSet [:, 0 :n - 1 ]
319
+ # 保存label列向量
320
+ Y = dataSet [:, - 1 ]
321
+ # 简单线性回归
322
+ xTx = X .T * X
323
+ # 奇异矩阵不可以求逆
324
+ if np .linalg .det (xTx ) == 0.0 :
325
+ raise NameError ('This matrix is singular, cannont do inverse,\n \
326
+ try increasing the second value of ops' )
327
+ # 求解回归系数
328
+ ws = xTx .I * (X .T * Y )
329
+ return ws , X , Y
330
+
331
+
332
+ """
333
+ 函数说明:返回数据集的回归系数
334
+
335
+ Parameters:
336
+ dataSet - 数据集
337
+
338
+ Returns:
339
+ ws - 最佳回归系数
340
+
341
+ Modify:
342
+ 2018-08-02
343
+ """
344
+ def modelLeaf (dataSet ):
345
+ ws , X , Y = linearSolve (dataSet )
346
+ return ws
347
+
348
+
349
+ """
350
+ 函数说明:计算误差
351
+
352
+ Parameters:
353
+ dataSet - 数据集
354
+
355
+ Returns:
356
+ 误差值
357
+
358
+ Modify:
359
+ 2018-08-02
360
+ """
361
+ def modelErr (dataSet ):
362
+ ws , X , Y = linearSolve (dataSet )
363
+ yHat = X * ws
364
+ # 求差值的平方和
365
+ return sum (np .power (Y - yHat , 2 ))
366
+
367
+
368
+ if __name__ == '__main__' :
369
+ train_filename = 'exp2.txt'
370
+ train_Data = loadDataSet (train_filename )
371
+ dataMat = np .mat (train_Data )
372
+ Tree = createTree (dataMat , modelLeaf , modelErr , (1 , 10 ))
373
+ print (Tree )
374
+ plotDataSet (train_filename )
375
+ # 绘制分段回归直线
376
+ x1 = np .linspace (0 , Tree ['spVal' ])
377
+ plt .plot (x1 , float (Tree ['right' ][1 ])* x1 + float (Tree ['right' ][0 ]), 'r--' )
378
+ x2 = np .linspace (Tree ['spVal' ], 1 )
379
+ plt .plot (x2 , float (Tree ['left' ][1 ])* x2 + float (Tree ['left' ][0 ]), 'r--' )
380
+ # 显示
381
+ plt .show ()
382
+
383
+
0 commit comments