From 1ae9c12a973772146f2255ecb922415b3cd0ff5e Mon Sep 17 00:00:00 2001 From: Taylor Hakes Date: Mon, 17 Sep 2018 16:38:19 -0400 Subject: [PATCH 1/4] Update behavior of lock to behave closer to redis lock --- fakeredis.py | 26 ++++++++++++++++++++------ test_fakeredis.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/fakeredis.py b/fakeredis.py index 921e3ff..c2e1a21 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -8,15 +8,15 @@ from datetime import datetime, timedelta import operator import sys -import threading import time import types import re import functools from itertools import count, islice +from uuid import uuid4 import redis -from redis.exceptions import ResponseError +from redis.exceptions import ResponseError, LockError import redis.client try: @@ -315,8 +315,8 @@ class _Lock(object): def __init__(self, redis, name, timeout): self.redis = redis self.name = name - self.lock = threading.Lock() - redis.set(name, self, ex=timeout) + self.timeout = timeout + self.id = None def __enter__(self): self.acquire() @@ -326,11 +326,25 @@ def __exit__(self, exc_type, exc_value, traceback): self.release() def acquire(self, blocking=True, blocking_timeout=None): - return self.lock.acquire(blocking) + token = str(uuid4()) + acquired = bool(self.redis.set(self.name, token, nx=True, ex=self.timeout)) + if not acquired and blocking: + raise ValueError('fakeredis can\'t do blocking locks') + + if acquired: + self.id = token + + return acquired def release(self): - self.lock.release() + if self.id is None: + raise LockError("Cannot release an unlocked lock") + + if _decode(self.redis.get(self.name)) != self.id: + raise LockError("Cannot extend a lock that's no longer owned") + self.redis.delete(self.name) + self.id = None def _check_conn(func): diff --git a/test_fakeredis.py b/test_fakeredis.py index 715559c..2ccdf13 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -3916,6 +3916,38 @@ def test_lock(self): self.assertTrue(self.redis.exists('bar')) self.assertFalse(self.redis.exists('bar')) + def test_acquiring_lock_twice(self): + lock = self.redis.lock('foo') + self.assertTrue(lock.acquire(blocking=False)) + self.assertFalse(lock.acquire(blocking=False)) + + def test_acquiring_lock_different_lock(self): + lock1 = self.redis.lock('foo') + lock2 = self.redis.lock('foo') + self.assertTrue(lock1.acquire(blocking=False)) + self.assertFalse(lock2.acquire(blocking=False)) + + def test_acquiring_lock_different_lock_release(self): + lock1 = self.redis.lock('foo') + lock2 = self.redis.lock('foo') + self.assertTrue(lock1.acquire(blocking=False)) + self.assertFalse(lock2.acquire(blocking=False)) + + # Test only releasing lock1 actually releases the lock + with self.assertRaises(redis.exceptions.LockError): + lock2.release() + self.assertFalse(lock2.acquire(blocking=False)) + lock1.release() + + # Locking with lock2 now has the lock + self.assertTrue(lock2.acquire(blocking=False)) + self.assertFalse(lock1.acquire(blocking=False)) + + def test_nested_lock(self): + with self.redis.lock('bar'): + acquired = self.redis.lock('bar').acquire(blocking=False) + self.assertFalse(acquired) + class DecodeMixin(object): decode_responses = True From 7e82bce46a878240106c27aeccab0a03db956d70 Mon Sep 17 00:00:00 2001 From: Taylor Hakes Date: Mon, 17 Sep 2018 17:21:01 -0400 Subject: [PATCH 2/4] Added test for no longer owned lock --- test_fakeredis.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test_fakeredis.py b/test_fakeredis.py index 2ccdf13..0c03831 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -3948,6 +3948,13 @@ def test_nested_lock(self): acquired = self.redis.lock('bar').acquire(blocking=False) self.assertFalse(acquired) + def test_lock_no_longer_owned(self): + lock = self.redis.lock('bar') + lock.acquire() + self.redis.delete('bar') + with self.assertRaises(redis.exceptions.LockError): + lock.release() + class DecodeMixin(object): decode_responses = True From 60ddfe0eddeb3a9a8a79086085e81203166a28ca Mon Sep 17 00:00:00 2001 From: Taylor Hakes Date: Tue, 18 Sep 2018 09:04:57 -0400 Subject: [PATCH 3/4] Fixed error message on release --- fakeredis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fakeredis.py b/fakeredis.py index c2e1a21..8093437 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -341,7 +341,7 @@ def release(self): raise LockError("Cannot release an unlocked lock") if _decode(self.redis.get(self.name)) != self.id: - raise LockError("Cannot extend a lock that's no longer owned") + raise LockError("Cannot release a lock that's no longer owned") self.redis.delete(self.name) self.id = None From 4a24a1f00eceb00d43ef4484f3ddff198d6299f2 Mon Sep 17 00:00:00 2001 From: Taylor Hakes Date: Tue, 18 Sep 2018 12:27:40 -0400 Subject: [PATCH 4/4] Add thread safety to lock --- fakeredis.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/fakeredis.py b/fakeredis.py index 8093437..4b91314 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -12,6 +12,7 @@ import types import re import functools +import threading from itertools import count, islice from uuid import uuid4 @@ -311,6 +312,9 @@ def _compile_pattern(pattern): return re.compile(regex, re.S) +threading_lock = threading.Lock() + + class _Lock(object): def __init__(self, redis, name, timeout): self.redis = redis @@ -327,7 +331,8 @@ def __exit__(self, exc_type, exc_value, traceback): def acquire(self, blocking=True, blocking_timeout=None): token = str(uuid4()) - acquired = bool(self.redis.set(self.name, token, nx=True, ex=self.timeout)) + with threading_lock: + acquired = bool(self.redis.set(self.name, token, nx=True, ex=self.timeout)) if not acquired and blocking: raise ValueError('fakeredis can\'t do blocking locks') @@ -340,11 +345,12 @@ def release(self): if self.id is None: raise LockError("Cannot release an unlocked lock") - if _decode(self.redis.get(self.name)) != self.id: - raise LockError("Cannot release a lock that's no longer owned") + with threading_lock: + if _decode(self.redis.get(self.name)) != self.id: + raise LockError("Cannot release a lock that's no longer owned") - self.redis.delete(self.name) - self.id = None + self.redis.delete(self.name) + self.id = None def _check_conn(func):