Skip to content

Commit 4e6145f

Browse files
authored
Add files via upload
1 parent 908e3d6 commit 4e6145f

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

DecisionTree_Project2/DecisionTree.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Tue Jul 17 21:52:06 2018
4+
数据的Labels依次是age、prescript、astigmatic、tearRate、class
5+
年龄、症状、是否散光、眼泪数量、分类标签
6+
@author: wzy
7+
"""
8+
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
9+
import pandas as pd
10+
import numpy as np
11+
import pydotplus
12+
from sklearn.externals.six import StringIO
13+
from sklearn import tree
14+
15+
if __name__ == '__main__':
16+
# 加载文件
17+
with open('lenses.txt') as fr:
18+
# 处理文件,去掉每行两头的空白符,以\t分隔每个数据
19+
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
20+
# 提取每组数据的类别,保存在列表里
21+
lenses_targt = []
22+
for each in lenses:
23+
# 存储Label到lenses_targt中
24+
lenses_targt.append([each[-1]])
25+
# 特征标签
26+
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
27+
# 保存lenses数据的临时列表
28+
lenses_list = []
29+
# 保存lenses数据的字典,用于生成pandas
30+
lenses_dict = {}
31+
# 提取信息,生成字典
32+
for each_label in lensesLabels:
33+
for each in lenses:
34+
# index方法用于从列表中找出某个值第一个匹配项的索引位置
35+
lenses_list.append(each[lensesLabels.index(each_label)])
36+
lenses_dict[each_label] = lenses_list
37+
lenses_list = []
38+
# 打印字典信息
39+
# print(lenses_dict)
40+
# 生成pandas.DataFrame用于对象的创建
41+
lenses_pd = pd.DataFrame(lenses_dict)
42+
# 打印数据
43+
# print(lenses_pd)
44+
# 创建LabelEncoder对象
45+
le = LabelEncoder()
46+
# 为每一列序列化
47+
for col in lenses_pd.columns:
48+
# fit_transform()干了两件事:fit找到数据转换规则,并将数据标准化
49+
# transform()直接把转换规则拿来用,需要先进行fit
50+
# transform函数是一定可以替换为fit_transform函数的,fit_transform函数不能替换为transform函数
51+
lenses_pd[col] = le.fit_transform(lenses_pd[col])
52+
# 打印归一化的结果
53+
# print(lenses_pd)
54+
# 创建DecisionTreeClassifier()类
55+
clf = tree.DecisionTreeClassifier(criterion='entropy', max_depth=4)
56+
# 使用数据构造决策树
57+
# fit(X,y):Build a decision tree classifier from the training set(X,y)
58+
# 所有的sklearn的API必须先fit
59+
clf = clf.fit(lenses_pd.values.tolist(), lenses_targt)
60+
dot_data = StringIO()
61+
# 绘制决策树
62+
tree.export_graphviz(clf, out_file=dot_data, feature_names=lenses_pd.keys(),
63+
class_names=clf.classes_, filled=True, rounded=True,
64+
special_characters=True)
65+
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
66+
# 保存绘制好的决策树,以PDF的形式存储。
67+
graph.write_pdf("tree.pdf")
68+
#预测
69+
print(clf.predict([[1,1,1,0]]))
70+

DecisionTree_Project2/lenses.txt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
young myope no reduced no lenses
2+
young myope no normal soft
3+
young myope yes reduced no lenses
4+
young myope yes normal hard
5+
young hyper no reduced no lenses
6+
young hyper no normal soft
7+
young hyper yes reduced no lenses
8+
young hyper yes normal hard
9+
pre myope no reduced no lenses
10+
pre myope no normal soft
11+
pre myope yes reduced no lenses
12+
pre myope yes normal hard
13+
pre hyper no reduced no lenses
14+
pre hyper no normal soft
15+
pre hyper yes reduced no lenses
16+
pre hyper yes normal no lenses
17+
presbyopic myope no reduced no lenses
18+
presbyopic myope no normal no lenses
19+
presbyopic myope yes reduced no lenses
20+
presbyopic myope yes normal hard
21+
presbyopic hyper no reduced no lenses
22+
presbyopic hyper no normal soft
23+
presbyopic hyper yes reduced no lenses
24+
presbyopic hyper yes normal no lenses

DecisionTree_Project2/tree.pdf

21.5 KB
Binary file not shown.

0 commit comments

Comments
 (0)