Skip to content

Commit

Permalink
feat: Adds CLI commands to execute viz migrations (apache#25304)
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-s-molina authored Sep 19, 2023
1 parent 308743e commit 1928b12
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 20 deletions.
91 changes: 91 additions & 0 deletions superset/cli/viz_migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from enum import Enum

import click
from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup
from flask.cli import with_appcontext

from superset import db


class VizType(str, Enum):
TREEMAP = "treemap"
DUAL_LINE = "dual_line"
AREA = "area"
PIVOT_TABLE = "pivot_table"


@click.group()
def migrate_viz() -> None:
"""
Migrate a viz from one type to another.
"""


@migrate_viz.command()
@with_appcontext
@optgroup.group(
"Grouped options",
cls=RequiredMutuallyExclusiveOptionGroup,
)
@optgroup.option(
"--viz_type",
"-t",
help=f"The viz type to migrate: {', '.join(list(VizType))}",
)
def upgrade(viz_type: str) -> None:
"""Upgrade a viz to the latest version."""
migrate(VizType(viz_type))


@migrate_viz.command()
@with_appcontext
@optgroup.group(
"Grouped options",
cls=RequiredMutuallyExclusiveOptionGroup,
)
@optgroup.option(
"--viz_type",
"-t",
help=f"The viz type to migrate: {', '.join(list(VizType))}",
)
def downgrade(viz_type: str) -> None:
"""Downgrade a viz to the previous version."""
migrate(VizType(viz_type), is_downgrade=True)


def migrate(viz_type: VizType, is_downgrade: bool = False) -> None:
"""Migrate a viz from one type to another."""
# pylint: disable=import-outside-toplevel
from superset.migrations.shared.migrate_viz.processors import (
MigrateAreaChart,
MigrateDualLine,
MigratePivotTable,
MigrateTreeMap,
)

migrations = {
VizType.TREEMAP: MigrateTreeMap,
VizType.DUAL_LINE: MigrateDualLine,
VizType.AREA: MigrateAreaChart,
VizType.PIVOT_TABLE: MigratePivotTable,
}
if is_downgrade:
migrations[viz_type].downgrade(db.session)
else:
migrations[viz_type].upgrade(db.session)
12 changes: 4 additions & 8 deletions superset/migrations/shared/migrate_viz/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import json
from typing import Any

from alembic import op
from sqlalchemy import and_, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session

from superset import conf, db, is_feature_enabled
from superset import conf, is_feature_enabled
from superset.constants import TimeGrain
from superset.migrations.shared.utils import paginated_update, try_load_json

Expand Down Expand Up @@ -156,9 +156,7 @@ def downgrade_slice(cls, slc: Slice) -> Slice:
return slc

@classmethod
def upgrade(cls) -> None:
bind = op.get_bind()
session = db.Session(bind=bind)
def upgrade(cls, session: Session) -> None:
slices = session.query(Slice).filter(Slice.viz_type == cls.source_viz_type)
for slc in paginated_update(
slices,
Expand All @@ -170,9 +168,7 @@ def upgrade(cls) -> None:
session.merge(new_viz)

@classmethod
def downgrade(cls) -> None:
bind = op.get_bind()
session = db.Session(bind=bind)
def downgrade(cls, session: Session) -> None:
slices = session.query(Slice).filter(
and_(
Slice.viz_type == cls.target_viz_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from alembic import op
from sqlalchemy.dialects.mysql.base import MySQLDialect

from superset import db
from superset.migrations.shared.migrate_viz import MigrateTreeMap

# revision identifiers, used by Alembic.
Expand All @@ -32,16 +33,21 @@


def upgrade():
bind = op.get_bind()

# Ensure `slice.params` and `slice.query_context`` in MySQL is MEDIUMTEXT
# before migration, as the migration will save a duplicate form_data backup
# which may significantly increase the size of these fields.
if isinstance(op.get_bind().dialect, MySQLDialect):
if isinstance(bind.dialect, MySQLDialect):
# If the columns are already MEDIUMTEXT, this is a no-op
op.execute("ALTER TABLE slices MODIFY params MEDIUMTEXT")
op.execute("ALTER TABLE slices MODIFY query_context MEDIUMTEXT")

MigrateTreeMap.upgrade()
session = db.Session(bind=bind)
MigrateTreeMap.upgrade(session)


def downgrade():
MigrateTreeMap.downgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateTreeMap.downgrade(session)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
Create Date: 2022-06-13 14:17:51.872706
"""
from alembic import op

from superset import db
from superset.migrations.shared.migrate_viz import MigrateAreaChart

# revision identifiers, used by Alembic.
Expand All @@ -29,8 +32,12 @@


def upgrade():
MigrateAreaChart.upgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateAreaChart.upgrade(session)


def downgrade():
MigrateAreaChart.downgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateAreaChart.downgrade(session)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
Create Date: 2023-08-06 09:02:10.148992
"""
from alembic import op

from superset import db
from superset.migrations.shared.migrate_viz import MigratePivotTable

# revision identifiers, used by Alembic.
Expand All @@ -29,8 +32,12 @@


def upgrade():
MigratePivotTable.upgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigratePivotTable.upgrade(session)


def downgrade():
MigratePivotTable.downgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigratePivotTable.downgrade(session)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from alembic import op
from sqlalchemy.dialects.mysql.base import MySQLDialect

from superset import db
from superset.migrations.shared.migrate_viz import MigrateTreeMap

# revision identifiers, used by Alembic.
Expand All @@ -32,16 +33,21 @@


def upgrade():
bind = op.get_bind()

# Ensure `slice.params` and `slice.query_context`` in MySQL is MEDIUMTEXT
# before migration, as the migration will save a duplicate form_data backup
# which may significantly increase the size of these fields.
if isinstance(op.get_bind().dialect, MySQLDialect):
if isinstance(bind.dialect, MySQLDialect):
# If the columns are already MEDIUMTEXT, this is a no-op
op.execute("ALTER TABLE slices MODIFY params MEDIUMTEXT")
op.execute("ALTER TABLE slices MODIFY query_context MEDIUMTEXT")

MigrateTreeMap.upgrade()
session = db.Session(bind=bind)
MigrateTreeMap.upgrade(session)


def downgrade():
MigrateTreeMap.downgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateTreeMap.downgrade(session)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
Create Date: 2023-06-08 11:34:36.241939
"""
from alembic import op

from superset import db

# revision identifiers, used by Alembic.
revision = "ae58e1e58e5c"
Expand All @@ -30,8 +33,12 @@


def upgrade():
MigrateDualLine.upgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateDualLine.upgrade(session)


def downgrade():
MigrateDualLine.downgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateDualLine.downgrade(session)

0 comments on commit 1928b12

Please sign in to comment.