From 92cdb8c282b9a7b83007940343410ad50b27b1b8 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Sat, 21 Jan 2023 10:17:56 +1300 Subject: [PATCH] chore: Add explicit bidirectional performant relationships for SQLA model (#22413) --- superset/connectors/sqla/models.py | 37 ++++++++++++++----- superset/datasets/api.py | 3 ++ superset/examples/birth_names.py | 4 +- superset/models/dashboard.py | 5 --- superset/models/helpers.py | 5 ++- tests/integration_tests/charts/api_tests.py | 6 ++- .../integration_tests/databases/api_tests.py | 4 +- tests/integration_tests/datasets/api_tests.py | 4 +- 8 files changed, 44 insertions(+), 24 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c5fd025f4ee2f..cffff7363055d 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -68,7 +68,14 @@ ) from sqlalchemy.engine.base import Connection from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session +from sqlalchemy.orm import ( + backref, + Mapped, + Query, + relationship, + RelationshipProperty, + Session, +) from sqlalchemy.orm.mapper import Mapper from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table @@ -224,10 +231,10 @@ class TableColumn(Model, BaseColumn, CertificationMixin): __tablename__ = "table_columns" __table_args__ = (UniqueConstraint("table_id", "column_name"),) table_id = Column(Integer, ForeignKey("tables.id")) - table: SqlaTable = relationship( + table: Mapped["SqlaTable"] = relationship( "SqlaTable", - backref=backref("columns", cascade="all, delete-orphan"), - foreign_keys=[table_id], + back_populates="columns", + lazy="joined", # Eager loading for efficient parent referencing with selectin. ) is_dttm = Column(Boolean, default=False) expression = Column(MediumText()) @@ -439,10 +446,10 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): __tablename__ = "sql_metrics" __table_args__ = (UniqueConstraint("table_id", "metric_name"),) table_id = Column(Integer, ForeignKey("tables.id")) - table = relationship( + table: Mapped["SqlaTable"] = relationship( "SqlaTable", - backref=backref("metrics", cascade="all, delete-orphan"), - foreign_keys=[table_id], + back_populates="metrics", + lazy="joined", # Eager loading for efficient parent referencing with selectin. ) expression = Column(MediumText(), nullable=False) extra = Column(Text) @@ -535,13 +542,23 @@ def _process_sql_expression( class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods - """An ORM object for SqlAlchemy table references""" + """An ORM object for SqlAlchemy table references.""" type = "table" query_language = "sql" is_rls_supported = True - columns: List[TableColumn] = [] - metrics: List[SqlMetric] = [] + columns: Mapped[List[TableColumn]] = relationship( + TableColumn, + back_populates="table", + cascade="all, delete-orphan", + lazy="selectin", # Only non-eager loading that works with bidirectional joined. + ) + metrics: Mapped[List[SqlMetric]] = relationship( + SqlMetric, + back_populates="table", + cascade="all, delete-orphan", + lazy="selectin", # Only non-eager loading that works with bidirectional joined. + ) metric_class = SqlMetric column_class = TableColumn owner_class = security_manager.user_model diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 40efe08c69d8b..925c3c7cb8c71 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -242,6 +242,9 @@ class DatasetRestApi(BaseSupersetModelRestApi): DatasetDuplicateSchema, ) + list_outer_default_load = True + show_outer_default_load = True + @expose("/", methods=["POST"]) @protect() @safe diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index daea2d6b0e641..8da041550e92a 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -144,8 +144,8 @@ def _add_table_metrics(datasource: SqlaTable) -> None: metrics.append(SqlMetric(metric_name="sum__num", expression=f"SUM({col})")) for col in columns: - if col.column_name == "ds": - col.is_dttm = True + if col.column_name == "ds": # type: ignore + col.is_dttm = True # type: ignore break datasource.columns = columns diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index ae6bae4b733ff..106e92fe46298 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -423,11 +423,6 @@ def export_dashboards( # pylint: disable=too-many-locals remote_id=eager_datasource.id, database_name=eager_datasource.database.name, ) - datasource_class = copied_datasource.__class__ - for field_name in datasource_class.export_children: - field_val = getattr(eager_datasource, field_name).copy() - # set children without creating ORM relations - copied_datasource.__dict__[field_name] = field_val eager_datasources.append(copied_datasource) return json.dumps( diff --git a/superset/models/helpers.py b/superset/models/helpers.py index fd0a1eff5ca7e..7ffd6083d190c 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -327,7 +327,10 @@ def import_from_dict( # Recursively create children if recursive: for child in cls.export_children: - child_class = cls.__mapper__.relationships[child].argument.class_ + argument = cls.__mapper__.relationships[child].argument + child_class = ( + argument.class_ if hasattr(argument, "class_") else argument + ) added = [] for c_obj in new_children.get(child, []): added.append( diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 63f530ffa98a8..3d8a4695f4eb1 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -1315,9 +1315,10 @@ def test_import_chart(self): chart.owners = [] dataset.owners = [] - database.owners = [] db.session.delete(chart) + db.session.commit() db.session.delete(dataset) + db.session.commit() db.session.delete(database) db.session.commit() @@ -1387,9 +1388,10 @@ def test_import_chart_overwrite(self): chart.owners = [] dataset.owners = [] - database.owners = [] db.session.delete(chart) + db.session.commit() db.session.delete(dataset) + db.session.commit() db.session.delete(database) db.session.commit() diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index d4e5fb4349456..ae01ccdaf9689 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -1987,8 +1987,8 @@ def test_import_database(self): assert str(dataset.uuid) == dataset_config["uuid"] dataset.owners = [] - database.owners = [] db.session.delete(dataset) + db.session.commit() db.session.delete(database) db.session.commit() @@ -2058,8 +2058,8 @@ def test_import_database_overwrite(self): ) dataset = database.tables[0] dataset.owners = [] - database.owners = [] db.session.delete(dataset) + db.session.commit() db.session.delete(database) db.session.commit() diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 4e566fc80dade..ff8206354c0d5 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -1988,8 +1988,8 @@ def test_import_dataset(self): assert str(dataset.uuid) == dataset_config["uuid"] dataset.owners = [] - database.owners = [] db.session.delete(dataset) + db.session.commit() db.session.delete(database) db.session.commit() @@ -2090,8 +2090,8 @@ def test_import_dataset_overwrite(self): dataset = database.tables[0] dataset.owners = [] - database.owners = [] db.session.delete(dataset) + db.session.commit() db.session.delete(database) db.session.commit()