diff --git a/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py b/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py index 35419e0066cd4..dc759693293ae 100644 --- a/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py +++ b/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py @@ -30,11 +30,12 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy import and_, inspect, or_ +from sqlalchemy import inspect from sqlalchemy.engine.url import make_url from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import backref, relationship, Session from sqlalchemy.schema import UniqueConstraint +from sqlalchemy.sql.expression import and_, desc, or_ from sqlalchemy_utils import UUIDType from superset import app, db @@ -118,16 +119,12 @@ class SqlaTable(Base): __tablename__ = "tables" __table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),) - def fetch_columns_and_metrics(self, session: Session) -> None: - self.columns = session.query(TableColumn).filter( - TableColumn.table_id == self.id - ) - self.metrics = session.query(SqlMetric).filter(TableColumn.table_id == self.id) - id = sa.Column(sa.Integer, primary_key=True) - columns: List[TableColumn] = [] + # Eager load related columns/metrics with `selectin`. Don't use `joined` + # as it will bloat the memory very fast for large datasets. + columns: List[TableColumn] = relationship(TableColumn, lazy="selectin") + metrics: List[SqlMetric] = relationship(SqlMetric, lazy="selectin") column_class = TableColumn - metrics: List[SqlMetric] = [] metric_class = SqlMetric database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) @@ -232,13 +229,14 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals """ Copy old datasets to the new models. """ - session = inspect(target).session + session: Session = inspect(target).session + database_id = target.database_id + is_physical_table = not target.sql # get DB-specific conditional quoter for expressions that point to columns or # table names database = ( - target.database - or session.query(Database).filter_by(id=target.database_id).first() + target.database or session.query(Database).filter_by(id=database_id).first() ) if not database: return @@ -262,18 +260,21 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals columns.append( NewColumn( name=column.column_name, - type=column.type or "Unknown", - expression=column.expression or conditional_quote(column.column_name), description=column.description, - is_temporal=column.is_dttm, + external_url=target.external_url, + expression=column.expression or conditional_quote(column.column_name), + extra_json=json.dumps(extra_json) if extra_json else None, is_aggregation=False, - is_physical=column.expression is None or column.expression == "", - is_spatial=False, - is_partition=False, is_increase_desired=True, - extra_json=json.dumps(extra_json) if extra_json else None, is_managed_externally=target.is_managed_externally, - external_url=target.external_url, + is_temporal=column.is_dttm, + is_spatial=False, + is_partition=False, + is_physical=( + is_physical_table + and (column.expression is None or column.expression == "") + ), + type=column.type or "Unknown", ), ) @@ -292,53 +293,45 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals columns.append( NewColumn( name=metric.metric_name, - type="Unknown", # figuring this out would require a type inferrer - expression=metric.expression, - warning_text=metric.warning_text, description=metric.description, + expression=metric.expression, + external_url=target.external_url, + extra_json=json.dumps(extra_json) if extra_json else None, is_aggregation=True, is_additive=is_additive, - is_physical=False, - is_spatial=False, - is_partition=False, is_increase_desired=True, - extra_json=json.dumps(extra_json) if extra_json else None, is_managed_externally=target.is_managed_externally, - external_url=target.external_url, + is_partition=False, + is_physical=False, + is_spatial=False, + is_temporal=False, + type="Unknown", # figuring this out would require a type inferrer + warning_text=metric.warning_text, ), ) - # physical dataset - tables = [] - if target.sql is None: - physical_columns = [column for column in columns if column.is_physical] - - # create table + if is_physical_table: + # create physical sl_table table = NewTable( name=target.table_name, schema=target.schema, catalog=None, # currently not supported - database_id=target.database_id, - columns=physical_columns, + database_id=database_id, + # only save physical columns + columns=[column for column in columns if column.is_physical], is_managed_externally=target.is_managed_externally, external_url=target.external_url, ) - tables.append(table) - - # virtual dataset + tables = [table] + expression = conditional_quote(target.table_name) else: - # mark all columns as virtual (not physical) - for column in columns: - column.is_physical = False - - # find referenced tables + # find referenced tables and link to dataset parsed = ParsedQuery(target.sql) referenced_tables = parsed.tables - - # predicate for finding the referenced tables predicate = or_( *[ and_( + NewTable.database_id == database_id, NewTable.schema == (table.schema or target.schema), NewTable.name == table.table, ) @@ -346,24 +339,24 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals ] ) tables = session.query(NewTable).filter(predicate).all() + expression = target.sql # create the new dataset dataset = NewDataset( sqlatable_id=target.id, name=target.table_name, - expression=target.sql or conditional_quote(target.table_name), + expression=expression, tables=tables, columns=columns, - is_physical=target.sql is None, + is_physical=is_physical_table, is_managed_externally=target.is_managed_externally, external_url=target.external_url, ) session.add(dataset) -def upgrade(): - # Create tables for the new models. - op.create_table( +new_tables = [ + ( "sl_columns", # AuditMixinNullable sa.Column("created_on", sa.DateTime(), nullable=True), @@ -373,52 +366,26 @@ def upgrade(): # ExtraJSONMixin sa.Column("extra_json", sa.Text(), nullable=True), # ImportExportMixin - sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), + sa.Column( + "uuid", UUIDType(binary=True), primary_key=False, default=uuid4, unique=True + ), # Column - sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column( + "id", sa.INTEGER(), primary_key=True, autoincrement=True, nullable=False + ), sa.Column("name", sa.TEXT(), nullable=False), sa.Column("type", sa.TEXT(), nullable=False), sa.Column("expression", sa.TEXT(), nullable=False), - sa.Column( - "is_physical", - sa.BOOLEAN(), - nullable=False, - default=True, - ), + sa.Column("is_physical", sa.BOOLEAN(), nullable=False, default=True), sa.Column("description", sa.TEXT(), nullable=True), sa.Column("warning_text", sa.TEXT(), nullable=True), sa.Column("unit", sa.TEXT(), nullable=True), sa.Column("is_temporal", sa.BOOLEAN(), nullable=False), - sa.Column( - "is_spatial", - sa.BOOLEAN(), - nullable=False, - default=False, - ), - sa.Column( - "is_partition", - sa.BOOLEAN(), - nullable=False, - default=False, - ), - sa.Column( - "is_aggregation", - sa.BOOLEAN(), - nullable=False, - default=False, - ), - sa.Column( - "is_additive", - sa.BOOLEAN(), - nullable=False, - default=False, - ), - sa.Column( - "is_increase_desired", - sa.BOOLEAN(), - nullable=False, - default=True, - ), + sa.Column("is_spatial", sa.BOOLEAN(), nullable=False, default=False), + sa.Column("is_partition", sa.BOOLEAN(), nullable=False, default=False), + sa.Column("is_aggregation", sa.BOOLEAN(), nullable=False, default=False), + sa.Column("is_additive", sa.BOOLEAN(), nullable=False, default=False), + sa.Column("is_increase_desired", sa.BOOLEAN(), nullable=False, default=True), sa.Column( "is_managed_externally", sa.Boolean(), @@ -426,12 +393,8 @@ def upgrade(): server_default=sa.false(), ), sa.Column("external_url", sa.Text(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - with op.batch_alter_table("sl_columns") as batch_op: - batch_op.create_unique_constraint("uq_sl_columns_uuid", ["uuid"]) - - op.create_table( + ), + ( "sl_tables", # AuditMixinNullable sa.Column("created_on", sa.DateTime(), nullable=True), @@ -441,10 +404,20 @@ def upgrade(): # ExtraJSONMixin sa.Column("extra_json", sa.Text(), nullable=True), # ImportExportMixin - sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), + sa.Column( + "uuid", UUIDType(binary=True), primary_key=False, default=uuid4, unique=True + ), # Table - sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column("database_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column( + "id", sa.INTEGER(), primary_key=True, autoincrement=True, nullable=False + ), + sa.Column( + "database_id", + sa.INTEGER(), + sa.ForeignKey("dbs.id"), + autoincrement=False, + nullable=False, + ), sa.Column("catalog", sa.TEXT(), nullable=True), sa.Column("schema", sa.TEXT(), nullable=True), sa.Column("name", sa.TEXT(), nullable=False), @@ -455,25 +428,8 @@ def upgrade(): server_default=sa.false(), ), sa.Column("external_url", sa.Text(), nullable=True), - sa.ForeignKeyConstraint(["database_id"], ["dbs.id"], name="sl_tables_ibfk_1"), - sa.PrimaryKeyConstraint("id"), - ) - with op.batch_alter_table("sl_tables") as batch_op: - batch_op.create_unique_constraint("uq_sl_tables_uuid", ["uuid"]) - - op.create_table( - "sl_table_columns", - sa.Column("table_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column("column_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint( - ["column_id"], ["sl_columns.id"], name="sl_table_columns_ibfk_2" - ), - sa.ForeignKeyConstraint( - ["table_id"], ["sl_tables.id"], name="sl_table_columns_ibfk_1" - ), - ) - - op.create_table( + ), + ( "sl_datasets", # AuditMixinNullable sa.Column("created_on", sa.DateTime(), nullable=True), @@ -483,18 +439,23 @@ def upgrade(): # ExtraJSONMixin sa.Column("extra_json", sa.Text(), nullable=True), # ImportExportMixin - sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), + sa.Column( + "uuid", UUIDType(binary=True), primary_key=False, default=uuid4, unique=True + ), # Dataset - sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column("sqlatable_id", sa.INTEGER(), nullable=True), - sa.Column("name", sa.TEXT(), nullable=False), - sa.Column("expression", sa.TEXT(), nullable=False), sa.Column( - "is_physical", - sa.BOOLEAN(), - nullable=False, - default=False, + "id", sa.INTEGER(), primary_key=True, autoincrement=True, nullable=False + ), + sa.Column( + "sqlatable_id", + sa.INTEGER(), + sa.ForeignKey("tables.id"), + unique=True, + nullable=True, ), + sa.Column("name", sa.TEXT(), nullable=False), + sa.Column("expression", sa.TEXT(), nullable=False), + sa.Column("is_physical", sa.BOOLEAN(), nullable=False, default=False), sa.Column( "is_managed_externally", sa.Boolean(), @@ -502,46 +463,95 @@ def upgrade(): server_default=sa.false(), ), sa.Column("external_url", sa.Text(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - with op.batch_alter_table("sl_datasets") as batch_op: - batch_op.create_unique_constraint("uq_sl_datasets_uuid", ["uuid"]) - batch_op.create_unique_constraint( - "uq_sl_datasets_sqlatable_id", ["sqlatable_id"] - ) - - op.create_table( + ), + # Relationships... + ( + "sl_table_columns", + sa.Column( + "table_id", + sa.INTEGER(), + sa.ForeignKey("sl_tables.id"), + autoincrement=False, + nullable=False, + ), + sa.Column( + "column_id", + sa.INTEGER(), + sa.ForeignKey("sl_columns.id"), + autoincrement=False, + nullable=False, + ), + ), + ( "sl_dataset_columns", - sa.Column("dataset_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column("column_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint( - ["column_id"], ["sl_columns.id"], name="sl_dataset_columns_ibfk_2" + sa.Column( + "dataset_id", + sa.INTEGER(), + sa.ForeignKey("sl_datasets.id"), + autoincrement=False, + nullable=False, ), - sa.ForeignKeyConstraint( - ["dataset_id"], ["sl_datasets.id"], name="sl_dataset_columns_ibfk_1" + sa.Column( + "column_id", + sa.INTEGER(), + sa.ForeignKey("sl_columns.id"), + autoincrement=False, + nullable=False, ), - ) - - op.create_table( + ), + ( "sl_dataset_tables", - sa.Column("dataset_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column("table_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint( - ["dataset_id"], ["sl_datasets.id"], name="sl_dataset_tables_ibfk_1" + sa.Column( + "dataset_id", + sa.INTEGER(), + sa.ForeignKey("sl_datasets.id"), + autoincrement=False, + nullable=False, ), - sa.ForeignKeyConstraint( - ["table_id"], ["sl_tables.id"], name="sl_dataset_tables_ibfk_2" + sa.Column( + "table_id", + sa.INTEGER(), + sa.ForeignKey("sl_tables.id"), + autoincrement=False, + nullable=False, ), - ) + ), +] - # migrate existing datasets to the new models - bind = op.get_bind() - session = db.Session(bind=bind) # pylint: disable=no-member - datasets = session.query(SqlaTable).all() - for dataset in datasets: - dataset.fetch_columns_and_metrics(session) - after_insert(target=dataset) +def upgrade(): + # Create tables for the new models. + for (table_name, *cols) in new_tables: + op.create_table(table_name, *cols) + + # copy existing datasets to the new models + session: Session = db.Session(bind=op.get_bind()) + total = session.query(SqlaTable.id).count() + offset = 0 + limit = 20 # smaller page size is actually faster + + if total > 1000: + print(f"Hang tight, we need to migrate {total} datasets today.") + + while offset < total: + for i, dataset in enumerate( + session.query(SqlaTable) + .order_by(desc(SqlaTable.sql.is_(None))) + .order_by(SqlaTable.id) + .offset(offset) + .limit(limit) + ): + table_name = dataset.table_name or "" + short_table_name = table_name[:50] + print( + f"Copying #{offset + i + 1} [ID: {dataset.id}] " + f"{short_table_name}{'...' if short_table_name != table_name else ''}" + + " " * 50, # extra space to override previous lines + end="\r", + ) + with op.get_context().autocommit_block(): + after_insert(target=dataset) + offset += limit def downgrade():