Skip to content

Commit

Permalink
feat: Added support for reading from Reader Endpoints for AWS Aurora …
Browse files Browse the repository at this point in the history
…use cases (#4494)

fix: Resovled merge conflicts associated to new changes

Signed-off-by: Bhargav Dodla <[email protected]>
Co-authored-by: Bhargav Dodla <[email protected]>
  • Loading branch information
EXPEbdodla and Bhargav Dodla authored Sep 7, 2024
1 parent 87e7ca4 commit d793c77
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 20 deletions.
51 changes: 31 additions & 20 deletions sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ class SqlRegistryConfig(RegistryConfig):
""" str: Path to metadata store.
If registry_type is 'sql', then this is a database URL as expected by SQLAlchemy """

read_path: Optional[StrictStr] = None
""" str: Read Path to metadata store if different from path.
If registry_type is 'sql', then this is a Read Endpoint for database URL. If not set, path will be used for read and write. """

sqlalchemy_config_kwargs: Dict[str, Any] = {"echo": False}
""" Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """

Expand All @@ -223,13 +227,20 @@ def __init__(
registry_config, SqlRegistryConfig
), "SqlRegistry needs a valid registry_config"

self.engine: Engine = create_engine(
self.write_engine: Engine = create_engine(
registry_config.path, **registry_config.sqlalchemy_config_kwargs
)
if registry_config.read_path:
self.read_engine: Engine = create_engine(
registry_config.read_path,
**registry_config.sqlalchemy_config_kwargs,
)
else:
self.read_engine = self.write_engine
metadata.create_all(self.write_engine)
self.thread_pool_executor_worker_count = (
registry_config.thread_pool_executor_worker_count
)
metadata.create_all(self.engine)
self.purge_feast_metadata = registry_config.purge_feast_metadata
# Sync feast_metadata to projects table
# when purge_feast_metadata is set to True, Delete data from
Expand All @@ -246,7 +257,7 @@ def __init__(
def _sync_feast_metadata_to_projects_table(self):
feast_metadata_projects: set = []
projects_set: set = []
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value
)
Expand All @@ -255,7 +266,7 @@ def _sync_feast_metadata_to_projects_table(self):
feast_metadata_projects.append(row._mapping["project_id"])

if len(feast_metadata_projects) > 0:
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = select(projects)
rows = conn.execute(stmt).all()
for row in rows:
Expand All @@ -267,7 +278,7 @@ def _sync_feast_metadata_to_projects_table(self):
self.apply_project(Project(name=project_name), commit=True)

if self.purge_feast_metadata:
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
for project_name in feast_metadata_projects:
stmt = delete(feast_metadata).where(
feast_metadata.c.project_id == project_name
Expand All @@ -285,7 +296,7 @@ def teardown(self):
validation_references,
permissions,
}:
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = delete(t)
conn.execute(stmt)

Expand Down Expand Up @@ -549,7 +560,7 @@ def apply_feature_service(
)

def delete_data_source(self, name: str, project: str, commit: bool = True):
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = delete(data_sources).where(
data_sources.c.data_source_name == name,
data_sources.c.project_id == project,
Expand Down Expand Up @@ -607,7 +618,7 @@ def _list_on_demand_feature_views(
)

def _list_project_metadata(self, project: str) -> List[ProjectMetadata]:
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.project_id == project,
)
Expand Down Expand Up @@ -726,7 +737,7 @@ def apply_user_metadata(
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = select(table).where(
getattr(table.c, "feature_view_name") == name,
table.c.project_id == project,
Expand Down Expand Up @@ -781,7 +792,7 @@ def get_user_metadata(
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(table).where(getattr(table.c, "feature_view_name") == name)
row = conn.execute(stmt).first()
if row:
Expand Down Expand Up @@ -885,7 +896,7 @@ def _apply_object(
name = name or (obj.name if hasattr(obj, "name") else None)
assert name, f"name needs to be provided for {obj}"

with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
update_datetime = _utc_now()
update_time = int(update_datetime.timestamp())
stmt = select(table).where(
Expand Down Expand Up @@ -961,7 +972,7 @@ def _apply_object(

def _maybe_init_project_metadata(self, project):
# Initialize project metadata if needed
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
update_datetime = _utc_now()
update_time = int(update_datetime.timestamp())
stmt = select(feast_metadata).where(
Expand All @@ -988,7 +999,7 @@ def _delete_object(
id_field_name: str,
not_found_exception: Optional[Callable],
):
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = delete(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
)
Expand All @@ -1014,7 +1025,7 @@ def _get_object(
proto_field_name: str,
not_found_exception: Optional[Callable],
):
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
)
Expand All @@ -1036,7 +1047,7 @@ def _list_objects(
proto_field_name: str,
tags: Optional[dict[str, str]] = None,
):
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(table).where(table.c.project_id == project)
rows = conn.execute(stmt).all()
if rows:
Expand All @@ -1051,7 +1062,7 @@ def _list_objects(
return []

def _set_last_updated_metadata(self, last_updated: datetime, project: str):
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
Expand Down Expand Up @@ -1085,7 +1096,7 @@ def _set_last_updated_metadata(self, last_updated: datetime, project: str):
conn.execute(insert_stmt)

def _get_last_updated_metadata(self, project: str):
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
Expand Down Expand Up @@ -1130,7 +1141,7 @@ def apply_permission(
)

def delete_permission(self, name: str, project: str, commit: bool = True):
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = delete(permissions).where(
permissions.c.permission_name == name,
permissions.c.project_id == project,
Expand All @@ -1143,7 +1154,7 @@ def _list_projects(
self,
tags: Optional[dict[str, str]],
) -> List[Project]:
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(projects)
rows = conn.execute(stmt).all()
if rows:
Expand Down Expand Up @@ -1188,7 +1199,7 @@ def delete_project(
):
project = self.get_project(name, allow_cache=False)
if project:
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
for t in {
managed_infra,
saved_datasets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,43 @@ def minio_registry(minio_server):
yield Registry("project", registry_config, None)


POSTGRES_READONLY_USER = "read_only_user"
POSTGRES_READONLY_PASSWORD = "readonly_password"

logger = logging.getLogger(__name__)


def add_pg_read_only_user(
container_host, container_port, db_name, postgres_user, postgres_password
):
# Connect to PostgreSQL as an admin
import psycopg

conn_string = f"dbname={db_name} user={postgres_user} password={postgres_password} host={container_host} port={container_port}"

with psycopg.connect(conn_string) as conn:
user_exists = conn.execute(
f"SELECT 1 FROM pg_catalog.pg_user WHERE usename = '{POSTGRES_READONLY_USER}'"
).fetchone()
if not user_exists:
conn.execute(
f"CREATE USER {POSTGRES_READONLY_USER} WITH PASSWORD '{POSTGRES_READONLY_PASSWORD}';"
)

conn.execute(
f"REVOKE ALL PRIVILEGES ON DATABASE {db_name} FROM {POSTGRES_READONLY_USER};"
)
conn.execute(
f"GRANT CONNECT ON DATABASE {db_name} TO {POSTGRES_READONLY_USER};"
)
conn.execute(
f"GRANT SELECT ON ALL TABLES IN SCHEMA public TO {POSTGRES_READONLY_USER};"
)
conn.execute(
f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO {POSTGRES_READONLY_USER};"
)


@pytest.fixture(scope="function")
def pg_registry(postgres_server):
db_name = "".join(random.choices(string.ascii_lowercase, k=10))
Expand All @@ -130,13 +164,22 @@ def pg_registry(postgres_server):
container_port = postgres_server.get_exposed_port(5432)
container_host = postgres_server.get_container_host_ip()

add_pg_read_only_user(
container_host,
container_port,
db_name,
postgres_server.username,
postgres_server.password,
)

registry_config = SqlRegistryConfig(
registry_type="sql",
cache_ttl_seconds=2,
cache_mode="sync",
# The `path` must include `+psycopg` in order for `sqlalchemy.create_engine()`
# to understand that we are using psycopg3.
path=f"postgresql+psycopg://{postgres_server.username}:{postgres_server.password}@{container_host}:{container_port}/{db_name}",
read_path=f"postgresql+psycopg://{POSTGRES_READONLY_USER}:{POSTGRES_READONLY_PASSWORD}@{container_host}:{container_port}/{db_name}",
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
thread_pool_executor_worker_count=0,
purge_feast_metadata=False,
Expand Down

0 comments on commit d793c77

Please sign in to comment.