K-means

原理入门视频:https://www.bilibili.com/video/av14601364/

实现基本功能,从txt中读取数据,根据给定的K值进行分类。

Java代码:

package kmeans;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;

public class Kmeans {

    private int k; // 分成多少簇
    private int m; // 迭代次数
    private int len; // 数据元素个数
    private ArrayList<double[]> center; // 存放中心
    private ArrayList<double[]> clist; // 存放所有点
    private ArrayList<ArrayList<double[]>> cluster; // 存放分类
    private ArrayList<Double> jc; // 误差平方和

    public Kmeans(int k) {
        if (k <= 0) {
            k = 1;
        }
        this.k = k;
        m = 0;
        center = new ArrayList<double[]>();
        clist = new ArrayList<double[]>();
        cluster = new ArrayList<ArrayList<double[]>>();
        jc = new ArrayList<Double>();
    }

    public int getK() {
        return k;
    }

    public void setK(int k) {
        this.k = k;
    }

    public int getM() {
        return m;
    }

    public ArrayList<double[]> getCenter() {
        return center;
    }

    public ArrayList<double[]> getClist() {
        return clist;
    }

    public ArrayList<ArrayList<double[]>> getCluster() {
        return cluster;
    }

    public void initClist() {
        String[] pointStr = null;
        String str;
        try {
            BufferedReader br = new BufferedReader(new FileReader("simple_k-means.txt"));
            try {
                while ((str = br.readLine()) != null) {
                    pointStr = str.split("\s+");// 根据空格换行分割
                    double[] pointFlt = new double[pointStr.length];
                    for (int i = 0; i < pointStr.length; i++) {
                        pointFlt[i] = Double.parseDouble(pointStr[i]);
                    }
                    this.clist.add(pointFlt);
                }
                this.len = this.clist.size();
                if (k > len) {
                    k = len;
                }
            } catch (IOException e) {
                System.out.println(e.toString());
            }
        } catch (FileNotFoundException e) {
            System.out.println(e.toString());
        }
    }

    public void initCenter() {
        Random random = new Random();
        for (int i = 0; i < this.k; i++) {
            int temp = random.nextInt(this.len);
            this.center.add(this.clist.get(temp));
        }
    }

    public void initCluster() {
        for (int i = 0; i < this.k; i++) {
            this.cluster.add(new ArrayList<double[]>());
        }
    }

    private double getSumSquare(double[] element, double[] center) {
        double x = element[0] - center[0];
        double y = element[1] - center[1];
        double z = x * x + y * y;
        return z;
    }

    // 获取距离集合中最小距离的位置
    private int minDistance(double[] distance) {
        double minDis = 0x3f3f3f;
        int minLocation = 0;
        for (int i = 0; i < distance.length; i++) {
            if (minDis > distance[i]) {
                minDis = distance[i];
                minLocation = i;
            }
        }
        return minLocation;
    }

    // 将当前元素放到最小距离中心相关的簇中
    private void clusterSet() {
        double[] distance = new double[this.k];
        for (int i = 0; i < this.len; i++) {
            for (int j = 0; j < this.k; j++) {
                distance[j] = Math.sqrt(getSumSquare(this.clist.get(i), this.center.get(j)));
            }
            int minLocation = minDistance(distance);
            this.cluster.get(minLocation).add(this.clist.get(i));
        }
    }

    // 求误差
    private void countRule() {
        double jcf = 0;
        for (int i = 0; i < this.cluster.size(); i++) {
            for (int j = 0; j < this.cluster.get(i).size(); j++) {
                jcf += getSumSquare(this.cluster.get(i).get(j), center.get(i));
            }
        }
        jc.add(jcf);
    }

    // 设置新的簇中心方法
    private void setNewCenter() {
        for (int i = 0; i < this.k; i++) {
            int n = this.cluster.get(i).size();
            if (n != 0) {
                double[] newCenter = { 0, 0 };
                for (int j = 0; j < n; j++) {
                    newCenter[0] += this.cluster.get(i).get(j)[0];
                    newCenter[1] += this.cluster.get(i).get(j)[1];
                }
                newCenter[0] = newCenter[0] / n;
                newCenter[1] = newCenter[1] / n;
                this.center.set(i, newCenter);
            }
        }
    }

    // 核心过程
    private void kmeans() {
        initClist();
        initCenter();
        initCluster();
        while (true) {
            clusterSet();
            countRule();
            if (m != 0) {
                if (jc.get(m) - jc.get(m - 1) < 0.001) {
                    break;
                }
            }
            setNewCenter();
            m++;
            cluster.clear();
            initCluster();
        }
    }

    public static void main(String[] args) {
        Kmeans km = new Kmeans(2);  //设置聚类个数
        km.kmeans();
        int count = km.getM();
        ArrayList<double[]> center = km.getCenter();
        ArrayList<ArrayList<double[]>> cluster = km.getCluster();
        System.out.println("迭代次数: " + count);
        System.out.println("----------质心:------------------");
        for (int i = 0; i < center.size(); i++) {
            System.out.println("[" + center.get(i)[0] + "," + center.get(i)[1] + "]");
        }
        System.out.println("----------聚类结果:--------------");
        for (int i = 0; i < cluster.size(); i++) {
            for (int j = 0; j < cluster.get(i).size(); j++) {
                System.out.print("[" + cluster.get(i).get(j)[0] + "," + cluster.get(i).get(j)[0] + "] ");
            }
            System.out.println();
        }
    }
}
java-Kmeans

 传统Kmeans需要给定K值,有两种初始化中心点的方法,一种是在现用的点中,尽量远的随机选择K个点,一种是根据实际问题,自己初始化K个点。

可能会出现这几种情况:过早的收敛,导致局部最优;某个中心点可能会聚不到点形成空簇。

为了解决这些问题,提出了改进的Kmeans算法,有需要的继续了解。

原文地址:https://www.cnblogs.com/flyuz/p/9041240.html