Skip to content

Commit 5bb765f

Browse files
committed
first release k means simulation
1 parent d696c71 commit 5bb765f

File tree

1 file changed

+85
-41
lines changed

1 file changed

+85
-41
lines changed

Mapping/kmean_clustering/kmean_clustering.py

Lines changed: 85 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
3-
Object clustering with k-mean algorithm
3+
Object clustering with k-means algorithm
44
55
author: Atsushi Sakai (@Atsushi_twi)
66
@@ -25,9 +25,23 @@ def __init__(self, x, y, nlabel):
2525
self.cy = [0.0 for _ in range(nlabel)]
2626

2727

28-
def init_clusters(rx, ry, nc):
28+
def kmeans_clustering(rx, ry, nc):
2929

3030
clusters = Clusters(rx, ry, nc)
31+
clusters = calc_centroid(clusters)
32+
33+
MAX_LOOP = 10
34+
DCOST_TH = 0.1
35+
pcost = 100.0
36+
for loop in range(MAX_LOOP):
37+
# print("Loop:", loop)
38+
clusters, cost = update_clusters(clusters)
39+
clusters = calc_centroid(clusters)
40+
41+
dcost = abs(cost - pcost)
42+
if dcost < DCOST_TH:
43+
break
44+
pcost = cost
3145

3246
return clusters
3347

@@ -62,36 +76,23 @@ def update_clusters(clusters):
6276
return clusters, cost
6377

6478

65-
def kmean_clustering(rx, ry, nc):
66-
67-
clusters = init_clusters(rx, ry, nc)
68-
clusters = calc_centroid(clusters)
79+
def calc_labeled_points(ic, clusters):
6980

70-
MAX_LOOP = 10
71-
DCOST_TH = 1.0
72-
pcost = 100.0
73-
for loop in range(MAX_LOOP):
74-
print("Loop:", loop)
75-
clusters, cost = update_clusters(clusters)
76-
clusters = calc_centroid(clusters)
81+
inds = np.array([i for i in range(clusters.ndata)
82+
if clusters.labels[i] == ic])
83+
tx = np.array(clusters.x)
84+
ty = np.array(clusters.y)
7785

78-
dcost = abs(cost - pcost)
79-
if dcost < DCOST_TH:
80-
break
81-
pcost = cost
86+
x = tx[inds]
87+
y = ty[inds]
8288

83-
return clusters
89+
return x, y
8490

8591

86-
def calc_raw_data():
92+
def calc_raw_data(cx, cy, npoints, rand_d):
8793

8894
rx, ry = [], []
8995

90-
cx = [0.0, 5.0]
91-
cy = [0.0, 5.0]
92-
npoints = 30
93-
rand_d = 3.0
94-
9596
for (icx, icy) in zip(cx, cy):
9697
for _ in range(npoints):
9798
rx.append(icx + rand_d * (random.random() - 0.5))
@@ -100,32 +101,75 @@ def calc_raw_data():
100101
return rx, ry
101102

102103

103-
def calc_labeled_points(ic, clusters):
104+
def update_positions(cx, cy):
104105

105-
inds = np.array([i for i in range(clusters.ndata)
106-
if clusters.labels[i] == ic])
107-
tx = np.array(clusters.x)
108-
ty = np.array(clusters.y)
106+
DX1 = 0.4
107+
DY1 = 0.5
109108

110-
x = tx[inds]
111-
y = ty[inds]
109+
cx[0] += DX1
110+
cy[0] += DY1
112111

113-
return x, y
112+
DX2 = -0.3
113+
DY2 = -0.5
114+
115+
cx[1] += DX2
116+
cy[1] += DY2
117+
118+
return cx, cy
119+
120+
121+
def calc_association(cx, cy, clusters):
122+
123+
inds = []
124+
125+
for ic in range(len(cx)):
126+
tcx = cx[ic]
127+
tcy = cy[ic]
128+
129+
dx = [icx - tcx for icx in clusters.cx]
130+
dy = [icy - tcy for icy in clusters.cy]
131+
132+
dlist = [math.sqrt(idx**2 + idy**2) for (idx, idy) in zip(dx, dy)]
133+
min_id = dlist.index(min(dlist))
134+
inds.append(min_id)
135+
136+
return inds
114137

115138

116139
def main():
117140
print(__file__ + " start!!")
118141

119-
rx, ry = calc_raw_data()
120-
142+
cx = [0.0, 8.0]
143+
cy = [0.0, 8.0]
144+
npoints = 10
145+
rand_d = 3.0
121146
ncluster = 2
122-
clusters = kmean_clustering(rx, ry, ncluster)
123-
124-
for ic in range(clusters.nlabel):
125-
x, y = calc_labeled_points(ic, clusters)
126-
plt.plot(x, y, "x")
127-
plt.plot(clusters.cx, clusters.cy, "o")
128-
plt.show()
147+
sim_time = 15.0
148+
dt = 1.0
149+
time = 0.0
150+
151+
while time <= sim_time:
152+
print("Time:", time)
153+
time += dt
154+
155+
# simulate objects
156+
cx, cy = update_positions(cx, cy)
157+
rx, ry = calc_raw_data(cx, cy, npoints, rand_d)
158+
159+
clusters = kmeans_clustering(rx, ry, ncluster)
160+
161+
# for animation
162+
plt.cla()
163+
inds = calc_association(cx, cy, clusters)
164+
for ic in inds:
165+
x, y = calc_labeled_points(ic, clusters)
166+
plt.plot(x, y, "x")
167+
plt.plot(cx, cy, "o")
168+
plt.xlim(-2.0, 10.0)
169+
plt.ylim(-2.0, 10.0)
170+
plt.pause(dt)
171+
172+
print("Done")
129173

130174

131175
if __name__ == '__main__':

0 commit comments

Comments
 (0)