|
| 1 | +import numpy as np |
| 2 | +from utils import feature_split, calculate_gini |
| 3 | + |
| 4 | +### 定义树结点 |
| 5 | +class TreeNode(): |
| 6 | + def __init__(self, feature_i=None, threshold=None, |
| 7 | + leaf_value=None, left_branch=None, right_branch=None): |
| 8 | + # 特征索引 |
| 9 | + self.feature_i = feature_i |
| 10 | + # 特征划分阈值 |
| 11 | + self.threshold = threshold |
| 12 | + # 叶子节点取值 |
| 13 | + self.leaf_value = leaf_value |
| 14 | + # 左子树 |
| 15 | + self.left_branch = left_branch |
| 16 | + # 右子树 |
| 17 | + self.right_branch = right_branch |
| 18 | + |
| 19 | + |
| 20 | +### 定义二叉决策树 |
| 21 | +class BinaryDecisionTree(object): |
| 22 | + ### 决策树初始参数 |
| 23 | + def __init__(self, min_samples_split=2, min_gini_impurity=999, |
| 24 | + max_depth=float("inf"), loss=None): |
| 25 | + # 根结点 |
| 26 | + self.root = None |
| 27 | + # 节点最小分裂样本数 |
| 28 | + self.min_samples_split = min_samples_split |
| 29 | + # 节点初始化基尼不纯度 |
| 30 | + self.min_gini_impurity = min_gini_impurity |
| 31 | + # 树最大深度 |
| 32 | + self.max_depth = max_depth |
| 33 | + # 基尼不纯度计算函数 |
| 34 | + self.gini_impurity_calculation = None |
| 35 | + # 叶子节点值预测函数 |
| 36 | + self._leaf_value_calculation = None |
| 37 | + # 损失函数 |
| 38 | + self.loss = loss |
| 39 | + |
| 40 | + ### 决策树拟合函数 |
| 41 | + def fit(self, X, y, loss=None): |
| 42 | + # 递归构建决策树 |
| 43 | + self.root = self._build_tree(X, y) |
| 44 | + self.loss=None |
| 45 | + |
| 46 | + ### 决策树构建函数 |
| 47 | + def _build_tree(self, X, y, current_depth=0): |
| 48 | + # 初始化最小基尼不纯度 |
| 49 | + init_gini_impurity = 999 |
| 50 | + # 初始化最佳特征索引和阈值 |
| 51 | + best_criteria = None |
| 52 | + # 初始化数据子集 |
| 53 | + best_sets = None |
| 54 | + |
| 55 | + if len(np.shape(y)) == 1: |
| 56 | + y = np.expand_dims(y, axis=1) |
| 57 | + |
| 58 | + # 合并输入和标签 |
| 59 | + Xy = np.concatenate((X, y), axis=1) |
| 60 | + # 获取样本数和特征数 |
| 61 | + n_samples, n_features = X.shape |
| 62 | + # 设定决策树构建条件 |
| 63 | + # 训练样本数量大于节点最小分裂样本数且当前树深度小于最大深度 |
| 64 | + if n_samples >= self.min_samples_split and current_depth <= self.max_depth: |
| 65 | + # 遍历计算每个特征的基尼不纯度 |
| 66 | + for feature_i in range(n_features): |
| 67 | + # 获取第i特征的所有取值 |
| 68 | + feature_values = np.expand_dims(X[:, feature_i], axis=1) |
| 69 | + # 获取第i个特征的唯一取值 |
| 70 | + unique_values = np.unique(feature_values) |
| 71 | + |
| 72 | + # 遍历取值并寻找最佳特征分裂阈值 |
| 73 | + for threshold in unique_values: |
| 74 | + # 特征节点二叉分裂 |
| 75 | + Xy1, Xy2 = feature_split(Xy, feature_i, threshold) |
| 76 | + # 如果分裂后的子集大小都不为0 |
| 77 | + if len(Xy1) > 0 and len(Xy2) > 0: |
| 78 | + # 获取两个子集的标签值 |
| 79 | + y1 = Xy1[:, n_features:] |
| 80 | + y2 = Xy2[:, n_features:] |
| 81 | + |
| 82 | + # 计算基尼不纯度 |
| 83 | + impurity = self.impurity_calculation(y, y1, y2) |
| 84 | + |
| 85 | + # 获取最小基尼不纯度 |
| 86 | + # 最佳特征索引和分裂阈值 |
| 87 | + if impurity < init_gini_impurity: |
| 88 | + init_gini_impurity = impurity |
| 89 | + best_criteria = {"feature_i": feature_i, "threshold": threshold} |
| 90 | + best_sets = { |
| 91 | + "leftX": Xy1[:, :n_features], |
| 92 | + "lefty": Xy1[:, n_features:], |
| 93 | + "rightX": Xy2[:, :n_features], |
| 94 | + "righty": Xy2[:, n_features:] |
| 95 | + } |
| 96 | + |
| 97 | + # 如果计算的最小不纯度小于设定的最小不纯度 |
| 98 | + if init_gini_impurity < self.min_gini_impurity: |
| 99 | + # 分别构建左右子树 |
| 100 | + left_branch = self._build_tree(best_sets["leftX"], best_sets["lefty"], current_depth + 1) |
| 101 | + right_branch = self._build_tree(best_sets["rightX"], best_sets["righty"], current_depth + 1) |
| 102 | + return TreeNode(feature_i=best_criteria["feature_i"], threshold=best_criteria["threshold"], left_branch=left_branch, right_branch=right_branch) |
| 103 | + |
| 104 | + # 计算叶子计算取值 |
| 105 | + leaf_value = self._leaf_value_calculation(y) |
| 106 | + return TreeNode(leaf_value=leaf_value) |
| 107 | + |
| 108 | + ### 定义二叉树值预测函数 |
| 109 | + def predict_value(self, x, tree=None): |
| 110 | + if tree is None: |
| 111 | + tree = self.root |
| 112 | + # 如果叶子节点已有值,则直接返回已有值 |
| 113 | + if tree.leaf_value is not None: |
| 114 | + return tree.leaf_value |
| 115 | + # 选择特征并获取特征值 |
| 116 | + feature_value = x[tree.feature_i] |
| 117 | + # 判断落入左子树还是右子树 |
| 118 | + branch = tree.right_branch |
| 119 | + if isinstance(feature_value, int) or isinstance(feature_value, float): |
| 120 | + if feature_value >= tree.threshold: |
| 121 | + branch = tree.left_branch |
| 122 | + elif feature_value == tree.threshold: |
| 123 | + branch = tree.right_branch |
| 124 | + # 测试子集 |
| 125 | + return self.predict_value(x, branch) |
| 126 | + |
| 127 | + ### 数据集预测函数 |
| 128 | + def predict(self, X): |
| 129 | + y_pred = [self.predict_value(sample) for sample in X] |
| 130 | + return y_pred |
| 131 | + |
| 132 | + |
| 133 | + |
| 134 | +class ClassificationTree(BinaryDecisionTree): |
| 135 | + ### 定义基尼不纯度计算过程 |
| 136 | + def _calculate_gini_impurity(self, y, y1, y2): |
| 137 | + p = len(y1) / len(y) |
| 138 | + gini = calculate_gini(y) |
| 139 | + gini_impurity = p * calculate_gini(y1) + (1-p) * calculate_gini(y2) |
| 140 | + return gini_impurity |
| 141 | + |
| 142 | + ### 多数投票 |
| 143 | + def _majority_vote(self, y): |
| 144 | + most_common = None |
| 145 | + max_count = 0 |
| 146 | + for label in np.unique(y): |
| 147 | + # 统计多数 |
| 148 | + count = len(y[y == label]) |
| 149 | + if count > max_count: |
| 150 | + most_common = label |
| 151 | + max_count = count |
| 152 | + return most_common |
| 153 | + |
| 154 | + # 分类树拟合 |
| 155 | + def fit(self, X, y): |
| 156 | + self.impurity_calculation = self._calculate_gini_impurity |
| 157 | + self._leaf_value_calculation = self._majority_vote |
| 158 | + super(ClassificationTree, self).fit(X, y) |
| 159 | + |
| 160 | + |
| 161 | +### CART回归树 |
| 162 | +class RegressionTree(BinaryDecisionTree): |
| 163 | + # 计算方差减少量 |
| 164 | + def _calculate_variance_reduction(self, y, y1, y2): |
| 165 | + var_tot = np.var(y, axis=0) |
| 166 | + var_y1 = np.var(y1, axis=0) |
| 167 | + var_y2 = np.var(y2, axis=0) |
| 168 | + frac_1 = len(y1) / len(y) |
| 169 | + frac_2 = len(y2) / len(y) |
| 170 | + # 计算方差减少量 |
| 171 | + variance_reduction = var_tot - (frac_1 * var_y1 + frac_2 * var_y2) |
| 172 | + return sum(variance_reduction) |
| 173 | + |
| 174 | + # 节点值取平均 |
| 175 | + def _mean_of_y(self, y): |
| 176 | + value = np.mean(y, axis=0) |
| 177 | + return value if len(value) > 1 else value[0] |
| 178 | + |
| 179 | + # 回归树拟合 |
| 180 | + def fit(self, X, y): |
| 181 | + self.impurity_calculation = self._calculate_variance_reduction |
| 182 | + self._leaf_value_calculation = self._mean_of_y |
| 183 | + super(RegressionTree, self).fit(X, y) |
0 commit comments