Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Async refresh to Sql Registry #4251

Merged
57 changes: 36 additions & 21 deletions sdk/python/feast/infra/registry/caching_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import atexit
import logging
import threading
from abc import abstractmethod
from datetime import timedelta
from threading import Lock
Expand All @@ -21,18 +23,18 @@


class CachingRegistry(BaseRegistry):
def __init__(
self,
project: str,
cache_ttl_seconds: int,
):
def __init__(self, project: str, cache_ttl_seconds: int, cache_mode: str):
self.cached_registry_proto = self.proto()
proto_registry_utils.init_project_metadata(self.cached_registry_proto, project)
self.cached_registry_proto_created = _utc_now()
self._refresh_lock = Lock()
self.cached_registry_proto_ttl = timedelta(
seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0
)
self.cache_mode = cache_mode
if cache_mode == "thread":
self._start_thread_async_refresh(cache_ttl_seconds)
atexit.register(self._exit_handler)

@abstractmethod
def _get_data_source(self, name: str, project: str) -> DataSource:
Expand Down Expand Up @@ -322,22 +324,35 @@ def refresh(self, project: Optional[str] = None):
self.cached_registry_proto_created = _utc_now()

def _refresh_cached_registry_if_necessary(self):
with self._refresh_lock:
expired = (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
_utc_now()
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
if self.cache_mode == "sync":
with self._refresh_lock:
expired = (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
_utc_now()
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
)
)
)
)
if expired:
logger.info("Registry cache expired, so refreshing")
self.refresh()

def _start_thread_async_refresh(self, cache_ttl_seconds):
self.refresh()
if cache_ttl_seconds <= 0:
return
self.registry_refresh_thread = threading.Timer(
stanconia marked this conversation as resolved.
Show resolved Hide resolved
cache_ttl_seconds, self._start_thread_async_refresh, [cache_ttl_seconds]
)
self.registry_refresh_thread.setDaemon(True)
self.registry_refresh_thread.start()

if expired:
logger.info("Registry cache expired, so refreshing")
self.refresh()
def _exit_handler(self):
self.registry_refresh_thread.cancel()
4 changes: 3 additions & 1 deletion sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def __init__(
)
metadata.create_all(self.engine)
super().__init__(
project=project, cache_ttl_seconds=registry_config.cache_ttl_seconds
project=project,
cache_ttl_seconds=registry_config.cache_ttl_seconds,
cache_mode=registry_config.cache_mode,
)

def teardown(self):
Expand Down
4 changes: 3 additions & 1 deletion sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class RegistryConfig(FeastBaseModel):
sqlalchemy_config_kwargs: Dict[str, Any] = {}
""" Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """

cache_mode: StrictStr = "sync"
stanconia marked this conversation as resolved.
Show resolved Hide resolved
""" str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)"""

@field_validator("path")
def validate_path(cls, path: str, values: ValidationInfo) -> str:
if values.data.get("registry_type") == "sql":
Expand All @@ -139,7 +142,6 @@ def validate_path(cls, path: str, values: ValidationInfo) -> str:
return path.replace("postgresql://", "postgresql+psycopg://")
return path


class RepoConfig(FeastBaseModel):
"""Repo config. Typically loaded from `feature_store.yaml`"""

Expand Down
113 changes: 103 additions & 10 deletions sdk/python/tests/integration/registration/test_universal_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def minio_registry() -> Registry:
logger = logging.getLogger(__name__)


@pytest.fixture(scope="session")
@pytest.fixture(scope="function")
def pg_registry():
container = (
DockerContainer("postgres:latest")
Expand All @@ -137,6 +137,35 @@ def pg_registry():

container.start()

registry_config = _given_registry_config_for_pg_sql(container)

yield SqlRegistry(registry_config, "project", None)

container.stop()


@pytest.fixture(scope="function")
def pg_registry_async():
container = (
DockerContainer("postgres:latest")
.with_exposed_ports(5432)
.with_env("POSTGRES_USER", POSTGRES_USER)
.with_env("POSTGRES_PASSWORD", POSTGRES_PASSWORD)
.with_env("POSTGRES_DB", POSTGRES_DB)
)

container.start()

registry_config = _given_registry_config_for_pg_sql(container, 2, "thread")

yield SqlRegistry(registry_config, "project", None)

container.stop()


def _given_registry_config_for_pg_sql(
container, cache_ttl_seconds=2, cache_mode="sync"
):
log_string_to_wait_for = "database system is ready to accept connections"
waited = wait_for_logs(
container=container,
Expand All @@ -148,42 +177,57 @@ def pg_registry():
container_port = container.get_exposed_port(5432)
container_host = container.get_container_host_ip()

registry_config = RegistryConfig(
return RegistryConfig(
registry_type="sql",
cache_ttl_seconds=cache_ttl_seconds,
cache_mode=cache_mode,
# The `path` must include `+psycopg` in order for `sqlalchemy.create_engine()`
# to understand that we are using psycopg3.
path=f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{container_host}:{container_port}/{POSTGRES_DB}",
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
)


@pytest.fixture(scope="function")
def mysql_registry():
container = MySqlContainer("mysql:latest")
container.start()

registry_config = _given_registry_config_for_mysql(container)

yield SqlRegistry(registry_config, "project", None)

container.stop()


@pytest.fixture(scope="session")
def mysql_registry():
@pytest.fixture(scope="function")
def mysql_registry_async():
container = MySqlContainer("mysql:latest")
container.start()

# testing for the database to exist and ready to connect and start testing.
registry_config = _given_registry_config_for_mysql(container, 2, "thread")

yield SqlRegistry(registry_config, "project", None)

container.stop()


def _given_registry_config_for_mysql(container, cache_ttl_seconds=2, cache_mode="sync"):
import sqlalchemy

engine = sqlalchemy.create_engine(
container.get_connection_url(), pool_pre_ping=True
)
engine.connect()

registry_config = RegistryConfig(
return RegistryConfig(
registry_type="sql",
path=container.get_connection_url(),
cache_ttl_seconds=cache_ttl_seconds,
cache_mode=cache_mode,
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
)

yield SqlRegistry(registry_config, "project", None)

container.stop()


@pytest.fixture(scope="session")
def sqlite_registry():
Expand Down Expand Up @@ -269,6 +313,17 @@ def mock_remote_registry():
lazy_fixture("sqlite_registry"),
]

async_sql_fixtures = [
pytest.param(
lazy_fixture("pg_registry_async"),
marks=pytest.mark.xdist_group(name="pg_registry_async"),
),
pytest.param(
lazy_fixture("mysql_registry_async"),
marks=pytest.mark.xdist_group(name="mysql_registry_async"),
),
]


@pytest.mark.integration
@pytest.mark.parametrize("test_registry", all_fixtures)
Expand Down Expand Up @@ -999,6 +1054,44 @@ def test_registry_cache(test_registry):
test_registry.teardown()


@pytest.mark.integration
@pytest.mark.parametrize(
"test_registry",
async_sql_fixtures,
)
def test_registry_cache_thread_async(test_registry):
# Create Feature View
batch_source = FileSource(
name="test_source",
file_format=ParquetFormat(),
path="file://feast/*",
timestamp_field="ts_col",
created_timestamp_column="timestamp",
)

project = "project"

# Register data source
test_registry.apply_data_source(batch_source, project)
registry_data_sources_cached = test_registry.list_data_sources(
project, allow_cache=True
)
# async ttl yet to expire, so cache miss
assert len(registry_data_sources_cached) == 0

# Wait for cache to be refreshed
time.sleep(4)
# Now objects exist
registry_data_sources_cached = test_registry.list_data_sources(
project, allow_cache=True
)
assert len(registry_data_sources_cached) == 1
registry_data_source = registry_data_sources_cached[0]
assert registry_data_source == batch_source

test_registry.teardown()


@pytest.mark.integration
@pytest.mark.parametrize(
"test_registry",
Expand Down
Loading