diff --git a/docs/conf.py b/docs/conf.py index a178ad7e..fefe4d50 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -55,6 +55,7 @@ ("py:class", "Optional"), ("py:class", "Tuple"), ("py:class", "Union"), + ("py:class", "redis.Redis"), ] # Add any Sphinx extension module names here, as strings. They can be diff --git a/setup.py b/setup.py index 239c8b7d..a4b6698c 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,10 @@ def main(): readme_filename = os.path.join(package_root, "README.md") with io.open(readme_filename, encoding="utf-8") as readme_file: readme = readme_file.read() - dependencies = ["google-cloud-datastore >= 1.7.0"] + dependencies = [ + "google-cloud-datastore >= 1.7.0", + "redis", + ] setuptools.setup( name="google-cloud-ndb", diff --git a/src/google/cloud/ndb/global_cache.py b/src/google/cloud/ndb/global_cache.py index 987b35b8..7cb698cc 100644 --- a/src/google/cloud/ndb/global_cache.py +++ b/src/google/cloud/ndb/global_cache.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""GlobalCache interface and its implementations.""" + import abc +import collections +import os import time +import uuid -"""GlobalCache interface and its implementations.""" +import redis as redis_module class GlobalCache(abc.ABC): @@ -160,3 +165,110 @@ def compare_and_swap(self, items, expires=None): current_value = self.cache.get(key) if watch_value == current_value: self.cache[key] = (new_value, expires) + + +_Pipeline = collections.namedtuple("_Pipeline", ("pipe", "id")) + + +class RedisCache(GlobalCache): + """Redis implementation of the :class:`GlobalCache`. + + This is a synchronous implementation. The idea is that calls to Redis + should be fast enough not to warrant the added complexity of an + asynchronous implementation. + + Args: + redis (redis.Redis): Instance of Redis client to use. + """ + + @classmethod + def from_environment(cls): + """Generate a class:`RedisCache` from an environment variable. + + This class method looks for the ``REDIS_CACHE_URL`` environment + variable and, if it is set, passes its value to ``Redis.from_url`` to + construct a ``Redis`` instance which is then used to instantiate a + ``RedisCache`` instance. + + Returns: + Optional[RedisCache]: A :class:`RedisCache` instance or + :data:`None`, if ``REDIS_CACHE_URL`` is not set in the + environment. + """ + url = os.environ.get("REDIS_CACHE_URL") + if url: + return cls(redis_module.Redis.from_url(url)) + + def __init__(self, redis): + self.redis = redis + self.pipes = {} + + def get(self, keys): + """Implements :meth:`GlobalCache.get`.""" + res = self.redis.mget(keys) + return res + + def set(self, items, expires=None): + """Implements :meth:`GlobalCache.set`.""" + self.redis.mset(items) + if expires: + for key in items.keys(): + self.redis.expire(key, expires) + + def delete(self, keys): + """Implements :meth:`GlobalCache.delete`.""" + self.redis.delete(*keys) + + def watch(self, keys): + """Implements :meth:`GlobalCache.watch`.""" + pipe = self.redis.pipeline() + pipe.watch(*keys) + holder = _Pipeline(pipe, str(uuid.uuid4())) + for key in keys: + self.pipes[key] = holder + + def compare_and_swap(self, items, expires=None): + """Implements :meth:`GlobalCache.compare_and_swap`.""" + pipes = {} + mappings = {} + results = {} + remove_keys = [] + + # get associated pipes + for key, value in items.items(): + remove_keys.append(key) + if key not in self.pipes: + continue + + pipe = self.pipes[key] + pipes[pipe.id] = pipe + mapping = mappings.setdefault(pipe.id, {}) + mapping[key] = value + + # execute transaction for each pipes + for pipe_id, mapping in mappings.items(): + pipe = pipes[pipe_id].pipe + try: + pipe.multi() + pipe.mset(mapping) + if expires: + for key in mapping.keys(): + pipe.expire(key, expires) + pipe.execute() + + except redis_module.exceptions.WatchError: + pass + + finally: + pipe.reset() + + # get keys associated to pipes but not updated + for key, pipe in self.pipes.items(): + if pipe.id in pipes: + remove_keys.append(key) + + # remote keys + for key in remove_keys: + self.pipes.pop(key, None) + + return results diff --git a/tests/system/test_crud.py b/tests/system/test_crud.py index 0816b62e..cf1fe766 100644 --- a/tests/system/test_crud.py +++ b/tests/system/test_crud.py @@ -18,6 +18,7 @@ import datetime import functools import operator +import os import threading from unittest import mock @@ -32,6 +33,8 @@ from tests.system import KIND, eventually +USE_REDIS_CACHE = bool(os.environ.get("REDIS_CACHE_URL")) + def _equals(n): return functools.partial(operator.eq, n) @@ -110,6 +113,40 @@ class SomeKind(ndb.Model): assert entity.baz == "night" +@pytest.mark.skipif(not USE_REDIS_CACHE, reason="Redis is not configured") +def test_retrieve_entity_with_redis_cache(ds_entity, client_context): + entity_id = test_utils.system.unique_resource_id() + ds_entity(KIND, entity_id, foo=42, bar="none", baz=b"night") + + class SomeKind(ndb.Model): + foo = ndb.IntegerProperty() + bar = ndb.StringProperty() + baz = ndb.StringProperty() + + global_cache = global_cache_module.RedisCache.from_environment() + with client_context.new(global_cache=global_cache).use() as context: + context.set_global_cache_policy(None) # Use default + + key = ndb.Key(KIND, entity_id) + entity = key.get() + assert isinstance(entity, SomeKind) + assert entity.foo == 42 + assert entity.bar == "none" + assert entity.baz == "night" + + cache_key = _cache.global_cache_key(key._key) + assert global_cache.redis.get(cache_key) is not None + + patch = mock.patch("google.cloud.ndb._datastore_api._LookupBatch.add") + patch.side_effect = Exception("Shouldn't call this") + with patch: + entity = key.get() + assert isinstance(entity, SomeKind) + assert entity.foo == 42 + assert entity.bar == "none" + assert entity.baz == "night" + + @pytest.mark.usefixtures("client_context") def test_retrieve_entity_not_found(ds_entity): entity_id = test_utils.system.unique_resource_id() @@ -316,6 +353,37 @@ class SomeKind(ndb.Model): dispose_of(key._key) +@pytest.mark.skipif(not USE_REDIS_CACHE, reason="Redis is not configured") +def test_insert_entity_with_redis_cache(dispose_of, client_context): + class SomeKind(ndb.Model): + foo = ndb.IntegerProperty() + bar = ndb.StringProperty() + + global_cache = global_cache_module.RedisCache.from_environment() + with client_context.new(global_cache=global_cache).use() as context: + context.set_global_cache_policy(None) # Use default + + entity = SomeKind(foo=42, bar="none") + key = entity.put() + cache_key = _cache.global_cache_key(key._key) + assert global_cache.redis.get(cache_key) is None + + retrieved = key.get() + assert retrieved.foo == 42 + assert retrieved.bar == "none" + + assert global_cache.redis.get(cache_key) is not None + + entity.foo = 43 + entity.put() + + # This is py27 behavior. I can see a case being made for caching the + # entity on write rather than waiting for a subsequent lookup. + assert global_cache.redis.get(cache_key) is None + + dispose_of(key._key) + + @pytest.mark.usefixtures("client_context") def test_update_entity(ds_entity): entity_id = test_utils.system.unique_resource_id() @@ -453,6 +521,31 @@ class SomeKind(ndb.Model): assert cache_dict[cache_key][0] == b"0" +@pytest.mark.skipif(not USE_REDIS_CACHE, reason="Redis is not configured") +def test_delete_entity_with_redis_cache(ds_entity, client_context): + entity_id = test_utils.system.unique_resource_id() + ds_entity(KIND, entity_id, foo=42) + + class SomeKind(ndb.Model): + foo = ndb.IntegerProperty() + + key = ndb.Key(KIND, entity_id) + cache_key = _cache.global_cache_key(key._key) + global_cache = global_cache_module.RedisCache.from_environment() + + with client_context.new(global_cache=global_cache).use(): + assert key.get().foo == 42 + assert global_cache.redis.get(cache_key) is not None + + assert key.delete() is None + assert global_cache.redis.get(cache_key) is None + + # This is py27 behavior. Not entirely sold on leaving _LOCKED value for + # Datastore misses. + assert key.get() is None + assert global_cache.redis.get(cache_key) == b"0" + + @pytest.mark.usefixtures("client_context") def test_delete_entity_in_transaction(ds_entity): entity_id = test_utils.system.unique_resource_id() diff --git a/tests/unit/test_global_cache.py b/tests/unit/test_global_cache.py index ffd6409a..53b1535e 100644 --- a/tests/unit/test_global_cache.py +++ b/tests/unit/test_global_cache.py @@ -15,6 +15,7 @@ from unittest import mock import pytest +import redis as redis_module from google.cloud.ndb import global_cache @@ -144,3 +145,157 @@ def test_watch_compare_and_swap_with_expires(time): result = cache.get([b"one", b"two", b"three"]) assert result == [None, b"hamburgers", None] + + +class TestRedisCache: + @staticmethod + def test_constructor(): + redis = object() + cache = global_cache.RedisCache(redis) + assert cache.redis is redis + + @staticmethod + @mock.patch("google.cloud.ndb.global_cache.redis_module") + def test_from_environment(redis_module): + redis = redis_module.Redis.from_url.return_value + with mock.patch.dict("os.environ", {"REDIS_CACHE_URL": "some://url"}): + cache = global_cache.RedisCache.from_environment() + assert cache.redis is redis + redis_module.Redis.from_url.assert_called_once_with("some://url") + + @staticmethod + def test_from_environment_not_configured(): + with mock.patch.dict("os.environ", {"REDIS_CACHE_URL": ""}): + cache = global_cache.RedisCache.from_environment() + assert cache is None + + @staticmethod + def test_get(): + redis = mock.Mock(spec=("mget",)) + cache_keys = [object(), object()] + cache_value = redis.mget.return_value + cache = global_cache.RedisCache(redis) + assert cache.get(cache_keys) is cache_value + redis.mget.assert_called_once_with(cache_keys) + + @staticmethod + def test_set(): + redis = mock.Mock(spec=("mset",)) + cache_items = {"a": "foo", "b": "bar"} + cache = global_cache.RedisCache(redis) + cache.set(cache_items) + redis.mset.assert_called_once_with(cache_items) + + @staticmethod + def test_set_w_expires(): + expired = {} + + def mock_expire(key, expires): + expired[key] = expires + + redis = mock.Mock(expire=mock_expire, spec=("mset", "expire")) + cache_items = {"a": "foo", "b": "bar"} + cache = global_cache.RedisCache(redis) + cache.set(cache_items, expires=32) + redis.mset.assert_called_once_with(cache_items) + assert expired == {"a": 32, "b": 32} + + @staticmethod + def test_delete(): + redis = mock.Mock(spec=("delete",)) + cache_keys = [object(), object()] + cache = global_cache.RedisCache(redis) + cache.delete(cache_keys) + redis.delete.assert_called_once_with(*cache_keys) + + @staticmethod + @mock.patch("google.cloud.ndb.global_cache.uuid") + def test_watch(uuid): + uuid.uuid4.return_value = "abc123" + redis = mock.Mock( + pipeline=mock.Mock(spec=("watch",)), spec=("pipeline",) + ) + pipe = redis.pipeline.return_value + keys = ["foo", "bar"] + cache = global_cache.RedisCache(redis) + cache.watch(keys) + + pipe.watch.assert_called_once_with("foo", "bar") + assert cache.pipes == { + "foo": global_cache._Pipeline(pipe, "abc123"), + "bar": global_cache._Pipeline(pipe, "abc123"), + } + + @staticmethod + def test_compare_and_swap(): + redis = mock.Mock(spec=()) + cache = global_cache.RedisCache(redis) + pipe1 = mock.Mock(spec=("multi", "mset", "execute", "reset")) + pipe2 = mock.Mock(spec=("multi", "mset", "execute", "reset")) + cache.pipes = { + "ay": global_cache._Pipeline(pipe1, "abc123"), + "be": global_cache._Pipeline(pipe1, "abc123"), + "see": global_cache._Pipeline(pipe2, "def456"), + "dee": global_cache._Pipeline(pipe2, "def456"), + "whatevs": global_cache._Pipeline(None, "himom!"), + } + pipe2.execute.side_effect = redis_module.exceptions.WatchError + + items = {"ay": "foo", "be": "bar", "see": "baz", "wut": "huh?"} + cache.compare_and_swap(items) + + pipe1.multi.assert_called_once_with() + pipe2.multi.assert_called_once_with() + pipe1.mset.assert_called_once_with({"ay": "foo", "be": "bar"}) + pipe2.mset.assert_called_once_with({"see": "baz"}) + pipe1.execute.assert_called_once_with() + pipe2.execute.assert_called_once_with() + pipe1.reset.assert_called_once_with() + pipe2.reset.assert_called_once_with() + + assert cache.pipes == { + "whatevs": global_cache._Pipeline(None, "himom!") + } + + @staticmethod + def test_compare_and_swap_w_expires(): + expired = {} + + def mock_expire(key, expires): + expired[key] = expires + + redis = mock.Mock(spec=()) + cache = global_cache.RedisCache(redis) + pipe1 = mock.Mock( + expire=mock_expire, + spec=("multi", "mset", "execute", "expire", "reset"), + ) + pipe2 = mock.Mock( + expire=mock_expire, + spec=("multi", "mset", "execute", "expire", "reset"), + ) + cache.pipes = { + "ay": global_cache._Pipeline(pipe1, "abc123"), + "be": global_cache._Pipeline(pipe1, "abc123"), + "see": global_cache._Pipeline(pipe2, "def456"), + "dee": global_cache._Pipeline(pipe2, "def456"), + "whatevs": global_cache._Pipeline(None, "himom!"), + } + pipe2.execute.side_effect = redis_module.exceptions.WatchError + + items = {"ay": "foo", "be": "bar", "see": "baz", "wut": "huh?"} + cache.compare_and_swap(items, expires=32) + + pipe1.multi.assert_called_once_with() + pipe2.multi.assert_called_once_with() + pipe1.mset.assert_called_once_with({"ay": "foo", "be": "bar"}) + pipe2.mset.assert_called_once_with({"see": "baz"}) + pipe1.execute.assert_called_once_with() + pipe2.execute.assert_called_once_with() + pipe1.reset.assert_called_once_with() + pipe2.reset.assert_called_once_with() + + assert cache.pipes == { + "whatevs": global_cache._Pipeline(None, "himom!") + } + assert expired == {"ay": 32, "be": 32, "see": 32}