|
| 1 | +# coding:UTF-8 |
| 2 | +''' |
| 3 | +Date:20160923 |
| 4 | +@author: zhaozhiyong |
| 5 | +''' |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from random import random |
| 9 | +from KMeans import load_data, kmeans, distance, save_result |
| 10 | + |
| 11 | +FLOAT_MAX = 1e100 # 设置一个较大的值作为初始化的最小的距离 |
| 12 | + |
| 13 | +def nearest(point, cluster_centers): |
| 14 | + '''计算point和cluster_centers之间的最小距离 |
| 15 | + input: point(mat):当前的样本点 |
| 16 | + cluster_centers(mat):当前已经初始化的聚类中心 |
| 17 | + output: min_dist(float):点point和当前的聚类中心之间的最短距离 |
| 18 | + ''' |
| 19 | + min_dist = FLOAT_MAX |
| 20 | + m = np.shape(cluster_centers)[0] # 当前已经初始化的聚类中心的个数 |
| 21 | + for i in xrange(m): |
| 22 | + # 计算point与每个聚类中心之间的距离 |
| 23 | + d = distance(point, cluster_centers[i, ]) |
| 24 | + # 选择最短距离 |
| 25 | + if min_dist > d: |
| 26 | + min_dist = d |
| 27 | + return min_dist |
| 28 | + |
| 29 | +def get_centroids(points, k): |
| 30 | + '''KMeans++的初始化聚类中心的方法 |
| 31 | + input: points(mat):样本 |
| 32 | + k(int):聚类中心的个数 |
| 33 | + output: cluster_centers(mat):初始化后的聚类中心 |
| 34 | + ''' |
| 35 | + m, n = np.shape(points) |
| 36 | + cluster_centers = np.mat(np.zeros((k , n))) |
| 37 | + # 1、随机选择一个样本点为第一个聚类中心 |
| 38 | + index = np.random.randint(0, m) |
| 39 | + cluster_centers[0, ] = np.copy(points[index, ]) |
| 40 | + # 2、初始化一个距离的序列 |
| 41 | + d = [0.0 for _ in xrange(m)] |
| 42 | + |
| 43 | + for i in xrange(1, k): |
| 44 | + sum_all = 0 |
| 45 | + for j in xrange(m): |
| 46 | + # 3、对每一个样本找到最近的聚类中心点 |
| 47 | + d[j] = nearest(points[j, ], cluster_centers[0:i, ]) |
| 48 | + # 4、将所有的最短距离相加 |
| 49 | + sum_all += d[j] |
| 50 | + # 5、取得sum_all之间的随机值 |
| 51 | + sum_all *= random() |
| 52 | + # 6、获得距离最远的样本点作为聚类中心点 |
| 53 | + for j, di in enumerate(d): |
| 54 | + sum_all -= di |
| 55 | + if sum_all > 0: |
| 56 | + continue |
| 57 | + cluster_centers[i] = np.copy(points[j, ]) |
| 58 | + break |
| 59 | + return cluster_centers |
| 60 | + |
| 61 | +if __name__ == "__main__": |
| 62 | + k = 4#聚类中心的个数 |
| 63 | + file_path = "data.txt" |
| 64 | + # 1、导入数据 |
| 65 | + print "---------- 1.load data ------------" |
| 66 | + data = load_data(file_path) |
| 67 | + # 2、KMeans++的聚类中心初始化方法 |
| 68 | + print "---------- 2.K-Means++ generate centers ------------" |
| 69 | + centroids = get_centroids(data, k) |
| 70 | + # 3、聚类计算 |
| 71 | + print "---------- 3.kmeans ------------" |
| 72 | + subCenter = kmeans(data, k, centroids) |
| 73 | + # 4、保存所属的类别文件 |
| 74 | + print "---------- 4.save subCenter ------------" |
| 75 | + save_result("sub_pp", subCenter) |
| 76 | + # 5、保存聚类中心 |
| 77 | + print "---------- 5.save centroids ------------" |
| 78 | + save_result("center_pp", centroids) |
0 commit comments