Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Return an empty infra object from sql registry when it doesn't exist #3022

Merged
merged 4 commits into from
Aug 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions sdk/python/feast/infra/registry_stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, List, Optional, Set, Union
from typing import Any, Callable, List, Optional, Set, Union

from sqlalchemy import ( # type: ignore
BigInteger,
Expand Down Expand Up @@ -560,7 +560,7 @@ def update_infra(self, infra: Infra, project: str, commit: bool = True):
)

def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
return self._get_object(
infra_object = self._get_object(
managed_infra,
"infra_obj",
project,
Expand All @@ -570,6 +570,8 @@ def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
"infra_proto",
None,
)
infra_object = infra_object or InfraProto()
return Infra.from_proto(infra_object)

def apply_user_metadata(
self,
Expand Down Expand Up @@ -683,11 +685,18 @@ def commit(self):
pass

def _apply_object(
self, table, project: str, id_field_name, obj, proto_field_name, name=None
self,
table: Table,
project: str,
id_field_name,
obj,
proto_field_name,
name=None,
):
self._maybe_init_project_metadata(project)

name = name or obj.name
name = name or obj.name if hasattr(obj, "name") else None
assert name, f"name needs to be provided for {obj}"
with self.engine.connect() as conn:
update_datetime = datetime.utcnow()
update_time = int(update_datetime.timestamp())
Expand Down Expand Up @@ -749,7 +758,14 @@ def _maybe_init_project_metadata(self, project):
conn.execute(insert_stmt)
usage.set_current_project_uuid(new_project_uuid)

def _delete_object(self, table, name, project, id_field_name, not_found_exception):
def _delete_object(
self,
table: Table,
name: str,
project: str,
id_field_name: str,
not_found_exception: Optional[Callable],
):
with self.engine.connect() as conn:
stmt = delete(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
Expand All @@ -763,14 +779,14 @@ def _delete_object(self, table, name, project, id_field_name, not_found_exceptio

def _get_object(
self,
table,
name,
project,
proto_class,
python_class,
id_field_name,
proto_field_name,
not_found_exception,
table: Table,
name: str,
project: str,
proto_class: Any,
python_class: Any,
id_field_name: str,
proto_field_name: str,
not_found_exception: Optional[Callable],
):
self._maybe_init_project_metadata(project)

Expand All @@ -782,10 +798,18 @@ def _get_object(
if row:
_proto = proto_class.FromString(row[proto_field_name])
return python_class.from_proto(_proto)
raise not_found_exception(name, project)
if not_found_exception:
raise not_found_exception(name, project)
else:
return None

def _list_objects(
self, table, project, proto_class, python_class, proto_field_name
self,
table: Table,
project: str,
proto_class: Any,
python_class: Any,
proto_field_name: str,
):
self._maybe_init_project_metadata(project)
with self.engine.connect() as conn:
Expand Down