Skip to content

Commit 9e94039

Browse files
Create train_cart.py
1 parent 0f16720 commit 9e94039

File tree

1 file changed

+175
-0
lines changed

1 file changed

+175
-0
lines changed

Chapter_9 CART/train_cart.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# coding:UTF-8
2+
'''
3+
Date:20161030
4+
@author: zhaozhiyong
5+
'''
6+
import numpy as np
7+
import cPickle as pickle
8+
9+
class node:
10+
'''树的节点的类
11+
'''
12+
def __init__(self, fea=-1, value=None, results=None, right=None, left=None):
13+
self.fea = fea # 用于切分数据集的属性的列索引值
14+
self.value = value # 设置划分的值
15+
self.results = results # 存储叶节点的值
16+
self.right = right # 右子树
17+
self.left = left # 左子树
18+
19+
def load_data(data_file):
20+
'''导入训练数据
21+
input: data_file(string):保存训练数据的文件
22+
output: data(list):训练数据
23+
'''
24+
data = []
25+
f = open(data_file)
26+
for line in f.readlines():
27+
sample = []
28+
lines = line.strip().split("\t")
29+
for x in lines:
30+
sample.append(float(x)) # 转换成float格式
31+
data.append(sample)
32+
f.close()
33+
34+
return data
35+
36+
def split_tree(data, fea, value):
37+
'''根据特征fea中的值value将数据集data划分成左右子树
38+
input: data(list):训练样本
39+
fea(float):需要划分的特征index
40+
value(float):指定的划分的值
41+
output: (set_1, set_2)(tuple):左右子树的聚合
42+
'''
43+
set_1 = [] # 右子树的集合
44+
set_2 = [] # 左子树的集合
45+
for x in data:
46+
if x[fea] >= value:
47+
set_1.append(x)
48+
else:
49+
set_2.append(x)
50+
return (set_1, set_2)
51+
52+
def leaf(dataSet):
53+
'''计算叶节点的值
54+
input: dataSet(list):训练样本
55+
output: np.mean(data[:, -1])(float):均值
56+
'''
57+
data = np.mat(dataSet)
58+
return np.mean(data[:, -1])
59+
60+
def err_cnt(dataSet):
61+
'''回归树的划分指标
62+
input: dataSet(list):训练数据
63+
output: m*s^2(float):总方差
64+
'''
65+
data = np.mat(dataSet)
66+
return np.var(data[:, -1]) * np.shape(data)[0]
67+
68+
69+
def build_tree(data, min_sample, min_err):
70+
'''构建树
71+
input: data(list):训练样本
72+
min_sample(int):叶子节点中最少的样本数
73+
min_err(float):最小的error
74+
output: node:树的根结点
75+
'''
76+
# 构建决策树,函数返回该决策树的根节点
77+
if len(data) <= min_sample:
78+
return node(results=leaf(data))
79+
80+
# 1、初始化
81+
best_err = err_cnt(data)
82+
bestCriteria = None # 存储最佳切分属性以及最佳切分点
83+
bestSets = None # 存储切分后的两个数据集
84+
85+
# 2、开始构建CART回归树
86+
feature_num = len(data[0]) - 1
87+
for fea in range(0, feature_num):
88+
feature_values = {}
89+
for sample in data:
90+
feature_values[sample[fea]] = 1
91+
92+
for value in feature_values.keys():
93+
# 2.1、尝试划分
94+
(set_1, set_2) = split_tree(data, fea, value)
95+
if len(set_1) < 2 or len(set_2) < 2:
96+
continue
97+
# 2.2、计算划分后的error值
98+
now_err = err_cnt(set_1) + err_cnt(set_2)
99+
# 2.3、更新最优划分
100+
if now_err < best_err and len(set_1) > 0 and len(set_2) > 0:
101+
best_err = now_err
102+
bestCriteria = (fea, value)
103+
bestSets = (set_1, set_2)
104+
105+
# 3、判断划分是否结束
106+
if best_err > min_err:
107+
right = build_tree(bestSets[0], min_sample, min_err)
108+
left = build_tree(bestSets[1], min_sample, min_err)
109+
return node(fea=bestCriteria[0], value=bestCriteria[1], \
110+
right=right, left=left)
111+
else:
112+
return node(results=leaf(data)) # 返回当前的类别标签作为最终的类别标签
113+
114+
def predict(sample, tree):
115+
'''对每一个样本sample进行预测
116+
input: sample(list):样本
117+
tree:训练好的CART回归树模型
118+
output: results(float):预测值
119+
'''
120+
# 1、只是树根
121+
if tree.results != None:
122+
return tree.results
123+
else:
124+
# 2、有左右子树
125+
val_sample = sample[tree.fea] # fea处的值
126+
branch = None
127+
# 2.1、选择右子树
128+
if val_sample >= tree.value:
129+
branch = tree.right
130+
# 2.2、选择左子树
131+
else:
132+
branch = tree.left
133+
return predict(sample, branch)
134+
135+
def cal_error(data, tree):
136+
''' 评估CART回归树模型
137+
input: data(list):
138+
tree:训练好的CART回归树模型
139+
output: err/m(float):均方误差
140+
'''
141+
m = len(data) # 样本的个数
142+
n = len(data[0]) - 1 # 样本中特征的个数
143+
err = 0.0
144+
for i in xrange(m):
145+
tmp = []
146+
for j in xrange(n):
147+
tmp.append(data[i][j])
148+
pre = predict(tmp, tree) # 对样本计算其预测值
149+
# 计算残差
150+
err += (data[i][-1] - pre) * (data[i][-1] - pre)
151+
return err / m
152+
153+
def save_model(regression_tree, result_file):
154+
'''将训练好的CART回归树模型保存到本地
155+
input: regression_tree:回归树模型
156+
result_file(string):文件名
157+
'''
158+
with open(result_file, 'w') as f:
159+
pickle.dump(regression_tree, f)
160+
161+
if __name__ == "__main__":
162+
# 1、导入训练数据
163+
print "----------- 1、load data -------------"
164+
data = load_data("sine.txt")
165+
# 2、构建CART树
166+
print "----------- 2、build CART ------------"
167+
regression_tree = build_tree(data, 30, 0.3)
168+
# 3、评估CART树
169+
print "----------- 3、cal err -------------"
170+
err = cal_error(data, regression_tree)
171+
print "\t--------- err : ", err
172+
# 4、保存最终的CART模型
173+
print "----------- 4、save result -----------"
174+
save_model(regression_tree, "regression_tree")
175+

0 commit comments

Comments
 (0)