K-Means 算法(Java)

kMeans算法原理见我的上一篇文章。这里介绍K-Means的Java实现方法,参考了Python的实现方法。

一、数据点的实现

package com.meachine.learning.kmeans;

import java.util.ArrayList;

/**
 * 数据点,有n维数据
 * 
 */
public class Point {
    private static int num;
    private int id;
    private int dimensioNum; // 维度
    private ArrayList<Double> values;
    private int clusterId = -1;
    private double minDist = Integer.MAX_VALUE;

    public Point() {
	id = ++num;
	values = new ArrayList<>();
    }

    public void add(double e) {
	values.add(e);
	dimensioNum++;
    }
    //------set与get省略----------
}

二、数据簇的实现

package com.meachine.learning.kmeans;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;

/**
 * 簇<br>
 * 数据集合的基本信息
 * 
 */
public class Cluster {
    // 簇id
    private int clusterId;
    // 属于该簇的点的个数
    private int numOfPoints;
    // 簇中心点的信息
    private Point center;

    public Cluster(int id) {
	this.clusterId = id;
	numOfPoints = 0;
    }

    public Cluster(int id, Point center) {
	this.clusterId = id;
	this.center = center;
    }
  //----------set与get省略----------------
}

三、计算数据点距离

package com.meachine.learning.kmeans;

import java.util.List;

/**
 * 计算距离接口
 *
 */
public interface IDistance<T> {
    public double getDis(List<T> p1, List<T> p2);
}

  

package com.meachine.learning.kmeans;

import java.util.List;

/**
 * 欧式距离
 *
 */
public class OujilidDistance<T extends Number> implements IDistance<T> {

    public double getDis(List<T> a, List<T> b) {
	if (a.size() != b.size()) {
	    throw new IllegalArgumentException("Size not compatible!");
	}
	double result = 0;
	for (int i = 0; i < a.size(); i++) {
	    result += Math.pow((a.get(i).doubleValue() - b.get(i).doubleValue()), 2);
	}
	return Math.sqrt(result);
    }

}

四、K-Means算法

  

package com.meachine.learning.kmeans;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
 * K-Means算法
 * 
 * @author Cang
 *
 */
public class KMeans {
    // 簇的个数
    private int k;
    // 维度,即多少个变量
    private int dimensioNum;
    // 最大迭代次数
    private int maxItrNum = 100;
    private IDistance<Double> distance;
    private List<Point> points;
    private List<Cluster> clusters = new ArrayList<Cluster>();
    private String dataFileName = "D:/testSet.txt";

    public KMeans(int k) {
	this.k = k;
    }

    /**
     * 初始化数据
     */
    public void init() {
	points = loadDataSet(dataFileName);
	distance = new OujilidDistance<Double>();
	initCluster();
    }

    /**
     * 加载数据集
     * 
     * @param fileName
     * @return
     */
    private List<Point> loadDataSet(String fileName) {
	List<Point> points = new ArrayList<>();
	File file = new File(fileName);
	BufferedReader reader = null;
	try {
	    reader = new BufferedReader(new FileReader(file));
	    String tempString = null;
	    int i = 0;
	    while ((tempString = reader.readLine()) != null) {
		Point point = new Point();
		dimensioNum = tempString.split("	").length;
		for (String data : tempString.split("	")) {
		    point.add(Double.parseDouble(data));
		}
		points.add(point);
	    }
	    reader.close();
	} catch (IOException e) {
	    e.printStackTrace();
	}
	return points;
    }

    /**
     * 初始化簇中心
     * 
     * @return
     */
    private void initCluster() {
	Random ran = new Random();
	int id = 0;
	while (id < k) {
	    Cluster c = new Cluster(++id);
	    int temp = ran.nextInt(points.size());
	    c.setCenter(points.get(temp));
	    clusters.add(c);
	}
    }

    /**
     * kMeans 具体算法
     */
    public void clustering() {
	boolean finished = false;
	int count = 0;
	while (!finished) {
	    // 寻找最近的中心
	    finished = true;
	    for (Point point : points) {
		for (Cluster cluster : clusters) {

		    double minLen = distance.getDis(cluster.getCenter().getValues(),
			    point.getValues());
		    // 更新最小距离
		    if (minLen < point.getMinDist()) {
			if (cluster.getClusterId() != point.getClusterId()) {
			    finished = false;
			    point.setClusterId(cluster.getClusterId());
			}
			point.setMinDist(minLen);
		    }
		}
	    }
	    System.out.println("Cluster center info:");
	    for (Cluster string : clusters) {
		System.out.println(string.getCenter().getValues());
	    }
	    // 更改中心的位置
	    changeCentroids();
	    // 超过循环次数,则跳出循环
	    if (++count > maxItrNum) {
		finished = true;
	    }
	}
    }

    /**
     * 改变簇中心
     */
    private void changeCentroids() {
	for (Cluster cluster : clusters) {
	    ArrayList<Double> newCenterValue = new ArrayList<Double>();
	    Point newCenterPoint = new Point();
	    double result = 0;
	    for (int i = 0; i < dimensioNum; i++) {
		for (Point point : points) {
		    if (point.getClusterId() == cluster.getClusterId()) {
			result += point.getValues().get(i);
		    }
		}
		newCenterValue.add(result / points.size());
	    }
	    newCenterPoint.setClusterId(cluster.getClusterId());
	    newCenterPoint.setValues(newCenterValue);
	    cluster.setCenter(newCenterPoint);
	}
    }

    public static void main(String[] args) {
	KMeans kmeans = new KMeans(4);
	kmeans.init();
	kmeans.clustering();
    }
}

  

原文地址:https://www.cnblogs.com/codingexperience/p/5040942.html