Skip to content

Commit a9b9d60

Browse files
committed
added some more chapters
1 parent bddc46f commit a9b9d60

7 files changed

+1360
-0
lines changed

clustering.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
from __future__ import division
2+
from linear_algebra import squared_distance, vector_mean, distance
3+
import math, random
4+
import matplotlib.image as mpimg
5+
import matplotlib.pyplot as plt
6+
7+
class KMeans:
8+
"""performs k-means clustering"""
9+
10+
def __init__(self, k):
11+
self.k = k # number of clusters
12+
self.means = None # means of clusters
13+
14+
def classify(self, input):
15+
"""return the index of the cluster closest to the input"""
16+
return min(range(self.k),
17+
key=lambda i: squared_distance(input, self.means[i]))
18+
19+
def train(self, inputs):
20+
21+
self.means = random.sample(inputs, self.k)
22+
assignments = None
23+
24+
while True:
25+
# Find new assignments
26+
new_assignments = map(self.classify, inputs)
27+
28+
# If no assignments have changed, we're done.
29+
if assignments == new_assignments:
30+
return
31+
32+
# Otherwise keep the new assignments,
33+
assignments = new_assignments
34+
35+
for i in range(self.k):
36+
i_points = [p for p, a in zip(inputs, assignments) if a == i]
37+
# avoid divide-by-zero if i_points is empty
38+
if i_points:
39+
self.means[i] = vector_mean(i_points)
40+
41+
def squared_clustering_errors(inputs, k):
42+
"""finds the total squared error from k-means clustering the inputs"""
43+
clusterer = KMeans(k)
44+
clusterer.train(inputs)
45+
means = clusterer.means
46+
assignments = map(clusterer.classify, inputs)
47+
48+
return sum(squared_distance(input,means[cluster])
49+
for input, cluster in zip(inputs, assignments))
50+
51+
def plot_squared_clustering_errors(plt):
52+
53+
ks = range(1, len(inputs) + 1)
54+
errors = [squared_clustering_errors(inputs, k) for k in ks]
55+
56+
plt.plot(ks, errors)
57+
plt.xticks(ks)
58+
plt.xlabel("k")
59+
plt.ylabel("total squared error")
60+
plt.show()
61+
62+
#
63+
# using clustering to recolor an image
64+
#
65+
66+
def recolor_image(input_file, k):
67+
68+
img = mpimg.imread(path_to_png_file)
69+
pixels = [pixel for row in img for pixel in row]
70+
clusterer = KMeans(5)
71+
clusterer.train(pixels) # this might take a while
72+
73+
def recolor(pixel):
74+
cluster = clusterer.classify(pixel) # index of the closest cluster
75+
return clusterer.means[cluster] # mean of the closest cluster
76+
77+
new_img = [[recolor(pixel) for pixel in row]
78+
for row in img]
79+
80+
plt.imshow(new_img)
81+
plt.axis('off')
82+
plt.show()
83+
84+
#
85+
# hierarchical clustering
86+
#
87+
88+
def cluster_distance(cluster1, cluster2, distance_agg=min):
89+
"""finds the aggregate distance between elements of cluster1
90+
and elements of cluster2"""
91+
return distance_agg(distance(input_i, input_j)
92+
for input_i in cluster1.members()
93+
for input_j in cluster2.members())
94+
95+
class LeafCluster:
96+
"""stores a single input
97+
it has 'infinite depth' so that we never try to split it"""
98+
99+
def __init__(self, value):
100+
self.value = value
101+
self.depth = float('inf')
102+
103+
def __repr__(self):
104+
return str(self.value)
105+
106+
def members(self):
107+
"""a LeafCluster has only one member"""
108+
return [self.value]
109+
110+
class MergedCluster:
111+
"""a new cluster that's the result of 'merging' two clusters"""
112+
113+
def __init__(self, branches, depth):
114+
self.branches = branches
115+
self.depth = depth
116+
117+
def __repr__(self):
118+
"""show as {(depth) child1, child2}"""
119+
return ("{(" + str(self.depth) + ") " +
120+
", ".join(str(b) for b in self.branches) + " }")
121+
122+
def members(self):
123+
"""recursively get members by looking for members of branches"""
124+
return [member
125+
for cluster in self.branches
126+
for member in cluster.members()]
127+
128+
129+
class BottomUpClusterer:
130+
131+
def __init__(self, distance_agg=min):
132+
self.agg = distance_agg
133+
self.clusters = None
134+
135+
def train(self, inputs):
136+
# start with each input its own cluster
137+
self.clusters = [LeafCluster(input) for input in inputs]
138+
139+
while len(self.clusters) > 1:
140+
141+
# find the two closest clusters
142+
c1, c2 = min([(cluster1, cluster2)
143+
for cluster1 in self.clusters
144+
for cluster2 in self.clusters
145+
if cluster1 != cluster2],
146+
key=lambda (c1, c2): cluster_distance(c1, c2,
147+
self.agg))
148+
149+
merged_cluster = MergedCluster([c1, c2], len(self.clusters))
150+
151+
self.clusters = [c for c in self.clusters
152+
if c not in [c1, c2]]
153+
154+
self.clusters.append(merged_cluster)
155+
156+
def get_clusters(self, num_clusters):
157+
"""extract num_clusters clusters from the hierachy"""
158+
159+
clusters = self.clusters[:] # create a copy so we can modify it
160+
while len(clusters) < num_clusters:
161+
# choose the least deep cluster
162+
next_cluster = min(clusters, key=lambda c: c.depth)
163+
# remove it from the list
164+
clusters = [c for c in clusters if c != next_cluster]
165+
# and add its children
166+
clusters.extend(next_cluster.branches)
167+
168+
return clusters
169+
170+
171+
172+
173+
174+
if __name__ == "__main__":
175+
176+
inputs = [[-14,-5],[13,13],[20,23],[-19,-11],[-9,-16],[21,27],[-49,15],[26,13],[-46,5],[-34,-1],[11,15],[-49,0],[-22,-16],[19,28],[-12,-8],[-13,-19],[-41,8],[-11,-6],[-25,-9],[-18,-3]]
177+
178+
random.seed(0) # so you get the same results as me
179+
clusterer = KMeans(3)
180+
clusterer.train(inputs)
181+
print "3-means:"
182+
print clusterer.means
183+
print
184+
185+
random.seed(0)
186+
clusterer = KMeans(2)
187+
clusterer.train(inputs)
188+
print "2-means:"
189+
print clusterer.means
190+
print
191+
192+
print "errors as a function of k"
193+
194+
for k in range(1, len(inputs) + 1):
195+
print k, squared_clustering_errors(inputs, k)
196+
print
197+
198+
199+
print "bottom up hierarchical clustering"
200+
201+
buc = BottomUpClusterer() # or BottomUpClusterer(max) if you like
202+
buc.train(inputs)
203+
print buc.clusters[0]
204+
205+
print
206+
print "three clusters:"
207+
for cluster in buc.get_clusters(3):
208+
print cluster

0 commit comments

Comments
 (0)