K近邻算法

public class KnnTest 
{
    public static void readFileToList(String path, List<List<Double>> list)
    {
        BufferedReader br = null;
        
        try {
            br = new BufferedReader(new FileReader(path));
            while (br.ready()) {
                String line = br.readLine();
                if (line.trim().isEmpty()) {
                    continue;
                }
                String[] tokens = line.split(" ");
                List<Double> box = new ArrayList<Double>();
                
                for (String num : tokens) {
                    box.add(Double.parseDouble(num));
                }
                list.add(box);
            }
        }
        catch (IOException ex) {
            ex.printStackTrace();
        }
    }
    
    
    public static void main(String[] args)
    {
        int length = 2;
        String dataFile = "data.txt"; 
        String testFile = "test.txt";
        
        KNN knn = new KNN();
        
        try {
            List<List<Double>> dataList = new ArrayList<List<Double>>();
            List<List<Double>> testList = new ArrayList<List<Double>>();
            
            readFileToList(dataFile, dataList);
            readFileToList(testFile, testList);
            
            for (List<Double> test : testList) {
                for (Double d : test) {
                    System.out.print(d + " ");
                }
                
                String category = knn.knn(dataList, test, length);
                System.out.println(Math.round(Float.parseFloat(category)));
            }
        }
        catch (Exception ex) {
            ex.printStackTrace();
        }
    }
}


class KNN
{
    private static Comparator<Node> comparator = new Comparator<Node>()
    {
        public int compare(Node n1, Node n2)
        {
            if (n1.getDistans() > n2.getDistans()) {
                return 1;
            }
            return 0;
        }
    };
    
    private int[] getRankNumbers(int n, int max)
    {
        int[] result = new int[n];
        int current = 0;
        
        back: for (int i = 0; i < n; i++) {
            current = (int) (Math.random() * max);
            
            for (int j = 0; j < i; j++) {
                if (current == result[j]) {
                    i--;
                    continue back;
                }
            }
            
            result[i] = current;
        }
        
        return result;
    }
    
    public String knn(List<List<Double>> example, List<Double> test, int k)
    {
        PriorityQueue<Node> pq = new PriorityQueue<Node>(k, comparator);
        int[] rand = getRankNumbers(k, example.size());
        
        for (int i = 0; i < k; i++) {
            List<Double> list = example.get(rand[i]);
            String category = list.get(list.size() - 1).toString();
            Node node = new Node(rand[i], calDistans(test, list), category);
            pq.add(node);
        }
        
        for (int i = 0; i < example.size(); i++) {
            List<Double> list = example.get(i);
            double distans = calDistans(test, list);
            Node node = pq.peek();
            if (node.getDistans() > distans) {
                pq.remove();
                pq.add(new Node(i, distans, list.get(list.size() - 1).toString()));
            }
        }
        
        return getMostCategory(pq);
    }
    
    private String getMostCategory(PriorityQueue<Node> pq)
    {
        Map<String, Integer> rankMapping = new HashMap<String, Integer>(pq.size(), 1);
        
        for (int i = 0; i < pq.size(); i++) {
            Node node = pq.remove();
            String category = node.getCategory();
            if (rankMapping.containsKey(category)) {
                rankMapping.put(category, rankMapping.get(category) + 1);
            }
            else {
                rankMapping.put(category, 1);
            }
        }
        
        int index = -1;
        int count = 0;
        
        Object[] data = rankMapping.keySet().toArray();
        for (int i = 0; i < data.length; i++) {
            if (rankMapping.get(data[i]) > count) {
                index = i;
                count = rankMapping.get(data[i]);
            }
        }
        
        return data[index].toString();
    }
    
    
    public double calDistans(List<Double> list1, List<Double> list2)
    {
        double result = 0.00;
        
        for (int i = 0; i < list1.size(); i++) {
            result += (list1.get(i) - list2.get(i)) * (list1.get(i) - list2.get(i));
        }
        
        return result;
    }
    
    
    static class Node
    {
        private int index;
        private double distans;
        private String category;

        public Node(int index, double distans, String category)
        {
            this.index = index;
            this.distans = distans;
            this.category = category;
        }

        public int getIndex() 
        {
            return index;
        }

        public void setIndex(int index) 
        {
            this.index = index;
        }

        public double getDistans() 
        {
            return distans;
        }

        public void setDistans(double distans) 
        {
            this.distans = distans;
        }

        public String getCategory() 
        {
            return category;
        }

        public void setCategory(String category) 
        {
            this.category = category;
        }
    }    
}
原文地址:https://www.cnblogs.com/rilley/p/2690098.html