Java并发编程之6——AQS如何实现CountDownLatch

前面写过一篇文章AQS源码分析的时候介绍过AQS是java并发编程的基础,Java并发包下面提供的同步工具类基本上都是以AQS构建的,一个同步类的实现主要分为三步…

CountDownLatch实现分析

定义同步工具类的步骤

前面写过一篇文章AQS源码分析的时候介绍过AQS是java并发编程的基础,Java并发包下面提供的同步工具类基本上都是以AQS构建的,一个同步类的实现主要分为三步:

  • 该同步工具类 ,定义内部类实现AQS
  • 定义该内部类的构造方法设置同步状态
  • 实现tryAcquire/tryRelease方法

CountDownLatch的实现分析

在CountDownLatch的实现分析中主要分析三个方法:

public CountDownLatch(int count)

这个方法是CountDownLatch的构造方法,它会指定同步器的状态

1
2
3
4
5
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
CountDownLatch的构造方法,在内部实例化了一个继承自AQS的子类。

Sync类的构造函数如下:

1
2
3
4
Sync(int count) {
setState(count);
}
Sync类的构造函数调用了AQS的setState设置同步器的状态

public void countDown()

修改同步状态,使同步器的状态减一。

1
2
3
public void countDown() {
sync.releaseShared(1);
}

countDown方法调用了内部Sync类的releaseShared方法释放共享锁,releaseShare方法作用是,让当前线程释放他持有的共享锁。实现如下:

1
2
3
4
5
6
7
8
9
10
public final boolean releaseShared(int arg) {
tryReleaseShared方法在AQS中强制子类去实现,这里的子类在一个循环中进行CAS操作更新同步状态(对同步状态减一),返回true更改同步状态之后,同步状态为0,等待共享锁的线程可以释放
false表示更改同步状态之后,同步状态大于0,这时候获取共享锁的线程必须等待
if (tryReleaseShared(arg)) {
当同步状态由CAS操作递减到0的时候,就释放头节点的后继节点对应的线程
doReleaseShared();
return true;
}
return false;
}

CountDownLatch中内部类Sync实现的tryReleaseShare方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
protected boolean tryReleaseShared(int releases) {
for (;;) {
int c = getState();
if (c == 0)
同步状态已经为0,直接返回false不用更新,这种状态一般是不合法的
return false;
int nextc = c-1;
CAS将同步状态减一,如果减之后同步状态为0返回true否则返回false。其中true表示更改同步状态后的同步状态为0,这是可以释放等待共享锁的线程。
false表示更改同步状态之后的同步状态大于0,这是获取共享锁的线程必须等待
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}

doReleaseShare是由AQS实现的,详情请参见AQS源码分析这里只列出主要逻辑。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
释放等待共享锁的线程
private void doReleaseShared() {
for (;;) {
获取头结点
Node h = head;
if (h != null && h != tail) {
获取头结点的状态
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
如果CAS更新头结点的状态成功就unparkSuccessor,如果失败就继续循环
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)){
continue;
}
激活头结点的后继节点
unparkSuccessor(h);
}
else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)){ 如果头结点的状态为0并且CAS将状态更改为-3成功就退出循环,否则继续循环
continue;
}
}
if (h == head)
break;
}
}

unparkSuccessor的实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
激活Node节点的后继节点对应的线程
private void unparkSuccessor(Node node) {
int ws = node.waitStatus;
if (ws < 0){
compareAndSetWaitStatus(node, ws, 0);
}
找到Node节点的一个有效的后继节点
Node s = node.next;
if (s == null || s.waitStatus > 0) {
s = null;
for (Node t = tail; t != null && t != node; t = t.prev)
if (t.waitStatus <= 0)
s = t;
}
if (s != null)
LockSupport.unpark(s.thread);
}

public void await()

作用是让当前线程等待。当同步计数器减为0的时候,当前线程被唤醒,继续执行。

1
2
3
4
public void await() throws InterruptedException {
以可响应中断的方式获取共享锁
sync.acquireSharedInterruptibly(1);
}

CountDownLatch的await方法调用的是Sync内部类的acquireShareInterruptibly方法,这里面是直接调用的父类AQS的方法。

1
2
3
4
5
6
7
8
9
10
可响应中断地获取同步状态
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
尝试获取共享锁如果锁计数器为0表示锁是可获取的,那么直接返回,如果锁计数器不为0表示锁是不可获取的,就以当前线程为节点构造等待队列
if (tryAcquireShared(arg) < 0){
doAcquireSharedInterruptibly(arg);
}
}

Sync类实现的tryAcquireShared方法。

1
2
3
4
尝试获取共享锁如果同步状态(锁计数器为0)表示锁是可以获取的返回1,否则返回-1
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

AQS的doAcquireShareInterruptibly。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
以当前线程构造一个等待的Node队列,在Node节点构造完成之后会再次尝试获取一次共享锁 在构造完成之后会再次尝试获取一次锁,
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
以当前线程构造等待的Node队列
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
注意 这里是在一个循环中执行的,第一次执行shouldParkAfterFailedAcquire会失败,因为此时Node的状态是0,但是会进行CAS操作将状态更新为-1,所以第二次循环判断的时候就可以
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg); 当同步状态为0的时候,表示锁是可以获取的此时返回1
if (r >= 0) {//如果锁是可以获取的,就将Node节点设置为头结点
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
判断当前线程是否应该挂起
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}

将当前线程添加到等待队列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
private Node addWaiter(Node mode) {
Node node = new Node(Thread.currentThread(), mode);
Node pred = tail;
if (pred != null) {
node.prev = pred;
if (compareAndSetTail(pred, node)) {
pred.next = node;
return node;
}
}
当队列为空的时候,
enq(node);
return node;
}

判断当前线程是否应该挂起

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
int ws = pred.waitStatus;
if (ws == Node.SIGNAL)
return true;
if (ws > 0) {
前驱节点对应的线程对应的线程被取消,继续往前找到
do {
node.prev = pred = pred.prev;
} while (pred.waitStatus > 0);
pred.next = node;
} else {
CAS更新状态
compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
}
return false;
}