并发工具CountDownLatch源码分析

CountDownLatch的作用类似于Thread.join()方法,但比join()更加灵活。

它可以等待多个线程(取决于实例化时声明的数量)都达到预期状态或者完成工作以后,通知其他正在等待的线程继续执行。

简单的说,Thread.join()是等待具体的一个线程执行完毕,CountDownLatch等待多个线程。

比如:如果需要统计4个文件中的内容行数,可以用4个线程分别执行,然后用一个线程等待统计结果,最后执行数据汇总。这样场景就适合使用CountDownLatch。

1、CountDownLatch中的内部类

private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);   // 更新AQS中的state
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

其实CountDownLatch的机制和ReentrantLock有点像,都是利用AQS(AbstractQueuedSynchronizer)来实现的。

CountDownLatch的内部类Sync继承AQS,重写了tryAcquireShared()方法和tryReleaseShared()方法。这里的重点是CountDownLatch的构造函数需要传入一个int值count,就是等待的线程数。这个count被Sync用来直接更新为AQS中的state

2、await()等待方法

//CountDownLatch
public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);  //强制获取共享变量
    }
//AQS
public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        if (tryAcquireShared(arg) < 0)            // 1:调用AQS中的tryAcquireShare()方法时,Sync重写了tryAcquireShared()方法,获取state,判断state是否为0;
            doAcquireSharedInterruptibly(arg);    // 2:如果state不为0,则返回-1,调用该方法,将线程加入队列,挂起线程。  
    }
//Sync
@Override
protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }
//AQS
private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

3、countDown()等待方法

public void countDown() {
        sync.releaseShared(1);   //释放共享变量
    }
//AQS
public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }
//Sync
@Override
protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)   //线程数量为0,则表示无法再减,表示所有线程都执行完毕,就唤醒等待队列中的线程;
                    return false;
                int nextc = c-1;    //利用CAS算法将state减1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }

//插入countDown的原理图

举个例子吧:

public class CountDownLatchTest {
    private static CountDownLatch countDownLatch = new CountDownLatch(3);
    private static ThreadPoolExecutor threadPool = new ThreadPoolExecutor(5, 5,
            0L, TimeUnit.MILLISECONDS,
            new LinkedBlockingQueue<Runnable>(10));
    
    public static void main(String[] args) {
        //等待线程
        for (int i = 0; i < 2; i++) {
            String threadName = "等待线程 " + i;
            threadPool.execute(new Runnable() {
                
                @Override
                public void run() {
                    try {
                        System.out.println(threadName + " 正在等待...");
                        //等待
                        countDownLatch.await();
                        System.out.println(threadName + " 结束等待...");
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        //工作线程
        for (int i = 2; i < 5; i++) {
            String threadName = "工作线程 " + i;
            threadPool.execute(new Runnable() {
                
                @Override
                public void run() {
                    try {
                        System.out.println(threadName + " 进入...");
                        //沉睡1秒
                        TimeUnit.MILLISECONDS.sleep(1000);
                        System.out.println(threadName + " 完成...");
                        //通知
                        countDownLatch.countDown();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        
        threadPool.shutdown();
    }
}

执行结果:

等待线程 1 正在等待...
等待线程 0 正在等待...
工作线程 2 进入...
工作线程 3 进入...
工作线程 4 进入...
工作线程 3 完成...
工作线程 2 完成...
工作线程 4 完成...
等待线程 0 结束等待...
等待线程 1 结束等待...

从结果也能看到,等待线程先执行,调用countDownLatch.await()方法开始等待。

每个工作线程工作完成以后,都调用countDownLatch.countDown()方法,告知自己的任务完成。countDownLatch初始参数为3,所以3个工作线程都告知自己结束以后,等待线程才开始工作。

 参考:

https://www.cnblogs.com/sunshine-ground-poems/p/10384453.html

Over......

原文地址:https://www.cnblogs.com/gjmhome/p/14396139.html