Skip to content

Commit 604500b

Browse files
Create mean_shift.py
1 parent ce30452 commit 604500b

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed

Chapter_11 MeanShift/mean_shift.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# coding:UTF-8
2+
'''
3+
Date:20160426
4+
@author: zhaozhiyong
5+
'''
6+
import math
7+
import numpy as np
8+
9+
MIN_DISTANCE = 0.000001 # mini error
10+
11+
def load_data(path, feature_num=2):
12+
'''导入数据
13+
input: path(string)文件的存储位置
14+
feature_num(int)特征的个数
15+
output: data(array)特征
16+
'''
17+
f = open(path) # 打开文件
18+
data = []
19+
for line in f.readlines():
20+
lines = line.strip().split("\t")
21+
data_tmp = []
22+
if len(lines) != feature_num: # 判断特征的个数是否正确
23+
continue
24+
for i in xrange(feature_num):
25+
data_tmp.append(float(lines[i]))
26+
data.append(data_tmp)
27+
f.close() # 关闭文件
28+
return data
29+
30+
def gaussian_kernel(distance, bandwidth):
31+
'''高斯核函数
32+
input: distance(mat):欧式距离
33+
bandwidth(int):核函数的带宽
34+
output: gaussian_val(mat):高斯函数值
35+
'''
36+
m = np.shape(distance)[0] # 样本个数
37+
right = np.mat(np.zeros((m, 1))) # mX1的矩阵
38+
for i in xrange(m):
39+
right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)
40+
right[i, 0] = np.exp(right[i, 0])
41+
left = 1 / (bandwidth * math.sqrt(2 * math.pi))
42+
43+
gaussian_val = left * right
44+
return gaussian_val
45+
46+
def shift_point(point, points, kernel_bandwidth):
47+
'''计算均值漂移点
48+
input: point(mat)需要计算的点
49+
points(array)所有的样本点
50+
kernel_bandwidth(int)核函数的带宽
51+
output: point_shifted(mat)漂移后的点
52+
'''
53+
points = np.mat(points)
54+
m = np.shape(points)[0] # 样本的个数
55+
# 计算距离
56+
point_distances = np.mat(np.zeros((m, 1)))
57+
for i in xrange(m):
58+
point_distances[i, 0] = euclidean_dist(point, points[i])
59+
60+
# 计算高斯核
61+
point_weights = gaussian_kernel(point_distances, kernel_bandwidth) # mX1的矩阵
62+
63+
# 计算分母
64+
all_sum = 0.0
65+
for i in xrange(m):
66+
all_sum += point_weights[i, 0]
67+
68+
# 均值偏移
69+
point_shifted = point_weights.T * points / all_sum
70+
return point_shifted
71+
72+
def euclidean_dist(pointA, pointB):
73+
'''计算欧式距离
74+
input: pointA(mat):A点的坐标
75+
pointB(mat):B点的坐标
76+
output: math.sqrt(total):两点之间的欧式距离
77+
'''
78+
# 计算pointA和pointB之间的欧式距离
79+
total = (pointA - pointB) * (pointA - pointB).T
80+
return math.sqrt(total) # 欧式距离
81+
82+
def group_points(mean_shift_points):
83+
'''计算所属的类别
84+
input: mean_shift_points(mat):漂移向量
85+
output: group_assignment(array):所属类别
86+
'''
87+
group_assignment = []
88+
m, n = np.shape(mean_shift_points)
89+
index = 0
90+
index_dict = {}
91+
for i in xrange(m):
92+
item = []
93+
for j in xrange(n):
94+
item.append(str(("%5.2f" % mean_shift_points[i, j])))
95+
96+
item_1 = "_".join(item)
97+
if item_1 not in index_dict:
98+
index_dict[item_1] = index
99+
index += 1
100+
101+
for i in xrange(m):
102+
item = []
103+
for j in xrange(n):
104+
item.append(str(("%5.2f" % mean_shift_points[i, j])))
105+
106+
item_1 = "_".join(item)
107+
group_assignment.append(index_dict[item_1])
108+
109+
return group_assignment
110+
111+
def train_mean_shift(points, kenel_bandwidth=2):
112+
'''训练Mean shift模型
113+
input: points(array):特征数据
114+
kenel_bandwidth(int):核函数的带宽
115+
output: points(mat):特征点
116+
mean_shift_points(mat):均值漂移点
117+
group(array):类别
118+
'''
119+
mean_shift_points = np.mat(points)
120+
max_min_dist = 1
121+
iteration = 0 # 训练的代数
122+
m = np.shape(mean_shift_points)[0] # 样本的个数
123+
need_shift = [True] * m # 标记是否需要漂移
124+
125+
# 计算均值漂移向量
126+
while max_min_dist > MIN_DISTANCE:
127+
max_min_dist = 0
128+
iteration += 1
129+
print "\titeration : " + str(iteration)
130+
for i in range(0, m):
131+
# 判断每一个样本点是否需要计算偏移均值
132+
if not need_shift[i]:
133+
continue
134+
p_new = mean_shift_points[i]
135+
p_new_start = p_new
136+
p_new = shift_point(p_new, points, kenel_bandwidth) # 对样本点进行漂移
137+
dist = euclidean_dist(p_new, p_new_start) # 计算该点与漂移后的点之间的距离
138+
139+
if dist > max_min_dist:
140+
max_min_dist = dist
141+
if dist < MIN_DISTANCE: # 不需要移动
142+
need_shift[i] = False
143+
144+
mean_shift_points[i] = p_new
145+
146+
# 计算最终的group
147+
group = group_points(mean_shift_points) # 计算所属的类别
148+
149+
return np.mat(points), mean_shift_points, group
150+
151+
def save_result(file_name, data):
152+
'''保存最终的计算结果
153+
input: file_name(string):存储的文件名
154+
data(mat):需要保存的文件
155+
'''
156+
f = open(file_name, "w")
157+
m, n = np.shape(data)
158+
for i in xrange(m):
159+
tmp = []
160+
for j in xrange(n):
161+
tmp.append(str(data[i, j]))
162+
f.write("\t".join(tmp) + "\n")
163+
f.close()
164+
165+
166+
if __name__ == "__main__":
167+
# 导入数据集
168+
print "----------1.load data ------------"
169+
data = load_data("data", 2)
170+
# 训练,h=2
171+
print "----------2.training ------------"
172+
points, shift_points, cluster = train_mean_shift(data, 2)
173+
# 保存所属的类别文件
174+
print "----------3.1.save sub ------------"
175+
save_result("sub_1", np.mat(cluster))
176+
print "----------3.2.save center ------------"
177+
# 保存聚类中心
178+
save_result("center_1", shift_points)
179+

0 commit comments

Comments
 (0)