1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Jul 17 21:52:06 2018
4
+ 数据的Labels依次是age、prescript、astigmatic、tearRate、class
5
+ 年龄、症状、是否散光、眼泪数量、分类标签
6
+ @author: wzy
7
+ """
8
+ from sklearn .preprocessing import LabelEncoder , OneHotEncoder
9
+ import pandas as pd
10
+ import numpy as np
11
+ import pydotplus
12
+ from sklearn .externals .six import StringIO
13
+ from sklearn import tree
14
+
15
+ if __name__ == '__main__' :
16
+ # 加载文件
17
+ with open ('lenses.txt' ) as fr :
18
+ # 处理文件,去掉每行两头的空白符,以\t分隔每个数据
19
+ lenses = [inst .strip ().split ('\t ' ) for inst in fr .readlines ()]
20
+ # 提取每组数据的类别,保存在列表里
21
+ lenses_targt = []
22
+ for each in lenses :
23
+ # 存储Label到lenses_targt中
24
+ lenses_targt .append ([each [- 1 ]])
25
+ # 特征标签
26
+ lensesLabels = ['age' , 'prescript' , 'astigmatic' , 'tearRate' ]
27
+ # 保存lenses数据的临时列表
28
+ lenses_list = []
29
+ # 保存lenses数据的字典,用于生成pandas
30
+ lenses_dict = {}
31
+ # 提取信息,生成字典
32
+ for each_label in lensesLabels :
33
+ for each in lenses :
34
+ # index方法用于从列表中找出某个值第一个匹配项的索引位置
35
+ lenses_list .append (each [lensesLabels .index (each_label )])
36
+ lenses_dict [each_label ] = lenses_list
37
+ lenses_list = []
38
+ # 打印字典信息
39
+ # print(lenses_dict)
40
+ # 生成pandas.DataFrame用于对象的创建
41
+ lenses_pd = pd .DataFrame (lenses_dict )
42
+ # 打印数据
43
+ # print(lenses_pd)
44
+ # 创建LabelEncoder对象
45
+ le = LabelEncoder ()
46
+ # 为每一列序列化
47
+ for col in lenses_pd .columns :
48
+ # fit_transform()干了两件事:fit找到数据转换规则,并将数据标准化
49
+ # transform()直接把转换规则拿来用,需要先进行fit
50
+ # transform函数是一定可以替换为fit_transform函数的,fit_transform函数不能替换为transform函数
51
+ lenses_pd [col ] = le .fit_transform (lenses_pd [col ])
52
+ # 打印归一化的结果
53
+ # print(lenses_pd)
54
+ # 创建DecisionTreeClassifier()类
55
+ clf = tree .DecisionTreeClassifier (criterion = 'entropy' , max_depth = 4 )
56
+ # 使用数据构造决策树
57
+ # fit(X,y):Build a decision tree classifier from the training set(X,y)
58
+ # 所有的sklearn的API必须先fit
59
+ clf = clf .fit (lenses_pd .values .tolist (), lenses_targt )
60
+ dot_data = StringIO ()
61
+ # 绘制决策树
62
+ tree .export_graphviz (clf , out_file = dot_data , feature_names = lenses_pd .keys (),
63
+ class_names = clf .classes_ , filled = True , rounded = True ,
64
+ special_characters = True )
65
+ graph = pydotplus .graph_from_dot_data (dot_data .getvalue ())
66
+ # 保存绘制好的决策树,以PDF的形式存储。
67
+ graph .write_pdf ("tree.pdf" )
68
+ #预测
69
+ print (clf .predict ([[1 ,1 ,1 ,0 ]]))
70
+
0 commit comments