Skip to content

Commit 2ec8e8e

Browse files
authored
Add files via upload
1 parent f21036b commit 2ec8e8e

File tree

3 files changed

+709
-0
lines changed

3 files changed

+709
-0
lines changed

CART_Project3/CART.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Wed Aug 1 21:21:14 2018
4+
5+
@author: wzy
6+
"""
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
import types
10+
11+
"""
12+
函数说明:加载数据
13+
14+
Parameters:
15+
fileName - 文件名
16+
17+
Returns:
18+
dataMat - 数据矩阵
19+
20+
Modify:
21+
2018-08-01
22+
"""
23+
def loadDataSet(fileName):
24+
dataMat = []
25+
fr = open(fileName)
26+
for line in fr.readlines():
27+
curLine = line.strip().split('\t')
28+
# 转换为float类型
29+
# map()是 Python 内置的高阶函数,它接收一个函数 f 和一个 list,并通过把函数 f 依次作用在 list 的每个元素上,得到一个新的 list 并返回。
30+
fltLine = list(map(float, curLine))
31+
dataMat.append(fltLine)
32+
return dataMat
33+
34+
35+
"""
36+
函数说明:绘制数据集
37+
38+
Parameters:
39+
fileName - 文件名
40+
41+
Returns:
42+
None
43+
44+
Modify:
45+
2018-08-01
46+
"""
47+
def plotDataSet(filename):
48+
dataMat = loadDataSet(filename)
49+
n = len(dataMat)
50+
xcord = []
51+
ycord = []
52+
# 样本点
53+
for i in range(n):
54+
xcord.append(dataMat[i][0])
55+
ycord.append(dataMat[i][1])
56+
fig = plt.figure()
57+
ax = fig.add_subplot(111)
58+
# 绘制样本点
59+
ax.scatter(xcord, ycord, s=20, c='blue', alpha=.5)
60+
plt.title('DataSet')
61+
plt.xlabel('X')
62+
plt.show()
63+
64+
65+
"""
66+
函数说明:根据特征切分数据集合
67+
68+
Parameters:
69+
dataSet - 数据集合
70+
feature - 带切分的特征
71+
value - 该特征的值
72+
73+
Returns:
74+
mat0 - 切分的数据集合0
75+
mat1 - 切分的数据集合1
76+
77+
Modify:
78+
2018-08-01
79+
"""
80+
def binSplitDataSet(dataSet, feature, value):
81+
mat0 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :]
82+
mat1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :]
83+
return mat0, mat1
84+
85+
86+
"""
87+
函数说明:生成叶结点
88+
89+
Parameters:
90+
dataSet - 数据集合
91+
92+
Returns:
93+
目标变量均值
94+
95+
Modify:
96+
2018-08-01
97+
"""
98+
def regLeaf(dataSet):
99+
return np.mean(dataSet[:, -1])
100+
101+
102+
"""
103+
函数说明:误差估计函数
104+
105+
Parameters:
106+
dataSet - 数据集合
107+
108+
Returns:
109+
目标变量的总方差
110+
111+
Modify:
112+
2018-08-01
113+
"""
114+
def regErr(dataSet):
115+
# var表示方差,即各项-均值的平方求和后再除以N
116+
return np.var(dataSet[:, -1]) * np.shape(dataSet)[0]
117+
118+
119+
"""
120+
函数说明:找到数据的最佳二元切分方式函数
121+
预剪枝
122+
123+
Parameters:
124+
dataSet - 数据集合
125+
leafType - 生成叶结点的函数
126+
errType - 误差估计函数
127+
ops - 用户定义的参数构成的元组
128+
129+
Returns:
130+
bestIndex - 最佳切分特征
131+
bestValue - 最佳特征值
132+
133+
Modify:
134+
2018-08-01
135+
"""
136+
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
137+
# tolS:允许的误差下降值
138+
tolS = ops[0]
139+
# tolN:切分的最小样本数
140+
tolN = ops[1]
141+
# 如果当前所有值相等,则退出(根据set的特性只保留不重复的元素)
142+
if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
143+
return None, leafType(dataSet)
144+
# 统计数据集合的行m和列n
145+
m, n = np.shape(dataSet)
146+
# 默认最后一个特征为最佳切分特征,计算其误差估计
147+
S = errType(dataSet)
148+
# 分别为最佳误差,最佳特征切分的索引值,最佳特征值
149+
bestS = float('inf')
150+
bestIndex = 0
151+
bestValue = 0
152+
# 遍历所有特征
153+
for featIndex in range(n-1):
154+
# 遍历所有特征值
155+
for splitVal in set(dataSet[:, featIndex].T.A.tolist()[0]):
156+
# 根据特征和特征值切分数据集
157+
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
158+
# 如果数据少于tolN,则退出剪枝操作
159+
if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
160+
continue
161+
# 计算误差估计,寻找newS的最小值
162+
newS = errType(mat0) + errType(mat1)
163+
# 如果误差估计更小,则更新特征索引值和特征值
164+
if newS < bestS:
165+
# 特征索引
166+
bestIndex = featIndex
167+
# 分割标准
168+
bestValue = splitVal
169+
# 更新目标函数的最小值
170+
bestS = newS
171+
# 如果误差减少不大则退出
172+
if (S - bestS) < tolS:
173+
return None, leafType(dataSet)
174+
# 根据最佳的切分特征和特征值切分数据集合
175+
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
176+
# 如果切分出的数据集很小则退出
177+
if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
178+
return None, leafType(dataSet)
179+
# 返回最佳切分特征和特征值
180+
return bestIndex, bestValue
181+
182+
183+
"""
184+
函数说明:树构建函数
185+
186+
Parameters:
187+
dataSet - 数据集合
188+
leafType - 生成叶结点的函数
189+
errType - 误差估计函数
190+
ops - 用户定义的参数构成的元组
191+
192+
Returns:
193+
retTree - 构建的回归树
194+
195+
Modify:
196+
2018-08-01
197+
"""
198+
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
199+
# 选择最佳切分特征和特征值
200+
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
201+
# 如果没有特征,则返回特征值
202+
if feat == None:
203+
return val
204+
# 回归树
205+
retTree = {}
206+
# 分割特征索引
207+
retTree['spInd'] = feat
208+
# 分割标准
209+
retTree['spVal'] = val
210+
# 分成左数据集和右数据集
211+
lSet, rSet = binSplitDataSet(dataSet, feat, val)
212+
# 创建左子树和右子树 递归
213+
retTree['left'] = createTree(lSet, leafType, errType, ops)
214+
retTree['right'] = createTree(rSet, leafType, errType, ops)
215+
return retTree
216+
217+
218+
"""
219+
函数说明:判断测试输入变量是否是一颗树
220+
树是通过字典存储的
221+
222+
Parameters:
223+
obj - 测试对象
224+
225+
Returns:
226+
是否是一颗树
227+
228+
Modify:
229+
2018-08-01
230+
"""
231+
def isTree(obj):
232+
return (type(obj).__name__ == 'dict')
233+
234+
235+
"""
236+
函数说明:对树进行塌陷处理(即返回树平均值)
237+
238+
Parameters:
239+
tree - 树
240+
241+
Returns:
242+
树的平均值
243+
244+
Modify:
245+
2018-08-01
246+
"""
247+
def getMean(tree):
248+
if isTree(tree['right']):
249+
tree['right'] = getMean(tree['right'])
250+
if isTree(tree['left']):
251+
tree['left'] = getMean(tree['left'])
252+
return (tree['left'] + tree['right']) / 2.0
253+
254+
255+
"""
256+
函数说明:后剪枝
257+
258+
Parameters:
259+
tree - 树
260+
testData - 测试集
261+
262+
Returns:
263+
264+
265+
Modify:
266+
2018-08-01
267+
"""
268+
def prune(tree, testData):
269+
# 如果测试集为空,则对树进行塌陷处理
270+
if np.shape(testData)[0] == 0:
271+
return getMean(tree)
272+
# 如果有左子树或者右子树,则切分数据集
273+
if (isTree(tree['right']) or isTree(tree['left'])):
274+
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
275+
# 处理左子树(剪枝)
276+
if isTree(tree['left']):
277+
tree['left'] = prune(tree['left'], lSet)
278+
# 处理右子树(剪枝)
279+
if isTree(tree['right']):
280+
tree['right'] = prune(tree['right'], rSet)
281+
# 如果当前节点的左右结点为叶结点
282+
if not isTree(tree['left']) and not isTree(tree['right']):
283+
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
284+
# 计算没有合并的误差
285+
errorNoMerge = np.sum(np.power(lSet[:, -1] - tree['left'], 2)) + np.sum(np.power(rSet[:, 1] - tree['right'], 2))
286+
# 计算合并的均值
287+
treeMean = (tree['left'] + tree['right']) / 2.0
288+
# 计算合并的误差
289+
errorMerge = np.sum(np.power(testData[:, -1] - treeMean, 2))
290+
# 如果合并的误差小于没有合并的误差,则合并
291+
if errorMerge < errorNoMerge:
292+
return treeMean
293+
else:
294+
return tree
295+
else:
296+
return tree
297+
298+
299+
if __name__ == '__main__':
300+
train_filename = 'ex2.txt'
301+
train_Data = loadDataSet(train_filename)
302+
train_Mat = np.mat(train_Data)
303+
tree = createTree(train_Mat)
304+
print("剪枝前:", tree)
305+
test_filename = 'ex2test.txt'
306+
test_Data = loadDataSet(test_filename)
307+
test_Mat = np.mat(test_Data)
308+
print("\n剪枝后:", prune(tree, test_Mat))
309+

0 commit comments

Comments
 (0)