1+ #!/usr/bin/env python 
2+ # -*- coding: utf-8 -*- 
3+ # @Time    : 2019/4/3 17:12 
4+ # @Author  : louwill 
5+ # @File    : Hard_Margin_Svm.py 
6+ 7+ 
8+ 
9+ 
10+ class  Hard_Margin_SVM :
11+     def  __init__ (self , visualization = True ):
12+         self .visualization  =  visualization 
13+         self .colors  =  {1 : 'r' , - 1 : 'g' }
14+         if  self .visualization :
15+             self .fig  =  plt .figure ()
16+             self .ax  =  self .fig .add_subplot (1 , 1 , 1 )
17+     
18+     # 定义训练函数 
19+     def  train (self , data ):
20+         self .data  =  data 
21+         # 参数字典 { ||w||: [w,b] } 
22+         opt_dict  =  {}
23+         
24+         # 数据转换列表 
25+         transforms  =  [[1 , 1 ],
26+                       [- 1 , 1 ],
27+                       [- 1 , - 1 ],
28+                       [1 , - 1 ]]
29+         
30+         # 从字典中获取所有数据 
31+         all_data  =  []
32+         for  yi  in  self .data :
33+             for  featureset  in  self .data [yi ]:
34+                 for  feature  in  featureset :
35+                     all_data .append (feature )
36+         
37+         # 获取数据最大最小值 
38+         self .max_feature_value  =  max (all_data )
39+         self .min_feature_value  =  min (all_data )
40+         all_data  =  None 
41+         
42+         # 定义一个学习率(步长)列表 
43+         step_sizes  =  [self .max_feature_value  *  0.1 ,
44+                       self .max_feature_value  *  0.01 ,
45+                       self .max_feature_value  *  0.001 
46+                       ]
47+         
48+         # 参数b的范围设置 
49+         b_range_multiple  =  2 
50+         b_multiple  =  5 
51+         latest_optimum  =  self .max_feature_value  *  10 
52+         
53+         # 基于不同步长训练优化 
54+         for  step  in  step_sizes :
55+             w  =  np .array ([latest_optimum , latest_optimum ])
56+             # 凸优化 
57+             optimized  =  False 
58+             while  not  optimized :
59+                 for  b  in  np .arange (- 1  *  (self .max_feature_value  *  b_range_multiple ),
60+                                    self .max_feature_value  *  b_range_multiple ,
61+                                    step  *  b_multiple ):
62+                     for  transformation  in  transforms :
63+                         w_t  =  w  *  transformation 
64+                         found_option  =  True 
65+                         
66+                         for  i  in  self .data :
67+                             for  xi  in  self .data [i ]:
68+                                 yi  =  i 
69+                                 if  not  yi  *  (np .dot (w_t , xi ) +  b ) >=  1 :
70+                                     found_option  =  False 
71+                                     # print(xi,':',yi*(np.dot(w_t,xi)+b)) 
72+                         
73+                         if  found_option :
74+                             opt_dict [np .linalg .norm (w_t )] =  [w_t , b ]
75+                 
76+                 if  w [0 ] <  0 :
77+                     optimized  =  True 
78+                     print ('Optimized a step!' )
79+                 else :
80+                     w  =  w  -  step 
81+             
82+             norms  =  sorted ([n  for  n  in  opt_dict ])
83+             # ||w|| : [w,b] 
84+             opt_choice  =  opt_dict [norms [0 ]]
85+             self .w  =  opt_choice [0 ]
86+             self .b  =  opt_choice [1 ]
87+             latest_optimum  =  opt_choice [0 ][0 ] +  step  *  2 
88+         
89+         for  i  in  self .data :
90+             for  xi  in  self .data [i ]:
91+                 yi  =  i 
92+                 print (xi , ':' , yi  *  (np .dot (self .w , xi ) +  self .b ))
93+                 
94+                 # 定义预测函数 
95+     
96+     def  predict (self , features ):
97+         # sign( x.w+b ) 
98+         classification  =  np .sign (np .dot (np .array (features ), self .w ) +  self .b )
99+         if  classification  !=  0  and  self .visualization :
100+             self .ax .scatter (features [0 ], features [1 ], s = 200 , marker = '^' , c = self .colors [classification ])
101+         return  classification 
102+     
103+     # 定义结果绘图函数 
104+     def  visualize (self ):
105+         [[self .ax .scatter (x [0 ], x [1 ], s = 100 , color = self .colors [i ]) for  x  in  data_dict [i ]] for  i  in  data_dict ]
106+         
107+         # hyperplane = x.w+b 
108+         # v = x.w+b 
109+         # psv = 1 
110+         # nsv = -1 
111+         # dec = 0 
112+         # 定义线性超平面 
113+         def  hyperplane (x , w , b , v ):
114+             return  (- w [0 ] *  x  -  b  +  v ) /  w [1 ]
115+         
116+         datarange  =  (self .min_feature_value  *  0.9 , self .max_feature_value  *  1.1 )
117+         hyp_x_min  =  datarange [0 ]
118+         hyp_x_max  =  datarange [1 ]
119+         
120+         # (w.x+b) = 1 
121+         # 正支持向量 
122+         psv1  =  hyperplane (hyp_x_min , self .w , self .b , 1 )
123+         psv2  =  hyperplane (hyp_x_max , self .w , self .b , 1 )
124+         self .ax .plot ([hyp_x_min , hyp_x_max ], [psv1 , psv2 ], 'k' )
125+         
126+         # (w.x+b) = -1 
127+         # 负支持向量 
128+         nsv1  =  hyperplane (hyp_x_min , self .w , self .b , - 1 )
129+         nsv2  =  hyperplane (hyp_x_max , self .w , self .b , - 1 )
130+         self .ax .plot ([hyp_x_min , hyp_x_max ], [nsv1 , nsv2 ], 'k' )
131+         
132+         # (w.x+b) = 0 
133+         # 线性分隔超平面 
134+         db1  =  hyperplane (hyp_x_min , self .w , self .b , 0 )
135+         db2  =  hyperplane (hyp_x_max , self .w , self .b , 0 )
136+         self .ax .plot ([hyp_x_min , hyp_x_max ], [db1 , db2 ], 'y--' )
137+         
138+         plt .show ()
139+ 
140+ 
141+ data_dict  =  {- 1 : np .array ([[1 , 7 ],
142+                            [2 , 8 ],
143+                            [3 , 8 ], ]),
144+ 
145+              1 : np .array ([[5 , 1 ],
146+                           [6 , - 1 ],
147+                           [7 , 3 ], ])}
148+ 
149+ svm  =  Hard_Margin_SVM ()
150+ svm .train (data = data_dict )
151+ 
152+ predict_us  =  [[0 , 10 ],
153+               [1 , 3 ],
154+               [3 , 4 ],
155+               [3 , 5 ],
156+               [5 , 5 ],
157+               [5 , 6 ],
158+               [6 , - 5 ],
159+               [5 , 8 ],
160+               [2 , 5 ],
161+               [8 , - 3 ]]
162+ 
163+ for  p  in  predict_us :
164+     svm .predict (p )
165+ 
166+ svm .visualize ()
0 commit comments