代码示例
# -*- coding: utf-8 -*- """ Created on Fri Sep 21 15:37:26 2018 @author: zhen """ from PIL import Image import numpy as np from sklearn.cluster import KMeans import matplotlib import matplotlib.pyplot as plt def restore_image(cb, cluster, shape): row, col, dummy = shape image = np.empty((row, col, dummy)) for r in range(row): for c in range(col): image[r, c] = cb[cluster[r * col + c]] return image def show_scatter(a): N = 10 density, edges = np.histogramdd(a, bins=[N, N, N], range=[(0, 1), (0, 1), (0, 1)]) density /= density.max() x = y = z = np.arange(N) d = np.meshgrid(x, y, z) fig = plt.figure(1, facecolor='w') ax = fig.add_subplot(111, projection='3d') cm = matplotlib.colors.ListedColormap(list('rgbm')) ax.scatter(d[0], d[1], d[2], s=100 * density, cmap=cm, marker='o', depthshade=True) ax.set_xlabel(u'红') ax.set_ylabel(u'绿') ax.set_zlabel(u'蓝') plt.title(u'图像颜色三维频数分布', fontsize=20) plt.figure(2, facecolor='w') den = density[density > 0] den = np.sort(den)[::-1] t = np.arange(len(den)) plt.plot(t, den, 'r-', t, den, 'go', lw=2) plt.title(u'图像颜色频数分布', fontsize=18) plt.grid(True) plt.show() if __name__ == '__main__': matplotlib.rcParams['font.sans-serif'] = [u'SimHei'] matplotlib.rcParams['axes.unicode_minus'] = False # 聚类数2,6,30 num_vq = 30 im = Image.open('test2.png') image = np.array(im).astype(np.float64) / 255 image = image[:, :, :3] image_v = image.reshape((-1, 3)) kmeans = KMeans(n_clusters=num_vq, init='k-means++') show_scatter(image_v) N = image_v.shape[0] # 图像像素总数 # 选择样本,计算聚类中心 idx = np.random.randint(0, N, size=int(N * 0.7)) image_sample = image_v[idx] kmeans.fit(image_sample) result = kmeans.predict(image_v) # 聚类结果 print("result:%s ", result) print("result central:%s ", kmeans.cluster_centers_) plt.figure(figsize=(15, 8), facecolor='w') plt.subplot(211) plt.axis('off') plt.title(u'原始图片', fontsize=18) plt.imshow(image) # plt.savefig('原始图片.png') plt.subplot(212) vq_image = restore_image(kmeans.cluster_centers_, result, image.shape) plt.axis('off') plt.title(u'聚类个数:%d' % num_vq, fontsize=20) plt.imshow(vq_image) # plt.savefig('矢量化图片.png') plt.tight_layout() plt.show()
总结:当聚类个数num_vq 较少时,算法运算速度快但效果较差,当聚类个数较多时,运算速度慢效果好但容易过拟合,所以恰当的k值对于聚类来说影响极其明显