Skip to content

Commit a98d2fd

Browse files
authored
Add files via upload
1 parent dcc2cad commit a98d2fd

File tree

4 files changed

+824
-0
lines changed

4 files changed

+824
-0
lines changed

CART_Project1/CART.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Wed Aug 1 15:34:57 2018
4+
5+
@author: wzy
6+
"""
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
10+
"""
11+
函数说明:加载数据
12+
13+
Parameters:
14+
fileName - 文件名
15+
16+
Returns:
17+
dataMat - 数据矩阵
18+
19+
Modify:
20+
2018-08-01
21+
"""
22+
def loadDataSet(fileName):
23+
dataMat = []
24+
fr = open(fileName)
25+
for line in fr.readlines():
26+
curLine = line.strip().split('\t')
27+
# 转换为float类型
28+
# map()是 Python 内置的高阶函数,它接收一个函数 f 和一个 list,并通过把函数 f 依次作用在 list 的每个元素上,得到一个新的 list 并返回。
29+
fltLine = list(map(float, curLine))
30+
dataMat.append(fltLine)
31+
return dataMat
32+
33+
34+
"""
35+
函数说明:根据特征切分数据集合
36+
37+
Parameters:
38+
dataSet - 数据集合
39+
feature - 带切分的特征
40+
value - 该特征的值
41+
42+
Returns:
43+
mat0 - 切分的数据集合0
44+
mat1 - 切分的数据集合1
45+
46+
Modify:
47+
2018-08-01
48+
"""
49+
def binSplitDataSet(dataSet, feature, value):
50+
mat0 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :]
51+
mat1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :]
52+
return mat0, mat1
53+
54+
55+
"""
56+
函数说明:生成叶结点
57+
58+
Parameters:
59+
dataSet - 数据集合
60+
61+
Returns:
62+
目标变量均值
63+
64+
Modify:
65+
2018-08-01
66+
"""
67+
def regLeaf(dataSet):
68+
return np.mean(dataSet[:, -1])
69+
70+
71+
"""
72+
函数说明:误差估计函数
73+
74+
Parameters:
75+
dataSet - 数据集合
76+
77+
Returns:
78+
目标变量的总方差
79+
80+
Modify:
81+
2018-08-01
82+
"""
83+
def regErr(dataSet):
84+
# var表示方差,即各项-均值的平方求和后再除以N
85+
return np.var(dataSet[:, -1]) * np.shape(dataSet)[0]
86+
87+
88+
"""
89+
函数说明:找到数据的最佳二元切分方式函数
90+
91+
Parameters:
92+
dataSet - 数据集合
93+
leafType - 生成叶结点的函数
94+
errType - 误差估计函数
95+
ops - 用户定义的参数构成的元组
96+
97+
Returns:
98+
bestIndex - 最佳切分特征
99+
bestValue - 最佳特征值
100+
101+
Modify:
102+
2018-08-01
103+
"""
104+
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
105+
# tolS:允许的误差下降值
106+
tolS = ops[0]
107+
# tolN:切分的最小样本数
108+
tolN = ops[1]
109+
# 如果当前所有值相等,则退出(根据set的特性只保留不重复的元素)
110+
if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
111+
return None, leafType(dataSet)
112+
# 统计数据集合的行m和列n
113+
m, n = np.shape(dataSet)
114+
# 默认最后一个特征为最佳切分特征,计算其误差估计
115+
S = errType(dataSet)
116+
# 分别为最佳误差,最佳特征切分的索引值,最佳特征值
117+
bestS = float('inf')
118+
bestIndex = 0
119+
bestValue = 0
120+
# 遍历所有特征
121+
for featIndex in range(n-1):
122+
# 遍历所有特征值
123+
for splitVal in set(dataSet[:, featIndex].T.A.tolist()[0]):
124+
# 根据特征和特征值切分数据集
125+
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
126+
# 如果数据少于tolN,则退出剪枝操作
127+
if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
128+
continue
129+
# 计算误差估计,寻找newS的最小值
130+
newS = errType(mat0) + errType(mat1)
131+
# 如果误差估计更小,则更新特征索引值和特征值
132+
if newS < bestS:
133+
# 特征索引
134+
bestIndex = featIndex
135+
# 分割标准
136+
bestValue = splitVal
137+
# 更新目标函数的最小值
138+
bestS = newS
139+
# 如果误差减少不大则退出
140+
if (S - bestS) < tolS:
141+
return None, leafType(dataSet)
142+
# 根据最佳的切分特征和特征值切分数据集合
143+
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
144+
# 如果切分出的数据集很小则退出
145+
if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
146+
return None, leafType(dataSet)
147+
# 返回最佳切分特征和特征值
148+
return bestIndex, bestValue
149+
150+
151+
"""
152+
函数说明:树构建函数
153+
154+
Parameters:
155+
dataSet - 数据集合
156+
leafType - 生成叶结点的函数
157+
errType - 误差估计函数
158+
ops - 用户定义的参数构成的元组
159+
160+
Returns:
161+
retTree - 构建的回归树
162+
163+
Modify:
164+
2018-08-01
165+
"""
166+
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
167+
# 选择最佳切分特征和特征值
168+
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
169+
# 如果没有特征,则返回特征值
170+
if feat == None:
171+
return val
172+
# 回归树
173+
retTree = {}
174+
# 分割特征索引
175+
retTree['spInd'] = feat
176+
# 分割标准
177+
retTree['spVal'] = val
178+
# 分成左数据集和右数据集
179+
lSet, rSet = binSplitDataSet(dataSet, feat, val)
180+
# 创建左子树和右子树 递归
181+
retTree['left'] = createTree(lSet, leafType, errType, ops)
182+
retTree['right'] = createTree(rSet, leafType, errType, ops)
183+
return retTree
184+
185+
186+
"""
187+
函数说明:绘制数据集
188+
189+
Parameters:
190+
fileName - 文件名
191+
192+
Returns:
193+
None
194+
195+
Modify:
196+
2018-08-01
197+
"""
198+
def plotDataSet(filename):
199+
dataMat = loadDataSet(filename)
200+
n = len(dataMat)
201+
xcord = []
202+
ycord = []
203+
# 样本点
204+
for i in range(n):
205+
xcord.append(dataMat[i][0])
206+
ycord.append(dataMat[i][1])
207+
fig = plt.figure()
208+
ax = fig.add_subplot(111)
209+
# 绘制样本点
210+
ax.scatter(xcord, ycord, s=20, c='blue', alpha=.5)
211+
plt.title('DataSet')
212+
plt.xlabel('X')
213+
plt.show()
214+
215+
216+
if __name__ == '__main__':
217+
# filename = 'ex00.txt'
218+
# plotDataSet(filename)
219+
myData = loadDataSet('ex2.txt')
220+
myMat = np.mat(myData)
221+
print(createTree(myMat))
222+
# print(feat)
223+
# print(val)
224+

0 commit comments

Comments
 (0)