1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Wed Aug 1 15:34:57 2018
4
+
5
+ @author: wzy
6
+ """
7
+ import matplotlib .pyplot as plt
8
+ import numpy as np
9
+
10
+ """
11
+ 函数说明:加载数据
12
+
13
+ Parameters:
14
+ fileName - 文件名
15
+
16
+ Returns:
17
+ dataMat - 数据矩阵
18
+
19
+ Modify:
20
+ 2018-08-01
21
+ """
22
+ def loadDataSet (fileName ):
23
+ dataMat = []
24
+ fr = open (fileName )
25
+ for line in fr .readlines ():
26
+ curLine = line .strip ().split ('\t ' )
27
+ # 转换为float类型
28
+ # map()是 Python 内置的高阶函数,它接收一个函数 f 和一个 list,并通过把函数 f 依次作用在 list 的每个元素上,得到一个新的 list 并返回。
29
+ fltLine = list (map (float , curLine ))
30
+ dataMat .append (fltLine )
31
+ return dataMat
32
+
33
+
34
+ """
35
+ 函数说明:根据特征切分数据集合
36
+
37
+ Parameters:
38
+ dataSet - 数据集合
39
+ feature - 带切分的特征
40
+ value - 该特征的值
41
+
42
+ Returns:
43
+ mat0 - 切分的数据集合0
44
+ mat1 - 切分的数据集合1
45
+
46
+ Modify:
47
+ 2018-08-01
48
+ """
49
+ def binSplitDataSet (dataSet , feature , value ):
50
+ mat0 = dataSet [np .nonzero (dataSet [:, feature ] > value )[0 ], :]
51
+ mat1 = dataSet [np .nonzero (dataSet [:, feature ] <= value )[0 ], :]
52
+ return mat0 , mat1
53
+
54
+
55
+ """
56
+ 函数说明:生成叶结点
57
+
58
+ Parameters:
59
+ dataSet - 数据集合
60
+
61
+ Returns:
62
+ 目标变量均值
63
+
64
+ Modify:
65
+ 2018-08-01
66
+ """
67
+ def regLeaf (dataSet ):
68
+ return np .mean (dataSet [:, - 1 ])
69
+
70
+
71
+ """
72
+ 函数说明:误差估计函数
73
+
74
+ Parameters:
75
+ dataSet - 数据集合
76
+
77
+ Returns:
78
+ 目标变量的总方差
79
+
80
+ Modify:
81
+ 2018-08-01
82
+ """
83
+ def regErr (dataSet ):
84
+ # var表示方差,即各项-均值的平方求和后再除以N
85
+ return np .var (dataSet [:, - 1 ]) * np .shape (dataSet )[0 ]
86
+
87
+
88
+ """
89
+ 函数说明:找到数据的最佳二元切分方式函数
90
+
91
+ Parameters:
92
+ dataSet - 数据集合
93
+ leafType - 生成叶结点的函数
94
+ errType - 误差估计函数
95
+ ops - 用户定义的参数构成的元组
96
+
97
+ Returns:
98
+ bestIndex - 最佳切分特征
99
+ bestValue - 最佳特征值
100
+
101
+ Modify:
102
+ 2018-08-01
103
+ """
104
+ def chooseBestSplit (dataSet , leafType = regLeaf , errType = regErr , ops = (1 , 4 )):
105
+ # tolS:允许的误差下降值
106
+ tolS = ops [0 ]
107
+ # tolN:切分的最小样本数
108
+ tolN = ops [1 ]
109
+ # 如果当前所有值相等,则退出(根据set的特性只保留不重复的元素)
110
+ if len (set (dataSet [:, - 1 ].T .tolist ()[0 ])) == 1 :
111
+ return None , leafType (dataSet )
112
+ # 统计数据集合的行m和列n
113
+ m , n = np .shape (dataSet )
114
+ # 默认最后一个特征为最佳切分特征,计算其误差估计
115
+ S = errType (dataSet )
116
+ # 分别为最佳误差,最佳特征切分的索引值,最佳特征值
117
+ bestS = float ('inf' )
118
+ bestIndex = 0
119
+ bestValue = 0
120
+ # 遍历所有特征
121
+ for featIndex in range (n - 1 ):
122
+ # 遍历所有特征值
123
+ for splitVal in set (dataSet [:, featIndex ].T .A .tolist ()[0 ]):
124
+ # 根据特征和特征值切分数据集
125
+ mat0 , mat1 = binSplitDataSet (dataSet , featIndex , splitVal )
126
+ # 如果数据少于tolN,则退出剪枝操作
127
+ if (np .shape (mat0 )[0 ] < tolN ) or (np .shape (mat1 )[0 ] < tolN ):
128
+ continue
129
+ # 计算误差估计,寻找newS的最小值
130
+ newS = errType (mat0 ) + errType (mat1 )
131
+ # 如果误差估计更小,则更新特征索引值和特征值
132
+ if newS < bestS :
133
+ # 特征索引
134
+ bestIndex = featIndex
135
+ # 分割标准
136
+ bestValue = splitVal
137
+ # 更新目标函数的最小值
138
+ bestS = newS
139
+ # 如果误差减少不大则退出
140
+ if (S - bestS ) < tolS :
141
+ return None , leafType (dataSet )
142
+ # 根据最佳的切分特征和特征值切分数据集合
143
+ mat0 , mat1 = binSplitDataSet (dataSet , bestIndex , bestValue )
144
+ # 如果切分出的数据集很小则退出
145
+ if (np .shape (mat0 )[0 ] < tolN ) or (np .shape (mat1 )[0 ] < tolN ):
146
+ return None , leafType (dataSet )
147
+ # 返回最佳切分特征和特征值
148
+ return bestIndex , bestValue
149
+
150
+
151
+ """
152
+ 函数说明:树构建函数
153
+
154
+ Parameters:
155
+ dataSet - 数据集合
156
+ leafType - 生成叶结点的函数
157
+ errType - 误差估计函数
158
+ ops - 用户定义的参数构成的元组
159
+
160
+ Returns:
161
+ retTree - 构建的回归树
162
+
163
+ Modify:
164
+ 2018-08-01
165
+ """
166
+ def createTree (dataSet , leafType = regLeaf , errType = regErr , ops = (1 , 4 )):
167
+ # 选择最佳切分特征和特征值
168
+ feat , val = chooseBestSplit (dataSet , leafType , errType , ops )
169
+ # 如果没有特征,则返回特征值
170
+ if feat == None :
171
+ return val
172
+ # 回归树
173
+ retTree = {}
174
+ # 分割特征索引
175
+ retTree ['spInd' ] = feat
176
+ # 分割标准
177
+ retTree ['spVal' ] = val
178
+ # 分成左数据集和右数据集
179
+ lSet , rSet = binSplitDataSet (dataSet , feat , val )
180
+ # 创建左子树和右子树 递归
181
+ retTree ['left' ] = createTree (lSet , leafType , errType , ops )
182
+ retTree ['right' ] = createTree (rSet , leafType , errType , ops )
183
+ return retTree
184
+
185
+
186
+ """
187
+ 函数说明:绘制数据集
188
+
189
+ Parameters:
190
+ fileName - 文件名
191
+
192
+ Returns:
193
+ None
194
+
195
+ Modify:
196
+ 2018-08-01
197
+ """
198
+ def plotDataSet (filename ):
199
+ dataMat = loadDataSet (filename )
200
+ n = len (dataMat )
201
+ xcord = []
202
+ ycord = []
203
+ # 样本点
204
+ for i in range (n ):
205
+ xcord .append (dataMat [i ][0 ])
206
+ ycord .append (dataMat [i ][1 ])
207
+ fig = plt .figure ()
208
+ ax = fig .add_subplot (111 )
209
+ # 绘制样本点
210
+ ax .scatter (xcord , ycord , s = 20 , c = 'blue' , alpha = .5 )
211
+ plt .title ('DataSet' )
212
+ plt .xlabel ('X' )
213
+ plt .show ()
214
+
215
+
216
+ if __name__ == '__main__' :
217
+ # filename = 'ex00.txt'
218
+ # plotDataSet(filename)
219
+ myData = loadDataSet ('ex2.txt' )
220
+ myMat = np .mat (myData )
221
+ print (createTree (myMat ))
222
+ # print(feat)
223
+ # print(val)
224
+
0 commit comments