Skip to content

Commit

Permalink
chore(tags): Refactor logic to leverage Flask-SQLAlchemy extension (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored Oct 3, 2022
1 parent 8d1b7ec commit 31895f4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 38 deletions.
7 changes: 3 additions & 4 deletions superset/cli/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from flask_appbuilder.api.manager import resolver

import superset.utils.database as database_utils
from superset.extensions import db
from superset.utils.encrypt import SecretsMigrator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,9 +61,9 @@ def sync_tags() -> None:
# pylint: disable=import-outside-toplevel
from superset.common.tags import add_favorites, add_owners, add_types

add_types(db.engine, metadata)
add_owners(db.engine, metadata)
add_favorites(db.engine, metadata)
add_types(metadata)
add_owners(metadata)
add_favorites(metadata)


@click.command()
Expand Down
76 changes: 42 additions & 34 deletions superset/common/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
from typing import Any, List

from sqlalchemy import MetaData
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import and_, func, join, literal, select

from superset.extensions import db
from superset.tags.models import ObjectTypes, TagTypes


def add_types_to_charts(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
slices = metadata.tables["slices"]

Expand Down Expand Up @@ -53,11 +53,11 @@ def add_types_to_charts(
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, charts)
engine.execute(query)
db.session.execute(query)


def add_types_to_dashboards(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
dashboard_table = metadata.tables["dashboards"]

Expand Down Expand Up @@ -85,11 +85,11 @@ def add_types_to_dashboards(
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, dashboards)
engine.execute(query)
db.session.execute(query)


def add_types_to_saved_queries(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
saved_query = metadata.tables["saved_query"]

Expand Down Expand Up @@ -117,11 +117,11 @@ def add_types_to_saved_queries(
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, saved_queries)
engine.execute(query)
db.session.execute(query)


def add_types_to_datasets(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
tables = metadata.tables["tables"]

Expand Down Expand Up @@ -149,10 +149,10 @@ def add_types_to_datasets(
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, datasets)
engine.execute(query)
db.session.execute(query)


def add_types(engine: Engine, metadata: MetaData) -> None:
def add_types(metadata: MetaData) -> None:
"""
Tag every object according to its type:
Expand Down Expand Up @@ -222,18 +222,22 @@ def add_types(engine: Engine, metadata: MetaData) -> None:
insert = tag.insert()
for type_ in ObjectTypes.__members__:
try:
engine.execute(insert, name=f"type:{type_}", type=TagTypes.type)
db.session.execute(
insert,
name=f"type:{type_}",
type=TagTypes.type,
)
except IntegrityError:
pass # already exists

add_types_to_charts(engine, metadata, tag, tagged_object, columns)
add_types_to_dashboards(engine, metadata, tag, tagged_object, columns)
add_types_to_saved_queries(engine, metadata, tag, tagged_object, columns)
add_types_to_datasets(engine, metadata, tag, tagged_object, columns)
add_types_to_charts(metadata, tag, tagged_object, columns)
add_types_to_dashboards(metadata, tag, tagged_object, columns)
add_types_to_saved_queries(metadata, tag, tagged_object, columns)
add_types_to_datasets(metadata, tag, tagged_object, columns)


def add_owners_to_charts(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
slices = metadata.tables["slices"]

Expand Down Expand Up @@ -265,11 +269,11 @@ def add_owners_to_charts(
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, charts)
engine.execute(query)
db.session.execute(query)


def add_owners_to_dashboards(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
dashboard_table = metadata.tables["dashboards"]

Expand Down Expand Up @@ -301,11 +305,11 @@ def add_owners_to_dashboards(
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, dashboards)
engine.execute(query)
db.session.execute(query)


def add_owners_to_saved_queries(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
saved_query = metadata.tables["saved_query"]

Expand Down Expand Up @@ -337,11 +341,11 @@ def add_owners_to_saved_queries(
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, saved_queries)
engine.execute(query)
db.session.execute(query)


def add_owners_to_datasets(
engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
) -> None:
tables = metadata.tables["tables"]

Expand Down Expand Up @@ -373,10 +377,10 @@ def add_owners_to_datasets(
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, datasets)
engine.execute(query)
db.session.execute(query)


def add_owners(engine: Engine, metadata: MetaData) -> None:
def add_owners(metadata: MetaData) -> None:
"""
Tag every object according to its owner:
Expand Down Expand Up @@ -443,19 +447,19 @@ def add_owners(engine: Engine, metadata: MetaData) -> None:
# create a custom tag for each user
ids = select([users.c.id])
insert = tag.insert()
for (id_,) in engine.execute(ids):
for (id_,) in db.session.execute(ids):
try:
engine.execute(insert, name=f"owner:{id_}", type=TagTypes.owner)
db.session.execute(insert, name=f"owner:{id_}", type=TagTypes.owner)
except IntegrityError:
pass # already exists

add_owners_to_charts(engine, metadata, tag, tagged_object, columns)
add_owners_to_dashboards(engine, metadata, tag, tagged_object, columns)
add_owners_to_saved_queries(engine, metadata, tag, tagged_object, columns)
add_owners_to_datasets(engine, metadata, tag, tagged_object, columns)
add_owners_to_charts(metadata, tag, tagged_object, columns)
add_owners_to_dashboards(metadata, tag, tagged_object, columns)
add_owners_to_saved_queries(metadata, tag, tagged_object, columns)
add_owners_to_datasets(metadata, tag, tagged_object, columns)


def add_favorites(engine: Engine, metadata: MetaData) -> None:
def add_favorites(metadata: MetaData) -> None:
"""
Tag every object that was favorited:
Expand Down Expand Up @@ -484,9 +488,13 @@ def add_favorites(engine: Engine, metadata: MetaData) -> None:
# create a custom tag for each user
ids = select([users.c.id])
insert = tag.insert()
for (id_,) in engine.execute(ids):
for (id_,) in db.session.execute(ids):
try:
engine.execute(insert, name=f"favorited_by:{id_}", type=TagTypes.type)
db.session.execute(
insert,
name=f"favorited_by:{id_}",
type=TagTypes.type,
)
except IntegrityError:
pass # already exists

Expand Down Expand Up @@ -518,4 +526,4 @@ def add_favorites(engine: Engine, metadata: MetaData) -> None:
.where(tagged_object.c.tag_id.is_(None))
)
query = tagged_object.insert().from_select(columns, favstars)
engine.execute(query)
db.session.execute(query)

0 comments on commit 31895f4

Please sign in to comment.