Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Upgrade sqlalchemy from 1.x to 2.x regarding PVE-2022-51668. #4065

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def teardown(self):
saved_datasets,
validation_references,
}:
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = delete(t)
conn.execute(stmt)

Expand Down Expand Up @@ -399,7 +399,7 @@ def apply_feature_service(
)

def delete_data_source(self, name: str, project: str, commit: bool = True):
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = delete(data_sources).where(
data_sources.c.data_source_name == name,
data_sources.c.project_id == project,
Expand Down Expand Up @@ -441,16 +441,19 @@ def _list_on_demand_feature_views(self, project: str) -> List[OnDemandFeatureVie
)

def _list_project_metadata(self, project: str) -> List[ProjectMetadata]:
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.project_id == project,
)
rows = conn.execute(stmt).all()
if rows:
project_metadata = ProjectMetadata(project_name=project)
for row in rows:
if row["metadata_key"] == FeastMetadataKeys.PROJECT_UUID.value:
project_metadata.project_uuid = row["metadata_value"]
if (
row._mapping["metadata_key"]
== FeastMetadataKeys.PROJECT_UUID.value
):
project_metadata.project_uuid = row._mapping["metadata_value"]
break
# TODO(adchia): Add other project metadata in a structured way
return [project_metadata]
Expand Down Expand Up @@ -557,7 +560,7 @@ def apply_user_metadata(
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(table).where(
getattr(table.c, "feature_view_name") == name,
table.c.project_id == project,
Expand Down Expand Up @@ -612,11 +615,11 @@ def get_user_metadata(
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(table).where(getattr(table.c, "feature_view_name") == name)
row = conn.execute(stmt).first()
if row:
return row["user_metadata"]
return row._mapping["user_metadata"]
else:
raise FeatureViewNotFoundException(feature_view.name, project=project)

Expand Down Expand Up @@ -674,7 +677,7 @@ def _apply_object(
name = name or (obj.name if hasattr(obj, "name") else None)
assert name, f"name needs to be provided for {obj}"

with self.engine.connect() as conn:
with self.engine.begin() as conn:
update_datetime = datetime.utcnow()
update_time = int(update_datetime.timestamp())
stmt = select(table).where(
Expand Down Expand Up @@ -723,7 +726,7 @@ def _apply_object(

def _maybe_init_project_metadata(self, project):
# Initialize project metadata if needed
with self.engine.connect() as conn:
with self.engine.begin() as conn:
update_datetime = datetime.utcnow()
update_time = int(update_datetime.timestamp())
stmt = select(feast_metadata).where(
Expand All @@ -732,7 +735,7 @@ def _maybe_init_project_metadata(self, project):
)
row = conn.execute(stmt).first()
if row:
usage.set_current_project_uuid(row["metadata_value"])
usage.set_current_project_uuid(row._mapping["metadata_value"])
else:
new_project_uuid = f"{uuid.uuid4()}"
values = {
Expand All @@ -753,7 +756,7 @@ def _delete_object(
id_field_name: str,
not_found_exception: Optional[Callable],
):
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = delete(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
)
Expand All @@ -777,13 +780,13 @@ def _get_object(
):
self._maybe_init_project_metadata(project)

with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
)
row = conn.execute(stmt).first()
if row:
_proto = proto_class.FromString(row[proto_field_name])
_proto = proto_class.FromString(row._mapping[proto_field_name])
return python_class.from_proto(_proto)
if not_found_exception:
raise not_found_exception(name, project)
Expand All @@ -799,20 +802,20 @@ def _list_objects(
proto_field_name: str,
):
self._maybe_init_project_metadata(project)
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(table).where(table.c.project_id == project)
rows = conn.execute(stmt).all()
if rows:
return [
python_class.from_proto(
proto_class.FromString(row[proto_field_name])
proto_class.FromString(row._mapping[proto_field_name])
)
for row in rows
]
return []

def _set_last_updated_metadata(self, last_updated: datetime, project: str):
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
Expand Down Expand Up @@ -846,7 +849,7 @@ def _set_last_updated_metadata(self, last_updated: datetime, project: str):
conn.execute(insert_stmt)

def _get_last_updated_metadata(self, project: str):
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
Expand All @@ -855,13 +858,13 @@ def _get_last_updated_metadata(self, project: str):
row = conn.execute(stmt).first()
if not row:
return None
update_time = int(row["last_updated_timestamp"])
update_time = int(row._mapping["last_updated_timestamp"])

return datetime.utcfromtimestamp(update_time)

def _get_all_projects(self) -> Set[str]:
projects = set()
with self.engine.connect() as conn:
with self.engine.begin() as conn:
for table in {
entities,
data_sources,
Expand All @@ -872,6 +875,6 @@ def _get_all_projects(self) -> Set[str]:
stmt = select(table)
rows = conn.execute(stmt).all()
for row in rows:
projects.add(row["project_id"])
projects.add(row._mapping["project_id"])

return projects
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
"pygments>=2.12.0,<3",
"PyYAML>=5.4.0,<7",
"requests",
"SQLAlchemy[mypy]>1,<2",
"SQLAlchemy[mypy]>1",
"tabulate>=0.8.0,<1",
"tenacity>=7,<9",
"toml>=0.10.0,<1",
Expand Down
Loading