什么是CountDownLatch?

CountDownLatch是Java的一个同步类,用于协作多线程,同时也是一个共享锁

CountDownLatch使用场景

等待所有线程完成任务

主线程等异步线程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
public static void main(String[] args) throws InterruptedException {
CountDownLatch countDownLatch = new CountDownLatch(3);
for (int i = 0; i < 3; i++) {
new MyThread(countDownLatch).start();
}
countDownLatch.await();
System.out.println("3个线程都完成任务了");
}

static class MyThread extends Thread{

CountDownLatch countDownLatch;

public MyThread(CountDownLatch countDownLatch) {
this.countDownLatch = countDownLatch;
}

@Override
public void run() {
try {
// 每个线程任务时间不同
Thread.sleep(1000 + new Random().nextInt(1000));
System.out.println(Thread.currentThread().getName());
countDownLatch.countDown();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}

多线程一起开始

异步线程等主线程的指令

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
public static void main(String[] args) throws InterruptedException {
CountDownLatch countDownLatch = new CountDownLatch(1);
for (int i = 0; i < 3; i++) {
new MyThread(countDownLatch).start();
}
Thread.sleep(1000);
countDownLatch.countDown();
System.out.println("3个线程一起开始");
}

static class MyThread extends Thread{

CountDownLatch countDownLatch;

public MyThread(CountDownLatch countDownLatch) {
this.countDownLatch = countDownLatch;
}

@Override
public void run() {
try {
// 每个线程等待
countDownLatch.await();
System.out.println(Thread.currentThread().getName());
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}

CountDownLatch源码解析

很多子方法在AQS内部有类似或者一样的实现,如果你了解过AQS的源码,那么CountDownLatch对于你而言应该问题不大,只需要理解好自定义同步器的tryAcquireShared()的实现就行,主要意思就是state != 0的时候线程都会被包装成节点加入到同步队列。

构造器

内部实现了自己的同步器,设置state的初值

1
2
3
4
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

await()

让线程等待

1
2
3
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}

acquireSharedInterruptibly()

1
2
3
4
5
6
7
8
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
// 尝试获取共享锁
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}

tryAcquireShared():尝试获取共享锁

1
2
3
4
protected int tryAcquireShared(int acquires) {
// 只要state不是0,都返回-1
return (getState() == 0) ? 1 : -1;
}

doAcquireSharedInterruptibly():让所有的线程都阻塞在同步队列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
// 把节点加入同步队列,设置waitstatus == SHARED
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
// 死循环是为了shouldParkAfterFailedAcquire()返回false,说明是头节点的后继节点真正入队前,还会再次调用tryAcquireShared()尝试获取锁
for (;;) {
// 返回节点的前驱节点,队列为空会抛异常
final Node p = node.predecessor();
// 前驱节点是头节点,尝试获取共享锁
if (p == head) {
int r = tryAcquireShared(arg);
// 返回的结果>=0,说明获取成功,实际上如果不调用countDown()返回的state都是>0的,导致r一直是-1,无法获取锁,全部节点内的线程都会 被阻塞
if (r >= 0) {
// 更改头节点为当前节点,并且释放后续的节点
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
// 当前节点的线程是否应该暂停,只有前驱节点的waitstatus == SIGNAL才会返回true,否则会再次循环尝试获取共享锁
if (shouldParkAfterFailedAcquire(p, node) &&
// shouldParkAfterFailedAcquire返回true,调用LockSupport.park(this)阻塞本节点的线程
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
// for循环内的某一步骤发生了异常,会把当前节点移除出同步队列
if (failed)
cancelAcquire(node);
}
}

countDown()

减少state的值

1
2
3
public void countDown() {
sync.releaseShared(1);
}

releaseShared()

1
2
3
4
5
6
7
8
public final boolean releaseShared(int arg) {
// 尝试释放共享锁
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}

tryReleaseShared():尝试释放共享锁

1
2
3
4
5
6
7
8
9
10
11
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}

doReleaseShared():唤醒后继节点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
private void doReleaseShared() {
// 死循环,直到调用unparkSuccessor(h)唤醒后继节点、或者设置头节点waitstatus == PROPAGATE
for (;;) {
Node h = head;
// 队列不空
if (h != null && h != tail) {
int ws = h.waitStatus;
// 头节点的waitstatus == SIGNAL
if (ws == Node.SIGNAL) {
// 设置头节点的waitstatus == 0
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
// 设置成功就可以唤醒后继节点
unparkSuccessor(h);
}
// 当头节点waitstatus == 0,并且设置头节点waitstatus == PROPAGATE也可以返回
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}

总结

countDownLatch有一个初始的state,只有当state == 0,才会让阻塞在同步队列的节点被逐一释放,对于所有的节点,如果state != 0都无法获取锁,也就是说所有的节点从一开始都会被阻塞。