Skip to content

Commit 1797e16

Browse files
author
louwill
authored
hard margin svm
1 parent f21728e commit 1797e16

File tree

2 files changed

+691
-0
lines changed

2 files changed

+691
-0
lines changed

hard margin svm/Hard_Margin_SVM.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# @Time : 2019/4/3 17:12
4+
# @Author : louwill
5+
# @File : Hard_Margin_Svm.py
6+
7+
8+
9+
10+
class Hard_Margin_SVM:
11+
def __init__(self, visualization=True):
12+
self.visualization = visualization
13+
self.colors = {1: 'r', -1: 'g'}
14+
if self.visualization:
15+
self.fig = plt.figure()
16+
self.ax = self.fig.add_subplot(1, 1, 1)
17+
18+
# 定义训练函数
19+
def train(self, data):
20+
self.data = data
21+
# 参数字典 { ||w||: [w,b] }
22+
opt_dict = {}
23+
24+
# 数据转换列表
25+
transforms = [[1, 1],
26+
[-1, 1],
27+
[-1, -1],
28+
[1, -1]]
29+
30+
# 从字典中获取所有数据
31+
all_data = []
32+
for yi in self.data:
33+
for featureset in self.data[yi]:
34+
for feature in featureset:
35+
all_data.append(feature)
36+
37+
# 获取数据最大最小值
38+
self.max_feature_value = max(all_data)
39+
self.min_feature_value = min(all_data)
40+
all_data = None
41+
42+
# 定义一个学习率(步长)列表
43+
step_sizes = [self.max_feature_value * 0.1,
44+
self.max_feature_value * 0.01,
45+
self.max_feature_value * 0.001
46+
]
47+
48+
# 参数b的范围设置
49+
b_range_multiple = 2
50+
b_multiple = 5
51+
latest_optimum = self.max_feature_value * 10
52+
53+
# 基于不同步长训练优化
54+
for step in step_sizes:
55+
w = np.array([latest_optimum, latest_optimum])
56+
# 凸优化
57+
optimized = False
58+
while not optimized:
59+
for b in np.arange(-1 * (self.max_feature_value * b_range_multiple),
60+
self.max_feature_value * b_range_multiple,
61+
step * b_multiple):
62+
for transformation in transforms:
63+
w_t = w * transformation
64+
found_option = True
65+
66+
for i in self.data:
67+
for xi in self.data[i]:
68+
yi = i
69+
if not yi * (np.dot(w_t, xi) + b) >= 1:
70+
found_option = False
71+
# print(xi,':',yi*(np.dot(w_t,xi)+b))
72+
73+
if found_option:
74+
opt_dict[np.linalg.norm(w_t)] = [w_t, b]
75+
76+
if w[0] < 0:
77+
optimized = True
78+
print('Optimized a step!')
79+
else:
80+
w = w - step
81+
82+
norms = sorted([n for n in opt_dict])
83+
# ||w|| : [w,b]
84+
opt_choice = opt_dict[norms[0]]
85+
self.w = opt_choice[0]
86+
self.b = opt_choice[1]
87+
latest_optimum = opt_choice[0][0] + step * 2
88+
89+
for i in self.data:
90+
for xi in self.data[i]:
91+
yi = i
92+
print(xi, ':', yi * (np.dot(self.w, xi) + self.b))
93+
94+
# 定义预测函数
95+
96+
def predict(self, features):
97+
# sign( x.w+b )
98+
classification = np.sign(np.dot(np.array(features), self.w) + self.b)
99+
if classification != 0 and self.visualization:
100+
self.ax.scatter(features[0], features[1], s=200, marker='^', c=self.colors[classification])
101+
return classification
102+
103+
# 定义结果绘图函数
104+
def visualize(self):
105+
[[self.ax.scatter(x[0], x[1], s=100, color=self.colors[i]) for x in data_dict[i]] for i in data_dict]
106+
107+
# hyperplane = x.w+b
108+
# v = x.w+b
109+
# psv = 1
110+
# nsv = -1
111+
# dec = 0
112+
# 定义线性超平面
113+
def hyperplane(x, w, b, v):
114+
return (-w[0] * x - b + v) / w[1]
115+
116+
datarange = (self.min_feature_value * 0.9, self.max_feature_value * 1.1)
117+
hyp_x_min = datarange[0]
118+
hyp_x_max = datarange[1]
119+
120+
# (w.x+b) = 1
121+
# 正支持向量
122+
psv1 = hyperplane(hyp_x_min, self.w, self.b, 1)
123+
psv2 = hyperplane(hyp_x_max, self.w, self.b, 1)
124+
self.ax.plot([hyp_x_min, hyp_x_max], [psv1, psv2], 'k')
125+
126+
# (w.x+b) = -1
127+
# 负支持向量
128+
nsv1 = hyperplane(hyp_x_min, self.w, self.b, -1)
129+
nsv2 = hyperplane(hyp_x_max, self.w, self.b, -1)
130+
self.ax.plot([hyp_x_min, hyp_x_max], [nsv1, nsv2], 'k')
131+
132+
# (w.x+b) = 0
133+
# 线性分隔超平面
134+
db1 = hyperplane(hyp_x_min, self.w, self.b, 0)
135+
db2 = hyperplane(hyp_x_max, self.w, self.b, 0)
136+
self.ax.plot([hyp_x_min, hyp_x_max], [db1, db2], 'y--')
137+
138+
plt.show()
139+
140+
141+
data_dict = {-1: np.array([[1, 7],
142+
[2, 8],
143+
[3, 8], ]),
144+
145+
1: np.array([[5, 1],
146+
[6, -1],
147+
[7, 3], ])}
148+
149+
svm = Hard_Margin_SVM()
150+
svm.train(data=data_dict)
151+
152+
predict_us = [[0, 10],
153+
[1, 3],
154+
[3, 4],
155+
[3, 5],
156+
[5, 5],
157+
[5, 6],
158+
[6, -5],
159+
[5, 8],
160+
[2, 5],
161+
[8, -3]]
162+
163+
for p in predict_us:
164+
svm.predict(p)
165+
166+
svm.visualize()

0 commit comments

Comments
 (0)