Skip to content

Commit 86d7508

Browse files
authored
Add files via upload
1 parent 9a93430 commit 86d7508

File tree

2 files changed

+617
-0
lines changed

2 files changed

+617
-0
lines changed

charpter15_random_forest/cart.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import numpy as np
2+
from utils import feature_split, calculate_gini
3+
4+
### 定义树结点
5+
class TreeNode():
6+
def __init__(self, feature_i=None, threshold=None,
7+
leaf_value=None, left_branch=None, right_branch=None):
8+
# 特征索引
9+
self.feature_i = feature_i
10+
# 特征划分阈值
11+
self.threshold = threshold
12+
# 叶子节点取值
13+
self.leaf_value = leaf_value
14+
# 左子树
15+
self.left_branch = left_branch
16+
# 右子树
17+
self.right_branch = right_branch
18+
19+
20+
### 定义二叉决策树
21+
class BinaryDecisionTree(object):
22+
### 决策树初始参数
23+
def __init__(self, min_samples_split=2, min_gini_impurity=999,
24+
max_depth=float("inf"), loss=None):
25+
# 根结点
26+
self.root = None
27+
# 节点最小分裂样本数
28+
self.min_samples_split = min_samples_split
29+
# 节点初始化基尼不纯度
30+
self.min_gini_impurity = min_gini_impurity
31+
# 树最大深度
32+
self.max_depth = max_depth
33+
# 基尼不纯度计算函数
34+
self.gini_impurity_calculation = None
35+
# 叶子节点值预测函数
36+
self._leaf_value_calculation = None
37+
# 损失函数
38+
self.loss = loss
39+
40+
### 决策树拟合函数
41+
def fit(self, X, y, loss=None):
42+
# 递归构建决策树
43+
self.root = self._build_tree(X, y)
44+
self.loss=None
45+
46+
### 决策树构建函数
47+
def _build_tree(self, X, y, current_depth=0):
48+
# 初始化最小基尼不纯度
49+
init_gini_impurity = 999
50+
# 初始化最佳特征索引和阈值
51+
best_criteria = None
52+
# 初始化数据子集
53+
best_sets = None
54+
55+
if len(np.shape(y)) == 1:
56+
y = np.expand_dims(y, axis=1)
57+
58+
# 合并输入和标签
59+
Xy = np.concatenate((X, y), axis=1)
60+
# 获取样本数和特征数
61+
n_samples, n_features = X.shape
62+
# 设定决策树构建条件
63+
# 训练样本数量大于节点最小分裂样本数且当前树深度小于最大深度
64+
if n_samples >= self.min_samples_split and current_depth <= self.max_depth:
65+
# 遍历计算每个特征的基尼不纯度
66+
for feature_i in range(n_features):
67+
# 获取第i特征的所有取值
68+
feature_values = np.expand_dims(X[:, feature_i], axis=1)
69+
# 获取第i个特征的唯一取值
70+
unique_values = np.unique(feature_values)
71+
72+
# 遍历取值并寻找最佳特征分裂阈值
73+
for threshold in unique_values:
74+
# 特征节点二叉分裂
75+
Xy1, Xy2 = feature_split(Xy, feature_i, threshold)
76+
# 如果分裂后的子集大小都不为0
77+
if len(Xy1) > 0 and len(Xy2) > 0:
78+
# 获取两个子集的标签值
79+
y1 = Xy1[:, n_features:]
80+
y2 = Xy2[:, n_features:]
81+
82+
# 计算基尼不纯度
83+
impurity = self.impurity_calculation(y, y1, y2)
84+
85+
# 获取最小基尼不纯度
86+
# 最佳特征索引和分裂阈值
87+
if impurity < init_gini_impurity:
88+
init_gini_impurity = impurity
89+
best_criteria = {"feature_i": feature_i, "threshold": threshold}
90+
best_sets = {
91+
"leftX": Xy1[:, :n_features],
92+
"lefty": Xy1[:, n_features:],
93+
"rightX": Xy2[:, :n_features],
94+
"righty": Xy2[:, n_features:]
95+
}
96+
97+
# 如果计算的最小不纯度小于设定的最小不纯度
98+
if init_gini_impurity < self.min_gini_impurity:
99+
# 分别构建左右子树
100+
left_branch = self._build_tree(best_sets["leftX"], best_sets["lefty"], current_depth + 1)
101+
right_branch = self._build_tree(best_sets["rightX"], best_sets["righty"], current_depth + 1)
102+
return TreeNode(feature_i=best_criteria["feature_i"], threshold=best_criteria["threshold"], left_branch=left_branch, right_branch=right_branch)
103+
104+
# 计算叶子计算取值
105+
leaf_value = self._leaf_value_calculation(y)
106+
return TreeNode(leaf_value=leaf_value)
107+
108+
### 定义二叉树值预测函数
109+
def predict_value(self, x, tree=None):
110+
if tree is None:
111+
tree = self.root
112+
# 如果叶子节点已有值,则直接返回已有值
113+
if tree.leaf_value is not None:
114+
return tree.leaf_value
115+
# 选择特征并获取特征值
116+
feature_value = x[tree.feature_i]
117+
# 判断落入左子树还是右子树
118+
branch = tree.right_branch
119+
if isinstance(feature_value, int) or isinstance(feature_value, float):
120+
if feature_value >= tree.threshold:
121+
branch = tree.left_branch
122+
elif feature_value == tree.threshold:
123+
branch = tree.right_branch
124+
# 测试子集
125+
return self.predict_value(x, branch)
126+
127+
### 数据集预测函数
128+
def predict(self, X):
129+
y_pred = [self.predict_value(sample) for sample in X]
130+
return y_pred
131+
132+
133+
134+
class ClassificationTree(BinaryDecisionTree):
135+
### 定义基尼不纯度计算过程
136+
def _calculate_gini_impurity(self, y, y1, y2):
137+
p = len(y1) / len(y)
138+
gini = calculate_gini(y)
139+
gini_impurity = p * calculate_gini(y1) + (1-p) * calculate_gini(y2)
140+
return gini_impurity
141+
142+
### 多数投票
143+
def _majority_vote(self, y):
144+
most_common = None
145+
max_count = 0
146+
for label in np.unique(y):
147+
# 统计多数
148+
count = len(y[y == label])
149+
if count > max_count:
150+
most_common = label
151+
max_count = count
152+
return most_common
153+
154+
# 分类树拟合
155+
def fit(self, X, y):
156+
self.impurity_calculation = self._calculate_gini_impurity
157+
self._leaf_value_calculation = self._majority_vote
158+
super(ClassificationTree, self).fit(X, y)
159+
160+
161+
### CART回归树
162+
class RegressionTree(BinaryDecisionTree):
163+
# 计算方差减少量
164+
def _calculate_variance_reduction(self, y, y1, y2):
165+
var_tot = np.var(y, axis=0)
166+
var_y1 = np.var(y1, axis=0)
167+
var_y2 = np.var(y2, axis=0)
168+
frac_1 = len(y1) / len(y)
169+
frac_2 = len(y2) / len(y)
170+
# 计算方差减少量
171+
variance_reduction = var_tot - (frac_1 * var_y1 + frac_2 * var_y2)
172+
return sum(variance_reduction)
173+
174+
# 节点值取平均
175+
def _mean_of_y(self, y):
176+
value = np.mean(y, axis=0)
177+
return value if len(value) > 1 else value[0]
178+
179+
# 回归树拟合
180+
def fit(self, X, y):
181+
self.impurity_calculation = self._calculate_variance_reduction
182+
self._leaf_value_calculation = self._mean_of_y
183+
super(RegressionTree, self).fit(X, y)

0 commit comments

Comments
 (0)