Skip to content

Commit

Permalink
chore: Add explicit bidirectional performant relationships for SQLA m…
Browse files Browse the repository at this point in the history
…odel (#22413)
  • Loading branch information
john-bodley authored Jan 20, 2023
1 parent 858c6e1 commit 92cdb8c
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 24 deletions.
37 changes: 27 additions & 10 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@
)
from sqlalchemy.engine.base import Connection
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm import (
backref,
Mapped,
Query,
relationship,
RelationshipProperty,
Session,
)
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table
Expand Down Expand Up @@ -224,10 +231,10 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
__tablename__ = "table_columns"
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
table: SqlaTable = relationship(
table: Mapped["SqlaTable"] = relationship(
"SqlaTable",
backref=backref("columns", cascade="all, delete-orphan"),
foreign_keys=[table_id],
back_populates="columns",
lazy="joined", # Eager loading for efficient parent referencing with selectin.
)
is_dttm = Column(Boolean, default=False)
expression = Column(MediumText())
Expand Down Expand Up @@ -439,10 +446,10 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
__tablename__ = "sql_metrics"
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
table = relationship(
table: Mapped["SqlaTable"] = relationship(
"SqlaTable",
backref=backref("metrics", cascade="all, delete-orphan"),
foreign_keys=[table_id],
back_populates="metrics",
lazy="joined", # Eager loading for efficient parent referencing with selectin.
)
expression = Column(MediumText(), nullable=False)
extra = Column(Text)
Expand Down Expand Up @@ -535,13 +542,23 @@ def _process_sql_expression(


class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods
"""An ORM object for SqlAlchemy table references"""
"""An ORM object for SqlAlchemy table references."""

type = "table"
query_language = "sql"
is_rls_supported = True
columns: List[TableColumn] = []
metrics: List[SqlMetric] = []
columns: Mapped[List[TableColumn]] = relationship(
TableColumn,
back_populates="table",
cascade="all, delete-orphan",
lazy="selectin", # Only non-eager loading that works with bidirectional joined.
)
metrics: Mapped[List[SqlMetric]] = relationship(
SqlMetric,
back_populates="table",
cascade="all, delete-orphan",
lazy="selectin", # Only non-eager loading that works with bidirectional joined.
)
metric_class = SqlMetric
column_class = TableColumn
owner_class = security_manager.user_model
Expand Down
3 changes: 3 additions & 0 deletions superset/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ class DatasetRestApi(BaseSupersetModelRestApi):
DatasetDuplicateSchema,
)

list_outer_default_load = True
show_outer_default_load = True

@expose("/", methods=["POST"])
@protect()
@safe
Expand Down
4 changes: 2 additions & 2 deletions superset/examples/birth_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def _add_table_metrics(datasource: SqlaTable) -> None:
metrics.append(SqlMetric(metric_name="sum__num", expression=f"SUM({col})"))

for col in columns:
if col.column_name == "ds":
col.is_dttm = True
if col.column_name == "ds": # type: ignore
col.is_dttm = True # type: ignore
break

datasource.columns = columns
Expand Down
5 changes: 0 additions & 5 deletions superset/models/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,6 @@ def export_dashboards( # pylint: disable=too-many-locals
remote_id=eager_datasource.id,
database_name=eager_datasource.database.name,
)
datasource_class = copied_datasource.__class__
for field_name in datasource_class.export_children:
field_val = getattr(eager_datasource, field_name).copy()
# set children without creating ORM relations
copied_datasource.__dict__[field_name] = field_val
eager_datasources.append(copied_datasource)

return json.dumps(
Expand Down
5 changes: 4 additions & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,10 @@ def import_from_dict(
# Recursively create children
if recursive:
for child in cls.export_children:
child_class = cls.__mapper__.relationships[child].argument.class_
argument = cls.__mapper__.relationships[child].argument
child_class = (
argument.class_ if hasattr(argument, "class_") else argument
)
added = []
for c_obj in new_children.get(child, []):
added.append(
Expand Down
6 changes: 4 additions & 2 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,9 +1315,10 @@ def test_import_chart(self):

chart.owners = []
dataset.owners = []
database.owners = []
db.session.delete(chart)
db.session.commit()
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()

Expand Down Expand Up @@ -1387,9 +1388,10 @@ def test_import_chart_overwrite(self):

chart.owners = []
dataset.owners = []
database.owners = []
db.session.delete(chart)
db.session.commit()
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()

Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,8 +1987,8 @@ def test_import_database(self):
assert str(dataset.uuid) == dataset_config["uuid"]

dataset.owners = []
database.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()

Expand Down Expand Up @@ -2058,8 +2058,8 @@ def test_import_database_overwrite(self):
)
dataset = database.tables[0]
dataset.owners = []
database.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()

Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,8 +1988,8 @@ def test_import_dataset(self):
assert str(dataset.uuid) == dataset_config["uuid"]

dataset.owners = []
database.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()

Expand Down Expand Up @@ -2090,8 +2090,8 @@ def test_import_dataset_overwrite(self):
dataset = database.tables[0]

dataset.owners = []
database.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()

Expand Down

0 comments on commit 92cdb8c

Please sign in to comment.