forkjoin框架

forkjoin框架

一:简介

从JDK1.7开始,Java提供Fork/Join框架用于并行执行任务,它的思想就是讲一个大任务分割成若干小任务,最终汇总每个小任务的结果得到这个大任务的结果。

这种思想和MapReduce很像(input --> split --> map --> reduce --> output)

主要有两步:

  • 第一、任务切分;
  • 第二、结果合并

它的模型大致是这样的:线程池中的每个线程都有自己的工作队列(PS:这一点和ThreadPoolExecutor不同,ThreadPoolExecutor是所有线程公用一个工作队列,所有线程都从这个工作队列中取任务),当自己队列中的任务都完成以后,会从其它线程的工作队列中偷一个任务执行,这样可以充分利用资源。下面盗一张图来宏观展示一下:

img

二:forkjoin中定义的角色

ForkJoinPool:充当fork/join框架里面的管理者,最原始的任务都要交给它才能处理。它负责控制整个fork/join有多少个workerThread,workerThread的创建,激活都是由它来掌控。它还负责workQueue队列的创建和分配,每当创建一个workerThread,为它负责分配相应的workQueue。然后它把接到的活都交给workerThread去处理,它可以说是整个frok/join的容器。

ForkJoinWorkerThread:fork/join里面真正干活的"工人",本质是一个线程。里面有一个ForkJoinPool.WorkQueue的队列存放着它要干的活,接活之前它要向ForkJoinPool注册(registerWorker),拿到相应的workQueue。然后就从workQueue里面拿任务出来处理。它是依附于ForkJoinPool而存活,如果ForkJoinPool的销毁了,它也会跟着结束。

我们所说的forkjoin的工作窃取,那么究竟是怎么窃取的呢?我么你分析一下任务是由workThread来窃取的,workThread是一个线程,线程的执行逻辑都是在run里面,所以任务的窃取逻辑一定在run()中可以找的到。

public void run() { //线程run方法
       if (workQueue.array == null) { // only run once
           Throwable exception = null;
           try {
               onStart();
               pool.runWorker(workQueue);  //在这里处理任务队列!
           } catch (Throwable ex) {
               exception = ex;
           } finally {
               try {
                   onTermination(exception);
               } catch (Throwable ex) {
                   if (exception == null)
                       exception = ex;
               } finally {
                   pool.deregisterWorker(this, exception);
               }
           }
       }
   }
   
   
  final void runWorker(WorkQueue w) {
       w.growArray();                   // allocate queue  进行队列的初始化
       int seed = w.hint;               // initially holds randomization hint
       int r = (seed == 0) ? 1 : seed;  // avoid 0 for xorShift
       for (ForkJoinTask<?> t;;) { //又是无限循环处理任务!
           if ((t = scan(w, r)) != null) //在这里获取任务!
               w.runTask(t);
           else if (!awaitWork(w, r))
               break;
           r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
       }
   }

窃取逻辑的主要代码:(scan)任务的窃取从workerThread运行的那一刻就开始了,先随机选中一条队列看看能不能窃取到下一条队列,如果都窃取不到就返回null。

/**
     * Scans for and tries to steal a top-level task. Scans start at a
     * random location, randomly moving on apparent contention,
     * otherwise continuing linearly until reaching two consecutive
     * empty passes over all queues with the same checksum (summing
     * each base index of each queue, that moves on each steal), at
     * which point the worker tries to inactivate and then re-scans,
     * attempting to re-activate (itself or some other worker) if
     * finding a task; otherwise returning null to await work.  Scans
     * otherwise touch as little memory as possible, to reduce
     * disruption on other scanning threads.
     *
     * @param w the worker (via its WorkQueue)
     * @param r a random seed
     * @return a task, or null if none found
     */
    private ForkJoinTask<?> scan(WorkQueue w, int r) {
        WorkQueue[] ws; int m;
        if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
            int ss = w.scanState;                     // initially non-negative
            for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
                WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
                int b, n; long c;
                if ((q = ws[k]) != null) {   //随机选中了非空队列 q
                    if ((n = (b = q.base) - q.top) < 0 &&
                        (a = q.array) != null) {      // non-empty
                        long i = (((a.length - 1) & b) << ASHIFT) + ABASE;  //从尾部出队,b是尾部下标
                        if ((t = ((ForkJoinTask<?>)
                                  U.getObjectVolatile(a, i))) != null &&
                            q.base == b) {
                            if (ss >= 0) {
                                if (U.compareAndSwapObject(a, i, t, null)) { //利用cas出队
                                    q.base = b + 1;
                                    if (n < -1)       // signal others
                                        signalWork(ws, q);
                                    return t;  //出队成功,成功窃取一个任务!
                                }
                            }
                            else if (oldSum == 0 &&   // try to activate 队列没有激活,尝试激活
                                     w.scanState < 0)
                                tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
                        }
                        if (ss < 0)                   // refresh
                            ss = w.scanState;
                        r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
                        origin = k = r & m;           // move and rescan
                        oldSum = checkSum = 0;
                        continue;
                    }
                    checkSum += b;
                }<br>                //k = k + 1表示取下一个队列 如果(k + 1) & m == origin表示 已经遍历完所有队列了
                if ((k = (k + 1) & m) == origin) {    // continue until stable 
                    if ((ss >= 0 || (ss == (ss = w.scanState))) &&
                        oldSum == (oldSum = checkSum)) {
                        if (ss < 0 || w.qlock < 0)    // already inactive
                            break;
                        int ns = ss | INACTIVE;       // try to inactivate
                        long nc = ((SP_MASK & ns) |
                                   (UC_MASK & ((c = ctl) - AC_UNIT)));
                        w.stackPred = (int)c;         // hold prev stack top
                        U.putInt(w, QSCANSTATE, ns);
                        if (U.compareAndSwapLong(this, CTL, c, nc))
                            ss = ns;
                        else
                            w.scanState = ss;         // back out
                    }
                    checkSum = 0;
                }
            }
        }
        return null;
    }

ForkJoinPool.WorkQueue: 双端队列就是它,它负责存储接收的任务。

ForkJoinTask:代表fork/join里面任务类型,我们一般用它的两个子类RecursiveTask、RecursiveAction。这两个区别在于RecursiveTask任务是有返回值,RecursiveAction没有返回值。任务的处理逻辑包括任务的切分都集中在compute()方法里面。

此外还有fork()方法:在当前线程运行的线程池中安排一个异步执行。简单的理解就是在创建一个子任务。

join()方法:当任务完成的时候返回计算结果。

invoke()方法:开始执行任务,如果必要,等待计算完成。

RecursiveAction ()方法:一个递归无结果的ForkJoinTask(没有返回值)

RecursiveTask 一个递归有结果的ForkJoinTask(有返回值)

三:工作窃取

工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行。工作窃取的运行流程图如下:

img

那么为什么需要使用工作窃取算法呢?

假如我们需要做一个比较大的任务,我们可以把这个任务分割为若干互不依赖的子任务,为了减少线程间的竞争,于是把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应,比如A线程负责处理A队列里的任务。但是有的线程会先把自己队列里的任务干完,而其他线程对应的队列里还有任务等待处理。干完活的线程与其等着,不如去帮其他线程干活,于是它就去其他线程的队列里窃取一个任务来执行。而在这时它们会访问同一个队列,所以为了减少窃取任务线程和被窃取任务线程之间的竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行。

工作窃取算法的优点是充分利用线程进行并行计算,并减少了线程间的竞争,其缺点是在某些情况下还是存在竞争,比如双端队列里只有一个任务时。并且消耗了更多的系统资源,比如创建多个线程和多个双端队列。

使用示例:

package com.duoxiancheng.juc;

import java.util.concurrent.RecursiveTask;

public class ForkJoinDemo extends RecursiveTask<Long> {
    private Long start;
    private Long end;
    /** 临界值 */
    private static final Long temp = 10000L;
    public ForkJoinDemo(Long start,Long end) {
        this.start = start;
        this.end = end;
    }
    
    @Override
    protected Long compute() {
        /** 超过中间值,就分配任务 */
        if(end - start < temp) {
            Long sum = 0L;
            for(Long i = start;i <= end;i++) {
                sum += i;
            }
            return sum;
        } else {
            /** 获取中间值 */
            Long middle = (end + start) / 2;
            ForkJoinDemo right = new ForkJoinDemo(start,middle);
            /** 开启分支计算任务 */
            right.fork();
            ForkJoinDemo left = new ForkJoinDemo(middle+1,end);
            /** 开启分支计算任务 */  
            left.fork();
            /** 合并结果 */
            return right.join() + left.join();
        }
       
       
    }
}

package com.cjs.boot.demo;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

public class ForkJoinTaskDemo {

    private class SumTask extends RecursiveTask<Integer> {

        private static final int THRESHOLD = 20;

        private int arr[];
        private int start;
        private int end;

        public SumTask(int[] arr, int start, int end) {
            this.arr = arr;
            this.start = start;
            this.end = end;
        }

        /**
         * 小计
         */
        private Integer subtotal() {
            Integer sum = 0;
            for (int i = start; i < end; i++) {
                sum += arr[i];
            }
            System.out.println(Thread.currentThread().getName() + ": ∑(" + start + "~" + end + ")=" + sum);
            return sum;
        }

        @Override
        protected Integer compute() {

            if ((end - start) <= THRESHOLD) {
                return subtotal();
            }else {
                int middle = (start + end) / 2;
                SumTask left = new SumTask(arr, start, middle);
                SumTask right = new SumTask(arr, middle, end);
                left.fork();
                right.fork();

                return left.join() + right.join();
            }
        }
    }

    public static void main(String[] args) throws ExecutionException, InterruptedException {
        int[] arr = new int[100];
        for (int i = 0; i < 100; i++) {
            arr[i] = i + 1;
        }
		// 创建一个ForkJoinPool线程池,用来存放任务
        ForkJoinPool pool = new ForkJoinPool();
        // 返回有结果的任务,RecursiveTask
        ForkJoinTask<Integer> result = pool.submit(new ForkJoinTaskDemo().new SumTask(arr, 0, arr.length));
        System.out.println("最终计算结果: " + result.invoke());
        pool.shutdown();
    }

}
参考链接:
https://www.cnblogs.com/cjsblog/p/9078341.html
https://www.cnblogs.com/linlinismine/p/9295701.html
原文地址:https://www.cnblogs.com/clover-forever/p/13526472.html