Skip to content

Commit 3710796

Browse files
authored
Update CART.py
1 parent f6c4340 commit 3710796

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

CART_Project3/CART.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def prune(tree, testData):
282282
if not isTree(tree['left']) and not isTree(tree['right']):
283283
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
284284
# 计算没有合并的误差
285-
errorNoMerge = np.sum(np.power(lSet[:, -1] - tree['left'], 2)) + np.sum(np.power(rSet[:, 1] - tree['right'], 2))
285+
errorNoMerge = np.sum(np.power(lSet[:, -1] - tree['left'], 2)) + np.sum(np.power(rSet[:, -1] - tree['right'], 2))
286286
# 计算合并的均值
287287
treeMean = (tree['left'] + tree['right']) / 2.0
288288
# 计算合并的误差
@@ -306,4 +306,4 @@ def prune(tree, testData):
306306
test_Data = loadDataSet(test_filename)
307307
test_Mat = np.mat(test_Data)
308308
print("\n剪枝后:", prune(tree, test_Mat))
309-
309+

0 commit comments

Comments
 (0)