1
+ from __future__ import division
2
+ from linear_algebra import squared_distance , vector_mean , distance
3
+ import math , random
4
+ import matplotlib .image as mpimg
5
+ import matplotlib .pyplot as plt
6
+
7
+ class KMeans :
8
+ """performs k-means clustering"""
9
+
10
+ def __init__ (self , k ):
11
+ self .k = k # number of clusters
12
+ self .means = None # means of clusters
13
+
14
+ def classify (self , input ):
15
+ """return the index of the cluster closest to the input"""
16
+ return min (range (self .k ),
17
+ key = lambda i : squared_distance (input , self .means [i ]))
18
+
19
+ def train (self , inputs ):
20
+
21
+ self .means = random .sample (inputs , self .k )
22
+ assignments = None
23
+
24
+ while True :
25
+ # Find new assignments
26
+ new_assignments = map (self .classify , inputs )
27
+
28
+ # If no assignments have changed, we're done.
29
+ if assignments == new_assignments :
30
+ return
31
+
32
+ # Otherwise keep the new assignments,
33
+ assignments = new_assignments
34
+
35
+ for i in range (self .k ):
36
+ i_points = [p for p , a in zip (inputs , assignments ) if a == i ]
37
+ # avoid divide-by-zero if i_points is empty
38
+ if i_points :
39
+ self .means [i ] = vector_mean (i_points )
40
+
41
+ def squared_clustering_errors (inputs , k ):
42
+ """finds the total squared error from k-means clustering the inputs"""
43
+ clusterer = KMeans (k )
44
+ clusterer .train (inputs )
45
+ means = clusterer .means
46
+ assignments = map (clusterer .classify , inputs )
47
+
48
+ return sum (squared_distance (input ,means [cluster ])
49
+ for input , cluster in zip (inputs , assignments ))
50
+
51
+ def plot_squared_clustering_errors (plt ):
52
+
53
+ ks = range (1 , len (inputs ) + 1 )
54
+ errors = [squared_clustering_errors (inputs , k ) for k in ks ]
55
+
56
+ plt .plot (ks , errors )
57
+ plt .xticks (ks )
58
+ plt .xlabel ("k" )
59
+ plt .ylabel ("total squared error" )
60
+ plt .show ()
61
+
62
+ #
63
+ # using clustering to recolor an image
64
+ #
65
+
66
+ def recolor_image (input_file , k ):
67
+
68
+ img = mpimg .imread (path_to_png_file )
69
+ pixels = [pixel for row in img for pixel in row ]
70
+ clusterer = KMeans (5 )
71
+ clusterer .train (pixels ) # this might take a while
72
+
73
+ def recolor (pixel ):
74
+ cluster = clusterer .classify (pixel ) # index of the closest cluster
75
+ return clusterer .means [cluster ] # mean of the closest cluster
76
+
77
+ new_img = [[recolor (pixel ) for pixel in row ]
78
+ for row in img ]
79
+
80
+ plt .imshow (new_img )
81
+ plt .axis ('off' )
82
+ plt .show ()
83
+
84
+ #
85
+ # hierarchical clustering
86
+ #
87
+
88
+ def cluster_distance (cluster1 , cluster2 , distance_agg = min ):
89
+ """finds the aggregate distance between elements of cluster1
90
+ and elements of cluster2"""
91
+ return distance_agg (distance (input_i , input_j )
92
+ for input_i in cluster1 .members ()
93
+ for input_j in cluster2 .members ())
94
+
95
+ class LeafCluster :
96
+ """stores a single input
97
+ it has 'infinite depth' so that we never try to split it"""
98
+
99
+ def __init__ (self , value ):
100
+ self .value = value
101
+ self .depth = float ('inf' )
102
+
103
+ def __repr__ (self ):
104
+ return str (self .value )
105
+
106
+ def members (self ):
107
+ """a LeafCluster has only one member"""
108
+ return [self .value ]
109
+
110
+ class MergedCluster :
111
+ """a new cluster that's the result of 'merging' two clusters"""
112
+
113
+ def __init__ (self , branches , depth ):
114
+ self .branches = branches
115
+ self .depth = depth
116
+
117
+ def __repr__ (self ):
118
+ """show as {(depth) child1, child2}"""
119
+ return ("{(" + str (self .depth ) + ") " +
120
+ ", " .join (str (b ) for b in self .branches ) + " }" )
121
+
122
+ def members (self ):
123
+ """recursively get members by looking for members of branches"""
124
+ return [member
125
+ for cluster in self .branches
126
+ for member in cluster .members ()]
127
+
128
+
129
+ class BottomUpClusterer :
130
+
131
+ def __init__ (self , distance_agg = min ):
132
+ self .agg = distance_agg
133
+ self .clusters = None
134
+
135
+ def train (self , inputs ):
136
+ # start with each input its own cluster
137
+ self .clusters = [LeafCluster (input ) for input in inputs ]
138
+
139
+ while len (self .clusters ) > 1 :
140
+
141
+ # find the two closest clusters
142
+ c1 , c2 = min ([(cluster1 , cluster2 )
143
+ for cluster1 in self .clusters
144
+ for cluster2 in self .clusters
145
+ if cluster1 != cluster2 ],
146
+ key = lambda (c1 , c2 ): cluster_distance (c1 , c2 ,
147
+ self .agg ))
148
+
149
+ merged_cluster = MergedCluster ([c1 , c2 ], len (self .clusters ))
150
+
151
+ self .clusters = [c for c in self .clusters
152
+ if c not in [c1 , c2 ]]
153
+
154
+ self .clusters .append (merged_cluster )
155
+
156
+ def get_clusters (self , num_clusters ):
157
+ """extract num_clusters clusters from the hierachy"""
158
+
159
+ clusters = self .clusters [:] # create a copy so we can modify it
160
+ while len (clusters ) < num_clusters :
161
+ # choose the least deep cluster
162
+ next_cluster = min (clusters , key = lambda c : c .depth )
163
+ # remove it from the list
164
+ clusters = [c for c in clusters if c != next_cluster ]
165
+ # and add its children
166
+ clusters .extend (next_cluster .branches )
167
+
168
+ return clusters
169
+
170
+
171
+
172
+
173
+
174
+ if __name__ == "__main__" :
175
+
176
+ inputs = [[- 14 ,- 5 ],[13 ,13 ],[20 ,23 ],[- 19 ,- 11 ],[- 9 ,- 16 ],[21 ,27 ],[- 49 ,15 ],[26 ,13 ],[- 46 ,5 ],[- 34 ,- 1 ],[11 ,15 ],[- 49 ,0 ],[- 22 ,- 16 ],[19 ,28 ],[- 12 ,- 8 ],[- 13 ,- 19 ],[- 41 ,8 ],[- 11 ,- 6 ],[- 25 ,- 9 ],[- 18 ,- 3 ]]
177
+
178
+ random .seed (0 ) # so you get the same results as me
179
+ clusterer = KMeans (3 )
180
+ clusterer .train (inputs )
181
+ print "3-means:"
182
+ print clusterer .means
183
+ print
184
+
185
+ random .seed (0 )
186
+ clusterer = KMeans (2 )
187
+ clusterer .train (inputs )
188
+ print "2-means:"
189
+ print clusterer .means
190
+ print
191
+
192
+ print "errors as a function of k"
193
+
194
+ for k in range (1 , len (inputs ) + 1 ):
195
+ print k , squared_clustering_errors (inputs , k )
196
+ print
197
+
198
+
199
+ print "bottom up hierarchical clustering"
200
+
201
+ buc = BottomUpClusterer () # or BottomUpClusterer(max) if you like
202
+ buc .train (inputs )
203
+ print buc .clusters [0 ]
204
+
205
+ print
206
+ print "three clusters:"
207
+ for cluster in buc .get_clusters (3 ):
208
+ print cluster
0 commit comments