Skip to content

Commit

Permalink
feat: Add Async refresh to Sql Registry (#4251)
Browse files Browse the repository at this point in the history
* Add sql registry async refresh

Signed-off-by: Stanley Opara <[email protected]>

* make refresh code a daemon thread

Signed-off-by: Stanley Opara <[email protected]>

* Change RegistryConfig to cacheMode

Signed-off-by: Stanley Opara <[email protected]>

* Only run async when ttl > 0

Signed-off-by: Stanley Opara <[email protected]>

* make refresh async run in a loop

Signed-off-by: Stanley Opara <[email protected]>

* make refresh async run in a loop

Signed-off-by: Stanley Opara <[email protected]>

* Reorder async refresh call

Signed-off-by: Stanley Opara <[email protected]>

* Add documentation

Signed-off-by: Stanley Opara <[email protected]>

* Update test_universal_registry.py

Signed-off-by: Stanley Opara <[email protected]>

* Force rerun of tests

Signed-off-by: Stanley Opara <[email protected]>

* Force rerun of tests

Signed-off-by: Stanley Opara <[email protected]>

* Format repo config file

Signed-off-by: Stanley Opara <[email protected]>

---------

Signed-off-by: Stanley Opara <[email protected]>
Co-authored-by: Stanley Opara <[email protected]>
  • Loading branch information
stanconia and Stanley Opara authored Jul 9, 2024
1 parent cea52e9 commit f569786
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 32 deletions.
57 changes: 36 additions & 21 deletions sdk/python/feast/infra/registry/caching_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import atexit
import logging
import threading
from abc import abstractmethod
from datetime import timedelta
from threading import Lock
Expand All @@ -21,18 +23,18 @@


class CachingRegistry(BaseRegistry):
def __init__(
self,
project: str,
cache_ttl_seconds: int,
):
def __init__(self, project: str, cache_ttl_seconds: int, cache_mode: str):
self.cached_registry_proto = self.proto()
proto_registry_utils.init_project_metadata(self.cached_registry_proto, project)
self.cached_registry_proto_created = _utc_now()
self._refresh_lock = Lock()
self.cached_registry_proto_ttl = timedelta(
seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0
)
self.cache_mode = cache_mode
if cache_mode == "thread":
self._start_thread_async_refresh(cache_ttl_seconds)
atexit.register(self._exit_handler)

@abstractmethod
def _get_data_source(self, name: str, project: str) -> DataSource:
Expand Down Expand Up @@ -322,22 +324,35 @@ def refresh(self, project: Optional[str] = None):
self.cached_registry_proto_created = _utc_now()

def _refresh_cached_registry_if_necessary(self):
with self._refresh_lock:
expired = (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
_utc_now()
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
if self.cache_mode == "sync":
with self._refresh_lock:
expired = (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
_utc_now()
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
)
)
)
)
if expired:
logger.info("Registry cache expired, so refreshing")
self.refresh()

def _start_thread_async_refresh(self, cache_ttl_seconds):
self.refresh()
if cache_ttl_seconds <= 0:
return
self.registry_refresh_thread = threading.Timer(
cache_ttl_seconds, self._start_thread_async_refresh, [cache_ttl_seconds]
)
self.registry_refresh_thread.setDaemon(True)
self.registry_refresh_thread.start()

if expired:
logger.info("Registry cache expired, so refreshing")
self.refresh()
def _exit_handler(self):
self.registry_refresh_thread.cancel()
4 changes: 3 additions & 1 deletion sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def __init__(
)
metadata.create_all(self.engine)
super().__init__(
project=project, cache_ttl_seconds=registry_config.cache_ttl_seconds
project=project,
cache_ttl_seconds=registry_config.cache_ttl_seconds,
cache_mode=registry_config.cache_mode,
)

def teardown(self):
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class RegistryConfig(FeastBaseModel):
sqlalchemy_config_kwargs: Dict[str, Any] = {}
""" Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """

cache_mode: StrictStr = "sync"
""" str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)"""

@field_validator("path")
def validate_path(cls, path: str, values: ValidationInfo) -> str:
if values.data.get("registry_type") == "sql":
Expand Down
113 changes: 103 additions & 10 deletions sdk/python/tests/integration/registration/test_universal_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def minio_registry() -> Registry:
logger = logging.getLogger(__name__)


@pytest.fixture(scope="session")
@pytest.fixture(scope="function")
def pg_registry():
container = (
DockerContainer("postgres:latest")
Expand All @@ -137,6 +137,35 @@ def pg_registry():

container.start()

registry_config = _given_registry_config_for_pg_sql(container)

yield SqlRegistry(registry_config, "project", None)

container.stop()


@pytest.fixture(scope="function")
def pg_registry_async():
container = (
DockerContainer("postgres:latest")
.with_exposed_ports(5432)
.with_env("POSTGRES_USER", POSTGRES_USER)
.with_env("POSTGRES_PASSWORD", POSTGRES_PASSWORD)
.with_env("POSTGRES_DB", POSTGRES_DB)
)

container.start()

registry_config = _given_registry_config_for_pg_sql(container, 2, "thread")

yield SqlRegistry(registry_config, "project", None)

container.stop()


def _given_registry_config_for_pg_sql(
container, cache_ttl_seconds=2, cache_mode="sync"
):
log_string_to_wait_for = "database system is ready to accept connections"
waited = wait_for_logs(
container=container,
Expand All @@ -148,42 +177,57 @@ def pg_registry():
container_port = container.get_exposed_port(5432)
container_host = container.get_container_host_ip()

registry_config = RegistryConfig(
return RegistryConfig(
registry_type="sql",
cache_ttl_seconds=cache_ttl_seconds,
cache_mode=cache_mode,
# The `path` must include `+psycopg` in order for `sqlalchemy.create_engine()`
# to understand that we are using psycopg3.
path=f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{container_host}:{container_port}/{POSTGRES_DB}",
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
)


@pytest.fixture(scope="function")
def mysql_registry():
container = MySqlContainer("mysql:latest")
container.start()

registry_config = _given_registry_config_for_mysql(container)

yield SqlRegistry(registry_config, "project", None)

container.stop()


@pytest.fixture(scope="session")
def mysql_registry():
@pytest.fixture(scope="function")
def mysql_registry_async():
container = MySqlContainer("mysql:latest")
container.start()

# testing for the database to exist and ready to connect and start testing.
registry_config = _given_registry_config_for_mysql(container, 2, "thread")

yield SqlRegistry(registry_config, "project", None)

container.stop()


def _given_registry_config_for_mysql(container, cache_ttl_seconds=2, cache_mode="sync"):
import sqlalchemy

engine = sqlalchemy.create_engine(
container.get_connection_url(), pool_pre_ping=True
)
engine.connect()

registry_config = RegistryConfig(
return RegistryConfig(
registry_type="sql",
path=container.get_connection_url(),
cache_ttl_seconds=cache_ttl_seconds,
cache_mode=cache_mode,
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
)

yield SqlRegistry(registry_config, "project", None)

container.stop()


@pytest.fixture(scope="session")
def sqlite_registry():
Expand Down Expand Up @@ -269,6 +313,17 @@ def mock_remote_registry():
lazy_fixture("sqlite_registry"),
]

async_sql_fixtures = [
pytest.param(
lazy_fixture("pg_registry_async"),
marks=pytest.mark.xdist_group(name="pg_registry_async"),
),
pytest.param(
lazy_fixture("mysql_registry_async"),
marks=pytest.mark.xdist_group(name="mysql_registry_async"),
),
]


@pytest.mark.integration
@pytest.mark.parametrize("test_registry", all_fixtures)
Expand Down Expand Up @@ -999,6 +1054,44 @@ def test_registry_cache(test_registry):
test_registry.teardown()


@pytest.mark.integration
@pytest.mark.parametrize(
"test_registry",
async_sql_fixtures,
)
def test_registry_cache_thread_async(test_registry):
# Create Feature View
batch_source = FileSource(
name="test_source",
file_format=ParquetFormat(),
path="file://feast/*",
timestamp_field="ts_col",
created_timestamp_column="timestamp",
)

project = "project"

# Register data source
test_registry.apply_data_source(batch_source, project)
registry_data_sources_cached = test_registry.list_data_sources(
project, allow_cache=True
)
# async ttl yet to expire, so there will be a cache miss
assert len(registry_data_sources_cached) == 0

# Wait for cache to be refreshed
time.sleep(4)
# Now objects exist
registry_data_sources_cached = test_registry.list_data_sources(
project, allow_cache=True
)
assert len(registry_data_sources_cached) == 1
registry_data_source = registry_data_sources_cached[0]
assert registry_data_source == batch_source

test_registry.teardown()


@pytest.mark.integration
@pytest.mark.parametrize(
"test_registry",
Expand Down

0 comments on commit f569786

Please sign in to comment.