Parallel Decision Tree

Decision Tree such as C4.5 is easy to parallel. Following is an example.

This is a non-parallel version:

public void learnFromDataSet(Iterable<Sample<FK, FV, Boolean>> dataset){
        for(Sample sample : dataset){
            model.addSample((MapBasedBinarySample<FK, FV>)sample);
        }
        Queue<TreeNode<FK, FV>> Q = new LinkedList<TreeNode<FK, FV>>();
        TreeNode<FK, FV> root = model.selectRootTreeNode();
        model.addTreeNode(root);
        Q.add(root);
        while (!Q.isEmpty()){
            TreeNode v = Q.poll();
            if(v.getDepth() >= model.getMaxDepth()){
                continue;
            }
            FeatureSplit<FK> featureSplit = model.selectFeature(v);
            if(featureSplit.getFeatureId() == null){
                continue;
            }
            v.setFeatureSplit(featureSplit);
            Pair<TreeNode<FK,FV>, TreeNode<FK, FV>> children =
                    model.newTreeNode(v, featureSplit);
            TreeNode leftNode = children.getKey();
            TreeNode rightNode = children.getValue();
            if(leftNode != null
                    && leftNode.getSampleSize() > model.getMinSampleSizeInNode()){
                v.setLeft(leftNode);
                model.addTreeNode(leftNode);
                Q.add(leftNode);
            }
            if(rightNode != null
                    && rightNode.getSampleSize() > model.getMinSampleSizeInNode()){
                v.setRight(rightNode);
                model.addTreeNode(rightNode);
                Q.add(rightNode);
            }
        }
    }

And this is a parallel version:

public class NodeSplitThread implements Runnable{
        private TreeNode<FK, FV> node = null;
        private Queue<TreeNode<FK, FV>> Q = null;

        public NodeSplitThread(TreeNode<FK, FV> node, Queue<TreeNode<FK, FV>> Q){
            this.node = node;
            this.Q = Q;
        }
        @Override
        public void run() {
            if(node.getDepth() >= model.getMaxDepth()){
                return;
            }
            FeatureSplit<FK> featureSplit = model.selectFeature(node);
            if(featureSplit.getFeatureId() == null){
                return;
            }
            node.setFeatureSplit(featureSplit);
            Pair<TreeNode<FK,FV>, TreeNode<FK, FV>> children = model.newTreeNode(node, featureSplit);
            TreeNode<FK, FV> leftNode = children.getKey();
            TreeNode<FK, FV> rightNode = children.getValue();

            if(leftNode != null && leftNode.getSampleSize() > model.getMinSampleSizeInNode()){
                node.setLeft(leftNode);
                model.addTreeNode(leftNode);
                Q.add(leftNode);
            }
            if(rightNode != null && rightNode.getSampleSize() > model.getMinSampleSizeInNode()){
                node.setRight(rightNode);
                model.addTreeNode(rightNode);
                Q.add(rightNode);
            }
        }
    }

    public List<TreeNode<FK, FV>> pollTopN(Queue<TreeNode<FK, FV>> Q, int n){
        List<TreeNode<FK, FV>> ret = new ArrayList<TreeNode<FK, FV>>();
        for(int i = 0; i < n; ++i){
            if(Q.isEmpty()) break;
            TreeNode<FK, FV> node = Q.poll();
            ret.add(node);
        }
        return ret;
    }

    @Override
    public void learnFromDataSet(Iterable<Sample<FK, FV, Boolean>> dataset){

        for(Sample sample : dataset){
            model.addSample((MapBasedBinarySample<FK, FV>)sample);
        }
        Queue<TreeNode<FK, FV>> Q = new ConcurrentLinkedQueue<TreeNode<FK, FV>>();
        TreeNode<FK, FV> root = model.selectRootTreeNode();
        model.addTreeNode(root);
        Q.add(root);
        ExecutorService threadPool = Executors.newFixedThreadPool(10);
        while (!Q.isEmpty()){
            List<TreeNode<FK, FV>> nodes = pollTopN(Q, 10);
            List<Future> tasks = new ArrayList<Future>(nodes.size());
            for(TreeNode<FK, FV> node : nodes){
                Future task = threadPool.submit(new NodeSplitThread(node, Q));
                tasks.add(task);
            }
            for(Future task : tasks){
                try {
                    task.get();
                } catch (InterruptedException e) {
                    continue;
                } catch (ExecutionException e) {
                    continue;
                }
            }
        }
        threadPool.shutdown();
        try {
            threadPool.awaitTermination(60, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            threadPool.shutdownNow();
            Thread.interrupted();
        }
        threadPool.shutdownNow();
    }

 http://xlvector.net/blog/?p=896

原文地址:https://www.cnblogs.com/549294286/p/3270000.html