diff --git a/alchemiscale/base/client.py b/alchemiscale/base/client.py index 1c2bbca9..be5179af 100644 --- a/alchemiscale/base/client.py +++ b/alchemiscale/base/client.py @@ -8,19 +8,21 @@ import time import random from itertools import islice -from typing import List import json from urllib.parse import urljoin from functools import wraps import gzip +from pathlib import Path +import os +from typing import Union, Optional +from diskcache import Cache import requests import httpx from gufe.tokenization import GufeTokenizable, JSON_HANDLER -from ..models import Scope, ScopedKey -from ..storage.models import TaskHub, Task +from ..models import ScopedKey def json_to_gufe(jsondata): @@ -61,6 +63,8 @@ def __init__( api_url: str, identifier: str, key: str, + cache_directory: Optional[Union[Path, str]] = None, + cache_size_limit: int = 1073741824, max_retries: int = 5, retry_base_seconds: float = 2.0, retry_max_seconds: float = 60.0, @@ -76,6 +80,13 @@ def __init__( Identifier for the identity used for authentication. key Credential for the identity used for authentication. + cache_directory + Location of the cache directory as either a `pathlib.Path` or `str`. + If `None` is provided then the directory will be determined via the + `XDG_CACHE_HOME` environment variable or default to + `${HOME}/.cache/alchemiscale`. Defaults to `None`. + cache_size_limit + Maximum size of the client cache. Defaults to 1 GB. max_retries Maximum number of times to retry a request. In the case the API service is unresponsive an exponential backoff is applied with @@ -111,9 +122,39 @@ def __init__( self._session = None self._lock = None + if cache_size_limit < 0: + raise ValueError( + "`cache_size_limit` must be greater than or equal to zero." + ) + + self._cache = Cache( + self._determine_cache_dir(cache_directory), + size_limit=cache_size_limit, + eviction_policy="least-recently-used", + ) + + @staticmethod + def _determine_cache_dir(cache_directory: Optional[Union[Path, str]]): + if not (isinstance(cache_directory, (Path, str)) or cache_directory is None): + raise TypeError( + "`cache_directory` must be a `str`, `pathlib.Path`, or `None`." + ) + + if cache_directory is None: + default_dir = Path().home() / ".cache" + cache_directory = ( + Path(os.getenv("XDG_CACHE_HOME", default_dir)) / "alchemiscale" + ) + else: + cache_directory = Path(cache_directory) + + return cache_directory.absolute() + def _settings(self): return dict( api_url=self.api_url, + cache_directory=self._cache.directory, + cache_size_limit=self._cache.size_limit, identifier=self.identifier, key=self.key, max_retries=self.max_retries, @@ -361,7 +402,7 @@ def _get_resource(self, resource, params=None, compress=False): if not 200 <= resp.status_code < 300: try: detail = resp.json()["detail"] - except: + except Exception: detail = resp.text raise self._exception( f"Status Code {resp.status_code} : {resp.reason} : {detail}", @@ -396,7 +437,7 @@ async def _get_resource_async(self, resource, params=None, compress=False): if not 200 <= resp.status_code < 300: try: detail = resp.json()["detail"] - except: + except Exception: detail = resp.text raise self._exception( f"Status Code {resp.status_code} : {resp.reason_phrase} : {detail}", @@ -442,7 +483,7 @@ def _post(self, url, headers, data): if not 200 <= resp.status_code < 300: try: detail = resp.json()["detail"] - except: + except Exception: detail = resp.text raise self._exception( f"Status Code {resp.status_code} : {resp.reason} : {detail}", @@ -466,7 +507,7 @@ async def _post_resource_async(self, resource, data): if not 200 <= resp.status_code < 300: try: detail = resp.json()["detail"] - except: + except Exception: detail = resp.text raise self._exception( f"Status Code {resp.status_code} : {resp.reason_phrase} : {detail}", @@ -498,7 +539,6 @@ def _rich_waiting_columns(): @staticmethod def _rich_progress_columns(): from rich.progress import ( - Progress, SpinnerColumn, MofNCompleteColumn, TextColumn, diff --git a/alchemiscale/interface/client.py b/alchemiscale/interface/client.py index 83c06ff8..3d12cef5 100644 --- a/alchemiscale/interface/client.py +++ b/alchemiscale/interface/client.py @@ -493,6 +493,30 @@ def get_chemicalsystem_transformations( f"/chemicalsystems/{chemicalsystem}/transformations" ) + def _get_keyed_chain_resource(self, scopedkey: ScopedKey, get_content_function): + + content = None + + try: + cached_keyed_chain = self._cache.get(str(scopedkey), None).decode("utf-8") + content = json.loads(cached_keyed_chain, cls=JSON_HANDLER.decoder) + # JSON could not decode + except json.JSONDecodeError: + warn( + f"Error decoding cached {scopedkey.qualname} ({scopedkey}), deleting entry and retriving new content." + ) + self._cache.delete(str(scopedkey)) + # when trying to call the decode method with a None (i.e. cached entry not found) + except AttributeError: + pass + + if content is None: + content = get_content_function() + keyedchain_json = json.dumps(content, cls=JSON_HANDLER.encoder) + self._cache.add(str(scopedkey), keyedchain_json.encode("utf-8")) + + return KeyedChain(content).to_gufe() + @lru_cache(maxsize=100) def get_network( self, @@ -522,9 +546,12 @@ def get_network( """ + if isinstance(network, str): + network = ScopedKey.from_str(network) + def _get_network(): content = self._get_resource(f"/networks/{network}", compress=compress) - return KeyedChain(content).to_gufe() + return content if visualize: from rich.progress import Progress @@ -534,12 +561,12 @@ def _get_network(): f"Retrieving [bold]'{network}'[/bold]...", total=None ) - an = _get_network() + an = self._get_keyed_chain_resource(network, _get_network) progress.start_task(task) progress.update(task, total=1, completed=1) else: - an = _get_network() + an = self._get_keyed_chain_resource(network, _get_network) return an @lru_cache(maxsize=10000) @@ -571,11 +598,14 @@ def get_transformation( """ + if isinstance(transformation, str): + transformation = ScopedKey.from_str(transformation) + def _get_transformation(): content = self._get_resource( f"/transformations/{transformation}", compress=compress ) - return KeyedChain(content).to_gufe() + return content if visualize: from rich.progress import Progress @@ -585,11 +615,11 @@ def _get_transformation(): f"Retrieving [bold]'{transformation}'[/bold]...", total=None ) - tf = _get_transformation() + tf = self._get_keyed_chain_resource(transformation, _get_transformation) progress.start_task(task) progress.update(task, total=1, completed=1) else: - tf = _get_transformation() + tf = self._get_keyed_chain_resource(transformation, _get_transformation) return tf @@ -622,11 +652,14 @@ def get_chemicalsystem( """ + if isinstance(chemicalsystem, str): + chemicalsystem = ScopedKey.from_str(chemicalsystem) + def _get_chemicalsystem(): content = self._get_resource( f"/chemicalsystems/{chemicalsystem}", compress=compress ) - return KeyedChain(content).to_gufe() + return content if visualize: from rich.progress import Progress @@ -636,12 +669,12 @@ def _get_chemicalsystem(): f"Retrieving [bold]'{chemicalsystem}'[/bold]...", total=None ) - cs = _get_chemicalsystem() + cs = self._get_keyed_chain_resource(chemicalsystem, _get_chemicalsystem) progress.start_task(task) progress.update(task, total=1, completed=1) else: - cs = _get_chemicalsystem() + cs = self._get_keyed_chain_resource(chemicalsystem, _get_chemicalsystem) return cs @@ -1355,12 +1388,22 @@ def get_tasks_priority( async def _async_get_protocoldagresult( self, protocoldagresultref, transformation, route, compress ): - pdr_latin1_decoded = await self._get_resource_async( - f"/transformations/{transformation}/{route}/{protocoldagresultref}", - compress=compress, - ) + # check the disk cache for the PDR + if pdr_bytes := self._cache.get(str(protocoldagresultref)): + pass + else: + # query the alchemiscale server for the PDR + pdr_latin1_decoded = await self._get_resource_async( + f"/transformations/{transformation}/{route}/{protocoldagresultref}", + compress=compress, + ) + pdr_bytes = pdr_latin1_decoded[0].encode("latin-1") - pdr_bytes = pdr_latin1_decoded[0].encode("latin-1") + # add the resulting PDR to the cache + self._cache.add( + str(protocoldagresultref), + pdr_bytes, + ) try: # Attempt to decompress the ProtocolDAGResult object diff --git a/alchemiscale/tests/integration/interface/client/conftest.py b/alchemiscale/tests/integration/interface/client/conftest.py index 7364b5a6..d633855a 100644 --- a/alchemiscale/tests/integration/interface/client/conftest.py +++ b/alchemiscale/tests/integration/interface/client/conftest.py @@ -1,12 +1,10 @@ import pytest from copy import copy -from time import sleep import uvicorn -import requests from alchemiscale.settings import get_base_api_settings -from alchemiscale.base.api import get_n4js_depends, get_s3os_depends +from alchemiscale.base.api import get_s3os_depends from alchemiscale.interface import api, client from alchemiscale.tests.integration.interface.utils import get_user_settings_override @@ -47,13 +45,25 @@ def uvicorn_server(user_api): yield +@pytest.fixture(scope="session") +def cache_dir(tmp_path_factory): + cache_dir = tmp_path_factory.mktemp("alchemiscale-cache") + return cache_dir + + @pytest.fixture(scope="module") -def user_client(uvicorn_server, user_identity): - return client.AlchemiscaleClient( +def user_client(uvicorn_server, user_identity, cache_dir): + + test_client = client.AlchemiscaleClient( api_url="http://127.0.0.1:8000/", identifier=user_identity["identifier"], key=user_identity["key"], + cache_directory=cache_dir, + cache_size_limit=int(1073741824 / 4), ) + test_client._cache.stats(enable=True, reset=True) + + return test_client @pytest.fixture(scope="module") diff --git a/alchemiscale/tests/integration/interface/client/test_client.py b/alchemiscale/tests/integration/interface/client/test_client.py index 09d1353b..9d13d4c7 100644 --- a/alchemiscale/tests/integration/interface/client/test_client.py +++ b/alchemiscale/tests/integration/interface/client/test_client.py @@ -27,6 +27,40 @@ class TestClient: + def test_cache_size_limit_negative( + self, user_client: client.AlchemiscaleBaseClient + ): + settings = user_client._settings() + settings["cache_size_limit"] = -1 + with pytest.raises( + ValueError, + match="`cache_size_limit` must be greater than or equal to zero.", + ): + client.AlchemiscaleClient(**settings) + + def test_cache_dir_not_path_str_none(self, user_client: client.AlchemiscaleClient): + settings = user_client._settings() + settings["cache_directory"] = 0 + with pytest.raises( + TypeError, + match="`cache_directory` must be a `str`, `pathlib.Path`, or `None`.", + ): + client.AlchemiscaleClient(**settings) + + # here we test the AlchemiscaleClient._determine_cache_dir + # so we don't create non-temporary files on the testing platform + def test_cache_dir_none(self): + # set custom XDG_CACHE_HOME + target_dir = Path().home() / ".other_cache" + os.environ["XDG_CACHE_HOME"] = str(target_dir) + cache_dir = client.AlchemiscaleClient._determine_cache_dir(None) + assert cache_dir == target_dir.absolute() / "alchemiscale" + + # remove the env variable to get the default directory location + os.environ.pop("XDG_CACHE_HOME", None) + cache_dir = client.AlchemiscaleClient._determine_cache_dir(None) + assert cache_dir == Path().home() / ".cache" / "alchemiscale" + def test_wrong_credential( self, scope_test, @@ -517,6 +551,49 @@ def test_get_network( assert an == network_tyk2 assert an is network_tyk2 + def test_cached_network( + self, + scope_test, + n4js_preloaded, + network_tyk2, + user_client: client.AlchemiscaleClient, + ): + # clear both the on-disk and in-memory cache + user_client._cache.clear() + user_client._cache.stats(reset=True) + user_client.get_network.cache_clear() + + an_sk = user_client.get_scoped_key(network_tyk2, scope_test) + + # reset stats of cache + assert user_client._cache.stats(enable=True, reset=True) == (0, 0) + + # expect a miss and entry in the cache + user_client.get_network(an_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + + # expect the in-memory lru cache to get the last result pulled + user_client.get_network(an_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + # clear in-memory cache + user_client.get_network.cache_clear() + + # expect a hit + user_client.get_network(an_sk) + assert user_client._cache.stats() == (1, 1) and len(user_client._cache) == 1 + + user_client.get_network.cache_clear() + + # manually invalidate the cached network so it won't deserialize + cached_bytes = user_client._cache.get(str(an_sk)) + corrupted_bytes = cached_bytes.replace(b":", b";") + user_client._cache.set(str(an_sk), corrupted_bytes) + with pytest.warns(UserWarning, match=f"Error decoding cached {an_sk.qualname}"): + user_client.get_network(an_sk) + + new_cached_bytes = user_client._cache.get(str(an_sk)) + assert new_cached_bytes != corrupted_bytes + def test_get_network_bad_network_key( self, scope_test: Scope, @@ -672,6 +749,49 @@ def test_get_transformation( assert tf == transformation assert tf is transformation + def test_cached_transformation( + self, + scope_test, + n4js_preloaded, + transformation, + user_client: client.AlchemiscaleClient, + ): + # clear both the on-disk and in-memory cache + user_client._cache.clear() + user_client._cache.stats(reset=True) + user_client.get_transformation.cache_clear() + + tf_sk = user_client.get_scoped_key(transformation, scope_test) + + # reset stats of cache + assert user_client._cache.stats(enable=True, reset=True) == (0, 0) + + # expect a miss and entry in the cache + user_client.get_transformation(tf_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + + # expect the in-memory lru cache to get the last result pulled + user_client.get_transformation(tf_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + # clear in-memory cache + user_client.get_transformation.cache_clear() + + # expect a hit + user_client.get_transformation(tf_sk) + assert user_client._cache.stats() == (1, 1) and len(user_client._cache) == 1 + + user_client.get_transformation.cache_clear() + + # manually invalidate the cached transformation so it won't deserialize + cached_bytes = user_client._cache.get(str(tf_sk)) + corrupted_bytes = cached_bytes.replace(b":", b";") + user_client._cache.set(str(tf_sk), corrupted_bytes) + with pytest.warns(UserWarning, match=f"Error decoding cached {tf_sk.qualname}"): + user_client.get_transformation(tf_sk) + + new_cached_bytes = user_client._cache.get(str(tf_sk)) + assert new_cached_bytes != corrupted_bytes + def test_get_transformation_bad_transformation_key( self, scope_test, n4js_preloaded, user_client ): @@ -698,6 +818,49 @@ def test_get_chemicalsystem( assert cs == chemicalsystem assert cs is chemicalsystem + def test_cached_chemicalsystem( + self, + scope_test, + n4js_preloaded, + chemicalsystem, + user_client: client.AlchemiscaleClient, + ): + # clear both the on-disk and in-memory cache + user_client._cache.clear() + user_client._cache.stats(reset=True) + user_client.get_chemicalsystem.cache_clear() + + cs_sk = user_client.get_scoped_key(chemicalsystem, scope_test) + + # reset stats of cache + assert user_client._cache.stats(enable=True, reset=True) == (0, 0) + + # expect a miss and entry in the cache + user_client.get_chemicalsystem(cs_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + + # expect the in-memory lru cache to get the last result pulled + user_client.get_chemicalsystem(cs_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + # clear in-memory cache + user_client.get_chemicalsystem.cache_clear() + + # expect a hit + user_client.get_chemicalsystem(cs_sk) + assert user_client._cache.stats() == (1, 1) and len(user_client._cache) == 1 + + user_client.get_chemicalsystem.cache_clear() + + # manually invalidate the cached ChemicalSystem so it won't deserialize + cached_bytes = user_client._cache.get(str(cs_sk)) + corrupted_bytes = cached_bytes.replace(b":", b";") + user_client._cache.set(str(cs_sk), corrupted_bytes) + with pytest.warns(UserWarning, match=f"Error decoding cached {cs_sk.qualname}"): + user_client.get_chemicalsystem(cs_sk) + + new_cached_bytes = user_client._cache.get(str(cs_sk)) + assert new_cached_bytes != corrupted_bytes + def test_get_chemicalsystem_bad_chemicalsystem_key( self, scope_test, n4js_preloaded, user_client ): @@ -1067,7 +1230,7 @@ def test_get_scope_status( other_scope = Scope("other_org", "other_campaign", "other_project") n4js_preloaded.assemble_network(network_tyk2, other_scope) other_tf_sk = n4js_preloaded.query_transformations(scope=other_scope)[0] - _ = n4js_preloaded.create_task(other_tf_sk) + n4js_preloaded.create_task(other_tf_sk) # ask for the scope that we don't have access to status_counts = user_client.get_scope_status(other_scope) @@ -1877,6 +2040,52 @@ def _execute_tasks(tasks, n4js, s3os_server): return protocoldagresults + def test_cached_pdr( + self, scope_test, n4js_preloaded, s3os_server, user_client, network_tyk2, tmpdir + ): + + user_client._cache.clear() + user_client._cache.stats(reset=True) + user_client._async_get_protocoldagresult.cache_clear() + + network_sk = user_client.get_scoped_key(network_tyk2, scope_test) + + transformation = list(t for t in network_tyk2.edges if "_solvent" in t.name)[0] + transformation_sk = user_client.get_scoped_key(transformation, scope_test) + + user_client.create_tasks(transformation_sk, count=3) + + all_tasks = user_client.get_transformation_tasks(transformation_sk) + actioned_tasks = user_client.action_tasks(all_tasks, network_sk) + + # execute the actioned tasks and push results directly using statestore and object store + with tmpdir.as_cwd(): + self._execute_tasks(actioned_tasks, n4js_preloaded, s3os_server) + + # make sure that we have reset all stats tracking before the intial pull + assert user_client._cache.stats(reset=True) == (0, 0) + + user_client.get_transformation_results(transformation_sk) + + # we expect four misses, but now the cache has length 4 + # this is because the cache also captures the transformation, not just the PDRs + assert user_client._cache.stats() == (0, 4) and len(user_client._cache) == 4 + + # clear the in-memory lru cache, to ensure we check the on-disk cache + user_client._async_get_protocoldagresult.cache_clear() + user_client.get_transformation.cache_clear() + + # running again should now pull results from the on-disk cache + user_client.get_transformation_results(transformation_sk) + + assert user_client._cache.stats() == (4, 4) and len(user_client._cache) == 4 + + # when the alru is not cleared, we should not see misses or hits on the disk cache + # since the alru should populate from the results found on disk + user_client.get_transformation_results(transformation_sk) + + assert user_client._cache.stats() == (4, 4) and len(user_client._cache) == 4 + @staticmethod def _push_result(task_scoped_key, protocoldagresult, n4js, s3os_server): transformation_sk, _ = n4js.get_task_transformation( diff --git a/devtools/conda-envs/alchemiscale-client.yml b/devtools/conda-envs/alchemiscale-client.yml index 1743ba49..2d2caa74 100644 --- a/devtools/conda-envs/alchemiscale-client.yml +++ b/devtools/conda-envs/alchemiscale-client.yml @@ -13,6 +13,7 @@ dependencies: - requests - click - httpx + - diskcache - pydantic >2 - pydantic-settings - async-lru diff --git a/devtools/conda-envs/test.yml b/devtools/conda-envs/test.yml index f1b21162..9304cc0c 100644 --- a/devtools/conda-envs/test.yml +++ b/devtools/conda-envs/test.yml @@ -12,6 +12,7 @@ dependencies: - pydantic >2 - pydantic-settings - async-lru + - diskcache - zstandard ## state store