diff --git a/CHANGES b/CHANGES index 1bc7cef600..e40bd58ae8 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Add an optional lock_name attribute to LockError. * Fix return types for `get`, `set_path` and `strappend` in JSONCommands * Connection.register_connect_callback() is made public. * Fix async `read_response` to use `disable_decoding`. diff --git a/redis/exceptions.py b/redis/exceptions.py index 7cf15a7d07..ddb4041da3 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -81,7 +81,10 @@ class LockError(RedisError, ValueError): "Errors acquiring or releasing a lock" # NOTE: For backwards compatibility, this class derives from ValueError. # This was originally chosen to behave like threading.Lock. - pass + + def __init__(self, message, lock_name=None): + self.message = message + self.lock_name = lock_name class LockNotOwnedError(LockError): diff --git a/redis/lock.py b/redis/lock.py index 4cca102d10..cae7f27ea1 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -157,7 +157,10 @@ def register_scripts(self) -> None: def __enter__(self) -> "Lock": if self.acquire(): return self - raise LockError("Unable to acquire lock within the time specified") + raise LockError( + "Unable to acquire lock within the time specified", + lock_name=self.name, + ) def __exit__( self, @@ -248,7 +251,7 @@ def release(self) -> None: """ expected_token = self.local.token if expected_token is None: - raise LockError("Cannot release an unlocked lock") + raise LockError("Cannot release an unlocked lock", lock_name=self.name) self.local.token = None self.do_release(expected_token) @@ -256,7 +259,10 @@ def do_release(self, expected_token: str) -> None: if not bool( self.lua_release(keys=[self.name], args=[expected_token], client=self.redis) ): - raise LockNotOwnedError("Cannot release a lock that's no longer owned") + raise LockNotOwnedError( + "Cannot release a lock that's no longer owned", + lock_name=self.name, + ) def extend(self, additional_time: int, replace_ttl: bool = False) -> bool: """ @@ -270,9 +276,9 @@ def extend(self, additional_time: int, replace_ttl: bool = False) -> bool: `additional_time`. """ if self.local.token is None: - raise LockError("Cannot extend an unlocked lock") + raise LockError("Cannot extend an unlocked lock", lock_name=self.name) if self.timeout is None: - raise LockError("Cannot extend a lock with no timeout") + raise LockError("Cannot extend a lock with no timeout", lock_name=self.name) return self.do_extend(additional_time, replace_ttl) def do_extend(self, additional_time: int, replace_ttl: bool) -> bool: @@ -284,7 +290,10 @@ def do_extend(self, additional_time: int, replace_ttl: bool) -> bool: client=self.redis, ) ): - raise LockNotOwnedError("Cannot extend a lock that's no longer owned") + raise LockNotOwnedError( + "Cannot extend a lock that's no longer owned", + lock_name=self.name, + ) return True def reacquire(self) -> bool: @@ -292,9 +301,12 @@ def reacquire(self) -> bool: Resets a TTL of an already acquired lock back to a timeout value. """ if self.local.token is None: - raise LockError("Cannot reacquire an unlocked lock") + raise LockError("Cannot reacquire an unlocked lock", lock_name=self.name) if self.timeout is None: - raise LockError("Cannot reacquire a lock with no timeout") + raise LockError( + "Cannot reacquire a lock with no timeout", + lock_name=self.name, + ) return self.do_reacquire() def do_reacquire(self) -> bool: @@ -304,5 +316,8 @@ def do_reacquire(self) -> bool: keys=[self.name], args=[self.local.token, timeout], client=self.redis ) ): - raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned") + raise LockNotOwnedError( + "Cannot reacquire a lock that's no longer owned", + lock_name=self.name, + ) return True diff --git a/tests/test_lock.py b/tests/test_lock.py index 72af87fa81..5c804b426e 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -242,6 +242,13 @@ def test_context_manager_reacquiring_lock_no_longer_owned_raises_error(self, r): with self.get_lock(r, "foo", timeout=10, blocking=False): r.set("foo", "a") + def test_lock_error_gives_correct_lock_name(self, r): + r.set("foo", "bar") + with pytest.raises(LockError) as excinfo: + with self.get_lock(r, "foo", blocking_timeout=0.1): + pass + assert excinfo.value.lock_name == "foo" + class TestLockClassSelection: def test_lock_class_argument(self, r):