Skip to content

Commit 0cbc95b

Browse files
authored
Add files via upload
1 parent 1a15943 commit 0cbc95b

File tree

4 files changed

+1253
-0
lines changed

4 files changed

+1253
-0
lines changed

charpter7_decision_tree/CART.ipynb

Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"### CART"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 1,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"import numpy as np\n",
17+
"from sklearn.model_selection import train_test_split\n",
18+
"from sklearn.metrics import accuracy_score, mean_squared_error\n",
19+
"from utils import feature_split, calculate_gini"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": 2,
25+
"metadata": {},
26+
"outputs": [],
27+
"source": [
28+
"### 定义树结点\n",
29+
"class TreeNode():\n",
30+
" def __init__(self, feature_i=None, threshold=None,\n",
31+
" leaf_value=None, left_branch=None, right_branch=None):\n",
32+
" # 特征索引\n",
33+
" self.feature_i = feature_i \n",
34+
" # 特征划分阈值\n",
35+
" self.threshold = threshold \n",
36+
" # 叶子节点取值\n",
37+
" self.leaf_value = leaf_value \n",
38+
" # 左子树\n",
39+
" self.left_branch = left_branch \n",
40+
" # 右子树\n",
41+
" self.right_branch = right_branch "
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": 3,
47+
"metadata": {},
48+
"outputs": [],
49+
"source": [
50+
"### 定义二叉决策树\n",
51+
"class BinaryDecisionTree(object):\n",
52+
" ### 决策树初始参数\n",
53+
" def __init__(self, min_samples_split=2, min_gini_impurity=999,\n",
54+
" max_depth=float(\"inf\"), loss=None):\n",
55+
" # 根结点\n",
56+
" self.root = None \n",
57+
" # 节点最小分裂样本数\n",
58+
" self.min_samples_split = min_samples_split\n",
59+
" # 节点初始化基尼不纯度\n",
60+
" self.mini_gini_impurity = min_gini_impurity\n",
61+
" # 树最大深度\n",
62+
" self.max_depth = max_depth\n",
63+
" # 基尼不纯度计算函数\n",
64+
" self.gini_impurity_calculation = None\n",
65+
" # 叶子节点值预测函数\n",
66+
" self._leaf_value_calculation = None\n",
67+
" # 损失函数\n",
68+
" self.loss = loss\n",
69+
"\n",
70+
" ### 决策树拟合函数\n",
71+
" def fit(self, X, y, loss=None):\n",
72+
" # 递归构建决策树\n",
73+
" self.root = self._build_tree(X, y)\n",
74+
" self.loss=None\n",
75+
"\n",
76+
" ### 决策树构建函数\n",
77+
" def _build_tree(self, X, y, current_depth=0):\n",
78+
" # 初始化最小基尼不纯度\n",
79+
" init_gini_impurity = 999\n",
80+
" # 初始化最佳特征索引和阈值\n",
81+
" best_criteria = None \n",
82+
" # 初始化数据子集\n",
83+
" best_sets = None \n",
84+
"\n",
85+
" # 合并输入和标签\n",
86+
" Xy = np.concatenate((X, y), axis=1)\n",
87+
" # 获取样本数和特征数\n",
88+
" n_samples, n_features = X.shape\n",
89+
" # 设定决策树构建条件\n",
90+
" # 训练样本数量大于节点最小分裂样本数且当前树深度小于最大深度\n",
91+
" if n_samples >= self.min_samples_split and current_depth <= self.max_depth:\n",
92+
" # 遍历计算每个特征的基尼不纯度\n",
93+
" for feature_i in range(n_features):\n",
94+
" # 获取第i特征的所有取值\n",
95+
" feature_values = np.expand_dims(X[:, feature_i], axis=1)\n",
96+
" # 获取第i个特征的唯一取值\n",
97+
" unique_values = np.unique(feature_values)\n",
98+
"\n",
99+
" # 遍历取值并寻找最佳特征分裂阈值\n",
100+
" for threshold in unique_values:\n",
101+
" # 特征节点二叉分裂\n",
102+
" Xy1, Xy2 = feature_split(Xy, feature_i, threshold)\n",
103+
" # 如果分裂后的子集大小都不为0\n",
104+
" if len(Xy1) > 0 and len(Xy2) > 0:\n",
105+
" # 获取两个子集的标签值\n",
106+
" y1 = Xy1[:, n_features:]\n",
107+
" y2 = Xy2[:, n_features:]\n",
108+
"\n",
109+
" # 计算基尼不纯度\n",
110+
" impurity = self.impurity_calculation(y, y1, y2)\n",
111+
"\n",
112+
" # 获取最小基尼不纯度\n",
113+
" # 最佳特征索引和分裂阈值\n",
114+
" if impurity < init_gini_impurity:\n",
115+
" init_gini_impurity = impurity\n",
116+
" best_criteria = {\"feature_i\": feature_i, \"threshold\": threshold}\n",
117+
" best_sets = {\n",
118+
" \"leftX\": Xy1[:, :n_features], \n",
119+
" \"lefty\": Xy1[:, n_features:], \n",
120+
" \"rightX\": Xy2[:, :n_features], \n",
121+
" \"righty\": Xy2[:, n_features:] \n",
122+
" }\n",
123+
" \n",
124+
" # 如果计算的最小不纯度小于设定的最小不纯度\n",
125+
" if init_gini_impurity < self.mini_gini_impurity:\n",
126+
" # 分别构建左右子树\n",
127+
" left_branch = self._build_tree(best_sets[\"leftX\"], best_sets[\"lefty\"], current_depth + 1)\n",
128+
" right_branch = self._build_tree(best_sets[\"rightX\"], best_sets[\"righty\"], current_depth + 1)\n",
129+
" return TreeNode(feature_i=best_criteria[\"feature_i\"], threshold=best_criteria[\n",
130+
" \"threshold\"], left_branch=left_branch, right_branch=right_branch)\n",
131+
"\n",
132+
" # 计算叶子计算取值\n",
133+
" leaf_value = self._leaf_value_calculation(y)\n",
134+
"\n",
135+
" return TreeNode(leaf_value=leaf_value)\n",
136+
"\n",
137+
" ### 定义二叉树值预测函数\n",
138+
" def predict_value(self, x, tree=None):\n",
139+
" if tree is None:\n",
140+
" tree = self.root\n",
141+
"\n",
142+
" # 如果叶子节点已有值,则直接返回已有值\n",
143+
" if tree.leaf_value is not None:\n",
144+
" return tree.leaf_value\n",
145+
"\n",
146+
" # 选择特征并获取特征值\n",
147+
" feature_value = x[tree.feature_i]\n",
148+
"\n",
149+
" # 判断落入左子树还是右子树\n",
150+
" branch = tree.right_branch\n",
151+
" if isinstance(feature_value, int) or isinstance(feature_value, float):\n",
152+
" if feature_value >= tree.threshold:\n",
153+
" branch = tree.left_branch\n",
154+
" elif feature_value == tree.threshold:\n",
155+
" branch = tree.right_branch\n",
156+
"\n",
157+
" # 测试子集\n",
158+
" return self.predict_value(x, branch)\n",
159+
"\n",
160+
" ### 数据集预测函数\n",
161+
" def predict(self, X):\n",
162+
" y_pred = [self.predict_value(sample) for sample in X]\n",
163+
" return y_pred"
164+
]
165+
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": 4,
169+
"metadata": {},
170+
"outputs": [],
171+
"source": [
172+
"### CART回归树\n",
173+
"class RegressionTree(BinaryDecisionTree):\n",
174+
" def _calculate_variance_reduction(self, y, y1, y2):\n",
175+
" var_tot = np.var(y, axis=0)\n",
176+
" var_y1 = np.var(y1, axis=0)\n",
177+
" var_y2 = np.var(y2, axis=0)\n",
178+
" frac_1 = len(y1) / len(y)\n",
179+
" frac_2 = len(y2) / len(y)\n",
180+
" # 计算方差减少量\n",
181+
" variance_reduction = var_tot - (frac_1 * var_y1 + frac_2 * var_y2)\n",
182+
" \n",
183+
" return sum(variance_reduction)\n",
184+
"\n",
185+
" # 节点值取平均\n",
186+
" def _mean_of_y(self, y):\n",
187+
" value = np.mean(y, axis=0)\n",
188+
" return value if len(value) > 1 else value[0]\n",
189+
"\n",
190+
" def fit(self, X, y):\n",
191+
" self.impurity_calculation = self._calculate_variance_reduction\n",
192+
" self._leaf_value_calculation = self._mean_of_y\n",
193+
" super(RegressionTree, self).fit(X, y)"
194+
]
195+
},
196+
{
197+
"cell_type": "code",
198+
"execution_count": 5,
199+
"metadata": {},
200+
"outputs": [],
201+
"source": [
202+
"### CART决策树\n",
203+
"class ClassificationTree(BinaryDecisionTree):\n",
204+
" ### 定义基尼不纯度计算过程\n",
205+
" def _calculate_gini_impurity(self, y, y1, y2):\n",
206+
" p = len(y1) / len(y)\n",
207+
" gini = calculate_gini(y)\n",
208+
" gini_impurity = p * calculate_gini(y1) + (1-p) * calculate_gini(y2)\n",
209+
" return gini_impurity\n",
210+
" \n",
211+
" ### 多数投票\n",
212+
" def _majority_vote(self, y):\n",
213+
" most_common = None\n",
214+
" max_count = 0\n",
215+
" for label in np.unique(y):\n",
216+
" # 统计多数\n",
217+
" count = len(y[y == label])\n",
218+
" if count > max_count:\n",
219+
" most_common = label\n",
220+
" max_count = count\n",
221+
" return most_common\n",
222+
" \n",
223+
" # 分类树拟合\n",
224+
" def fit(self, X, y):\n",
225+
" self.impurity_calculation = self._calculate_gini_impurity\n",
226+
" self._leaf_value_calculation = self._majority_vote\n",
227+
" super(ClassificationTree, self).fit(X, y)"
228+
]
229+
},
230+
{
231+
"cell_type": "code",
232+
"execution_count": 6,
233+
"metadata": {},
234+
"outputs": [
235+
{
236+
"name": "stdout",
237+
"output_type": "stream",
238+
"text": [
239+
"0.9777777777777777\n"
240+
]
241+
}
242+
],
243+
"source": [
244+
"from sklearn import datasets\n",
245+
"data = datasets.load_iris()\n",
246+
"X, y = data.data, data.target\n",
247+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n",
248+
"clf = ClassificationTree()\n",
249+
"clf.fit(X_train, y_train)\n",
250+
"y_pred = clf.predict(X_test)\n",
251+
"\n",
252+
"print(accuracy_score(y_test, y_pred))"
253+
]
254+
},
255+
{
256+
"cell_type": "code",
257+
"execution_count": 7,
258+
"metadata": {},
259+
"outputs": [
260+
{
261+
"name": "stdout",
262+
"output_type": "stream",
263+
"text": [
264+
"1.0\n"
265+
]
266+
}
267+
],
268+
"source": [
269+
"from sklearn.tree import DecisionTreeClassifier\n",
270+
"clf = DecisionTreeClassifier()\n",
271+
"clf.fit(X_train, y_train)\n",
272+
"y_pred = clf.predict(X_test)\n",
273+
"\n",
274+
"print(accuracy_score(y_test, y_pred))"
275+
]
276+
},
277+
{
278+
"cell_type": "code",
279+
"execution_count": 8,
280+
"metadata": {},
281+
"outputs": [
282+
{
283+
"name": "stdout",
284+
"output_type": "stream",
285+
"text": [
286+
"Mean Squared Error: 134.4803289473684\n"
287+
]
288+
}
289+
],
290+
"source": [
291+
"from sklearn.datasets import load_boston\n",
292+
"X, y = load_boston(return_X_y=True)\n",
293+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n",
294+
"model = RegressionTree()\n",
295+
"model.fit(X_train, y_train)\n",
296+
"y_pred = model.predict(X_test)\n",
297+
"mse = mean_squared_error(y_test, y_pred)\n",
298+
"\n",
299+
"print(\"Mean Squared Error:\", mse)"
300+
]
301+
},
302+
{
303+
"cell_type": "code",
304+
"execution_count": 9,
305+
"metadata": {},
306+
"outputs": [
307+
{
308+
"name": "stdout",
309+
"output_type": "stream",
310+
"text": [
311+
"Mean Squared Error: 28.75368421052632\n"
312+
]
313+
}
314+
],
315+
"source": [
316+
"from sklearn.tree import DecisionTreeRegressor\n",
317+
"reg = DecisionTreeRegressor()\n",
318+
"reg.fit(X_train, y_train)\n",
319+
"y_pred = reg.predict(X_test)\n",
320+
"mse = mean_squared_error(y_test, y_pred)\n",
321+
"\n",
322+
"print(\"Mean Squared Error:\", mse)"
323+
]
324+
},
325+
{
326+
"cell_type": "code",
327+
"execution_count": null,
328+
"metadata": {},
329+
"outputs": [],
330+
"source": []
331+
}
332+
],
333+
"metadata": {
334+
"kernelspec": {
335+
"display_name": "Python 3",
336+
"language": "python",
337+
"name": "python3"
338+
},
339+
"language_info": {
340+
"codemirror_mode": {
341+
"name": "ipython",
342+
"version": 3
343+
},
344+
"file_extension": ".py",
345+
"mimetype": "text/x-python",
346+
"name": "python",
347+
"nbconvert_exporter": "python",
348+
"pygments_lexer": "ipython3",
349+
"version": "3.7.3"
350+
},
351+
"toc": {
352+
"base_numbering": 1,
353+
"nav_menu": {},
354+
"number_sections": true,
355+
"sideBar": true,
356+
"skip_h1_title": false,
357+
"title_cell": "Table of Contents",
358+
"title_sidebar": "Contents",
359+
"toc_cell": false,
360+
"toc_position": {},
361+
"toc_section_display": true,
362+
"toc_window_display": false
363+
}
364+
},
365+
"nbformat": 4,
366+
"nbformat_minor": 2
367+
}

0 commit comments

Comments
 (0)