11"""
22
3- Object clustering with k-mean algorithm
3+ Object clustering with k-means algorithm
44
55author: Atsushi Sakai (@Atsushi_twi)
66
@@ -25,9 +25,23 @@ def __init__(self, x, y, nlabel):
2525 self .cy = [0.0 for _ in range (nlabel )]
2626
2727
28- def init_clusters (rx , ry , nc ):
28+ def kmeans_clustering (rx , ry , nc ):
2929
3030 clusters = Clusters (rx , ry , nc )
31+ clusters = calc_centroid (clusters )
32+
33+ MAX_LOOP = 10
34+ DCOST_TH = 0.1
35+ pcost = 100.0
36+ for loop in range (MAX_LOOP ):
37+ # print("Loop:", loop)
38+ clusters , cost = update_clusters (clusters )
39+ clusters = calc_centroid (clusters )
40+
41+ dcost = abs (cost - pcost )
42+ if dcost < DCOST_TH :
43+ break
44+ pcost = cost
3145
3246 return clusters
3347
@@ -62,36 +76,23 @@ def update_clusters(clusters):
6276 return clusters , cost
6377
6478
65- def kmean_clustering (rx , ry , nc ):
66-
67- clusters = init_clusters (rx , ry , nc )
68- clusters = calc_centroid (clusters )
79+ def calc_labeled_points (ic , clusters ):
6980
70- MAX_LOOP = 10
71- DCOST_TH = 1.0
72- pcost = 100.0
73- for loop in range (MAX_LOOP ):
74- print ("Loop:" , loop )
75- clusters , cost = update_clusters (clusters )
76- clusters = calc_centroid (clusters )
81+ inds = np .array ([i for i in range (clusters .ndata )
82+ if clusters .labels [i ] == ic ])
83+ tx = np .array (clusters .x )
84+ ty = np .array (clusters .y )
7785
78- dcost = abs (cost - pcost )
79- if dcost < DCOST_TH :
80- break
81- pcost = cost
86+ x = tx [inds ]
87+ y = ty [inds ]
8288
83- return clusters
89+ return x , y
8490
8591
86- def calc_raw_data ():
92+ def calc_raw_data (cx , cy , npoints , rand_d ):
8793
8894 rx , ry = [], []
8995
90- cx = [0.0 , 5.0 ]
91- cy = [0.0 , 5.0 ]
92- npoints = 30
93- rand_d = 3.0
94-
9596 for (icx , icy ) in zip (cx , cy ):
9697 for _ in range (npoints ):
9798 rx .append (icx + rand_d * (random .random () - 0.5 ))
@@ -100,32 +101,75 @@ def calc_raw_data():
100101 return rx , ry
101102
102103
103- def calc_labeled_points ( ic , clusters ):
104+ def update_positions ( cx , cy ):
104105
105- inds = np .array ([i for i in range (clusters .ndata )
106- if clusters .labels [i ] == ic ])
107- tx = np .array (clusters .x )
108- ty = np .array (clusters .y )
106+ DX1 = 0.4
107+ DY1 = 0.5
109108
110- x = tx [ inds ]
111- y = ty [ inds ]
109+ cx [ 0 ] += DX1
110+ cy [ 0 ] += DY1
112111
113- return x , y
112+ DX2 = - 0.3
113+ DY2 = - 0.5
114+
115+ cx [1 ] += DX2
116+ cy [1 ] += DY2
117+
118+ return cx , cy
119+
120+
121+ def calc_association (cx , cy , clusters ):
122+
123+ inds = []
124+
125+ for ic in range (len (cx )):
126+ tcx = cx [ic ]
127+ tcy = cy [ic ]
128+
129+ dx = [icx - tcx for icx in clusters .cx ]
130+ dy = [icy - tcy for icy in clusters .cy ]
131+
132+ dlist = [math .sqrt (idx ** 2 + idy ** 2 ) for (idx , idy ) in zip (dx , dy )]
133+ min_id = dlist .index (min (dlist ))
134+ inds .append (min_id )
135+
136+ return inds
114137
115138
116139def main ():
117140 print (__file__ + " start!!" )
118141
119- rx , ry = calc_raw_data ()
120-
142+ cx = [0.0 , 8.0 ]
143+ cy = [0.0 , 8.0 ]
144+ npoints = 10
145+ rand_d = 3.0
121146 ncluster = 2
122- clusters = kmean_clustering (rx , ry , ncluster )
123-
124- for ic in range (clusters .nlabel ):
125- x , y = calc_labeled_points (ic , clusters )
126- plt .plot (x , y , "x" )
127- plt .plot (clusters .cx , clusters .cy , "o" )
128- plt .show ()
147+ sim_time = 15.0
148+ dt = 1.0
149+ time = 0.0
150+
151+ while time <= sim_time :
152+ print ("Time:" , time )
153+ time += dt
154+
155+ # simulate objects
156+ cx , cy = update_positions (cx , cy )
157+ rx , ry = calc_raw_data (cx , cy , npoints , rand_d )
158+
159+ clusters = kmeans_clustering (rx , ry , ncluster )
160+
161+ # for animation
162+ plt .cla ()
163+ inds = calc_association (cx , cy , clusters )
164+ for ic in inds :
165+ x , y = calc_labeled_points (ic , clusters )
166+ plt .plot (x , y , "x" )
167+ plt .plot (cx , cy , "o" )
168+ plt .xlim (- 2.0 , 10.0 )
169+ plt .ylim (- 2.0 , 10.0 )
170+ plt .pause (dt )
171+
172+ print ("Done" )
129173
130174
131175if __name__ == '__main__' :
0 commit comments