diff --git a/src/main/java/io/lettuce/core/protocol/SharedLock.java b/src/main/java/io/lettuce/core/protocol/SharedLock.java index eeb21e28a0..a99f153d0e 100644 --- a/src/main/java/io/lettuce/core/protocol/SharedLock.java +++ b/src/main/java/io/lettuce/core/protocol/SharedLock.java @@ -26,6 +26,8 @@ class SharedLock { private final Lock lock = new ReentrantLock(); + private final ThreadLocal sharedCnt = ThreadLocal.withInitial(() -> 0); + private volatile long writers = 0; private volatile Thread exclusiveLockOwner; @@ -45,6 +47,7 @@ void incrementWriters() { if (WRITERS.get(this) >= 0) { WRITERS.incrementAndGet(this); + sharedCnt.set(sharedCnt.get() + 1); return; } } @@ -63,6 +66,7 @@ void decrementWriters() { } WRITERS.decrementAndGet(this); + sharedCnt.set(sharedCnt.get() - 1); } /** @@ -113,6 +117,7 @@ T doExclusive(Supplier supplier) { private void lockWritersExclusive() { if (exclusiveLockOwner == Thread.currentThread()) { + WRITERS.decrementAndGet(this); return; } @@ -124,6 +129,11 @@ private void lockWritersExclusive() { exclusiveLockOwner = Thread.currentThread(); return; } + // reentrant exclusive lock + if (WRITERS.compareAndSet(this, sharedCnt.get(), -1)) { + exclusiveLockOwner = Thread.currentThread(); + return; + } } } finally { lock.unlock(); @@ -136,7 +146,7 @@ private void lockWritersExclusive() { private void unlockWritersExclusive() { if (exclusiveLockOwner == Thread.currentThread()) { - if (WRITERS.compareAndSet(this, -1, 0)) { + if (WRITERS.compareAndSet(this, -1, sharedCnt.get())) { exclusiveLockOwner = null; } } diff --git a/src/test/java/io/lettuce/core/protocol/SharedLockTest.java b/src/test/java/io/lettuce/core/protocol/SharedLockTest.java new file mode 100644 index 0000000000..01b16c3f7c --- /dev/null +++ b/src/test/java/io/lettuce/core/protocol/SharedLockTest.java @@ -0,0 +1,43 @@ +package io.lettuce.core.protocol; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class SharedLockTest { + + @Test + public void safety_on_reentrant_lock_exclusive_on_writers() throws InterruptedException { + final SharedLock sharedLock = new SharedLock(); + CountDownLatch cnt = new CountDownLatch(1); + try { + // 共享锁 + sharedLock.incrementWriters(); + + String result = sharedLock.doExclusive(() -> { + return sharedLock.doExclusive(() -> { + return "ok"; + }); + }); + if ("ok".equals(result)) { + cnt.countDown(); + } + } finally { + sharedLock.decrementWriters(); + } + + cnt.await(1, TimeUnit.SECONDS); + + // verify writers won't be negative after finally decrementWriters + String result = sharedLock.doExclusive(() -> { + return sharedLock.doExclusive(() -> { + return "ok"; + }); + }); + + Assertions.assertEquals("ok", result); + } + +}