Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: use contextlib.surpress instead of passing on error #24896

Merged
merged 8 commits into from
Aug 29, 2023
12 changes: 3 additions & 9 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import contextlib
import json
import logging
from typing import Any, TYPE_CHECKING
Expand Down Expand Up @@ -223,11 +224,8 @@ def data(self) -> Response:
json_body = request.json
elif request.form.get("form_data"):
# CSV export submits regular form data
try:
with contextlib.suppress(TypeError, json.JSONDecodeError):
json_body = json.loads(request.form["form_data"])
except (TypeError, json.JSONDecodeError):
pass

if json_body is None:
return self.response_400(message=_("Request is not JSON"))

Expand Down Expand Up @@ -324,14 +322,10 @@ def _run_async(
Execute command as an async query.
"""
# First, look for the chart query results in the cache.
result = None
try:
with contextlib.suppress(ChartDataCacheLoadError):
result = command.run(force_cached=True)
if result is not None:
return self._send_chart_response(result)
except ChartDataCacheLoadError:
pass

# Otherwise, kick off a background job to run the chart query.
# Clients will either poll or be notified of query completion,
# at which point they will call the /data/<cache_key> endpoint
Expand Down
27 changes: 6 additions & 21 deletions superset/common/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
from typing import Any

from sqlalchemy import MetaData
Expand Down Expand Up @@ -221,14 +222,8 @@ def add_types(metadata: MetaData) -> None:
# add a tag for each object type
insert = tag.insert()
for type_ in ObjectTypes.__members__:
try:
db.session.execute(
insert,
name=f"type:{type_}",
type=TagTypes.type,
)
except IntegrityError:
pass # already exists
with contextlib.suppress(IntegrityError): # already exists
db.session.execute(insert, name=f"type:{type_}", type=TagTypes.type)

add_types_to_charts(metadata, tag, tagged_object, columns)
add_types_to_dashboards(metadata, tag, tagged_object, columns)
Expand Down Expand Up @@ -448,11 +443,8 @@ def add_owners(metadata: MetaData) -> None:
ids = select([users.c.id])
insert = tag.insert()
for (id_,) in db.session.execute(ids):
try:
with contextlib.suppress(IntegrityError): # already exists
db.session.execute(insert, name=f"owner:{id_}", type=TagTypes.owner)
except IntegrityError:
pass # already exists

add_owners_to_charts(metadata, tag, tagged_object, columns)
add_owners_to_dashboards(metadata, tag, tagged_object, columns)
add_owners_to_saved_queries(metadata, tag, tagged_object, columns)
Expand Down Expand Up @@ -489,15 +481,8 @@ def add_favorites(metadata: MetaData) -> None:
ids = select([users.c.id])
insert = tag.insert()
for (id_,) in db.session.execute(ids):
try:
db.session.execute(
insert,
name=f"favorited_by:{id_}",
type=TagTypes.type,
)
except IntegrityError:
pass # already exists

with contextlib.suppress(IntegrityError): # already exists
db.session.execute(insert, name=f"favorited_by:{id_}", type=TagTypes.type)
favstars = (
select(
[
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import json
import re
import urllib
Expand Down Expand Up @@ -557,11 +558,8 @@ def mask_encrypted_extra(cls, encrypted_extra: Optional[str]) -> Optional[str]:
except (json.JSONDecodeError, TypeError):
return encrypted_extra

try:
with contextlib.suppress(KeyError):
config["credentials_info"]["private_key"] = PASSWORD_MASK
except KeyError:
pass

return json.dumps(config)

@classmethod
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import contextlib
import json
import logging
import re
Expand Down Expand Up @@ -167,11 +168,8 @@ def mask_encrypted_extra(cls, encrypted_extra: str | None) -> str | None:
except (TypeError, json.JSONDecodeError):
return encrypted_extra

try:
with contextlib.suppress(KeyError):
config["service_account_info"]["private_key"] = PASSWORD_MASK
except KeyError:
pass

return json.dumps(config)

@classmethod
Expand Down
5 changes: 2 additions & 3 deletions superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import re
from datetime import datetime
from re import Pattern
Expand Down Expand Up @@ -258,11 +259,9 @@ def epoch_to_dttm(cls) -> str:
def _extract_error_message(cls, ex: Exception) -> str:
"""Extract error message for queries"""
message = str(ex)
try:
with contextlib.suppress(AttributeError, KeyError):
if isinstance(ex.args, tuple) and len(ex.args) > 1:
message = ex.args[1]
except (AttributeError, KeyError):
pass
return message

@classmethod
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/ocient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.

import contextlib
import re
import threading
from re import Pattern
Expand All @@ -24,8 +25,7 @@
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import Session

# Need to try-catch here because pyocient may not be installed
try:
with contextlib.suppress(ImportError, RuntimeError): # pyocient may not be installed
# Ensure pyocient inherits Superset's logging level
import geojson
import pyocient
Expand All @@ -35,8 +35,6 @@

superset_log_level = app.config["LOG_LEVEL"]
pyocient.logger.setLevel(superset_log_level)
except (ImportError, RuntimeError):
pass

from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec
Expand Down
10 changes: 3 additions & 7 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=too-many-lines
from __future__ import annotations

import contextlib
import logging
import re
import time
Expand Down Expand Up @@ -67,11 +68,8 @@
# prevent circular imports
from superset.models.core import Database

# need try/catch because pyhive may not be installed
try:
with contextlib.suppress(ImportError): # pyhive may not be installed
from pyhive.presto import Cursor
except ImportError:
pass

COLUMN_DOES_NOT_EXIST_REGEX = re.compile(
"line (?P<location>.+?): .*Column '(?P<column_name>.+?)' cannot be resolved"
Expand Down Expand Up @@ -1274,12 +1272,10 @@ def get_create_view(

@classmethod
def get_tracking_url(cls, cursor: Cursor) -> str | None:
try:
with contextlib.suppress(AttributeError):
if cursor.last_query_id:
# pylint: disable=protected-access, line-too-long
return f"{cursor._protocol}://{cursor._host}:{cursor._port}/ui/query.html?{cursor.last_query_id}"
except AttributeError:
pass
return None

@classmethod
Expand Down
9 changes: 3 additions & 6 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import contextlib
import logging
from typing import Any, TYPE_CHECKING

Expand All @@ -35,10 +36,8 @@
if TYPE_CHECKING:
from superset.models.core import Database

try:
with contextlib.suppress(ImportError): # trino may not be installed
from trino.dbapi import Cursor
except ImportError:
pass

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -140,12 +139,10 @@ def get_tracking_url(cls, cursor: Cursor) -> str | None:
try:
return cursor.info_uri
except AttributeError:
try:
with contextlib.suppress(AttributeError):
conn = cursor.connection
# pylint: disable=protected-access, line-too-long
return f"{conn.http_scheme}://{conn.host}:{conn.port}/ui/query.html?{cursor._query.query_id}"
except AttributeError:
pass
return None

@classmethod
Expand Down
7 changes: 3 additions & 4 deletions superset/explore/commands/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import logging
from abc import ABC
from typing import Any, cast, Optional
Expand Down Expand Up @@ -107,17 +108,15 @@ def run(self) -> Optional[dict[str, Any]]:
)
except SupersetException:
self._datasource_id = None
# fallback unkonw datasource to table type
# fallback unknown datasource to table type
self._datasource_type = SqlaTable.type

datasource: Optional[BaseDatasource] = None
if self._datasource_id is not None:
try:
with contextlib.suppress(DatasourceNotFound):
datasource = DatasourceDAO.get_datasource(
db.session, cast(str, self._datasource_type), self._datasource_id
)
except DatasourceNotFound:
pass
datasource_name = datasource.name if datasource else _("[Missing Dataset]")
viz_type = form_data.get("viz_type")
if not viz_type and datasource and datasource.default_endpoint:
Expand Down
8 changes: 3 additions & 5 deletions superset/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import contextlib
import logging
import os
import sys
Expand All @@ -25,7 +26,7 @@
from deprecation import deprecated
from flask import Flask, redirect
from flask_appbuilder import expose, IndexView
from flask_babel import gettext as __, lazy_gettext as _
from flask_babel import gettext as __
from flask_compress import Compress
from werkzeug.middleware.proxy_fix import ProxyFix

Expand Down Expand Up @@ -594,11 +595,8 @@ def __call__(
self.superset_app.wsgi_app = ChunkedEncodingFix(self.superset_app.wsgi_app)

if self.config["UPLOAD_FOLDER"]:
try:
with contextlib.suppress(OSError):
os.makedirs(self.config["UPLOAD_FOLDER"])
except OSError:
pass

for middleware in self.config["ADDITIONAL_MIDDLEWARE"]:
self.superset_app.wsgi_app = middleware(self.superset_app.wsgi_app)

Expand Down
18 changes: 5 additions & 13 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import logging
import textwrap
from ast import literal_eval
from contextlib import closing, contextmanager, nullcontext
from contextlib import closing, contextmanager, nullcontext, suppress
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
Expand Down Expand Up @@ -225,7 +225,6 @@ def allows_cost_estimate(self) -> bool:
@property
def allows_virtual_table_explore(self) -> bool:
extra = self.get_extra()

return bool(extra.get("allows_virtual_table_explore", True))

@property
Expand All @@ -235,9 +234,7 @@ def explore_database_id(self) -> int:
@property
def disable_data_preview(self) -> bool:
# this will prevent any 'trash value' strings from going through
if self.get_extra().get("disable_data_preview", False) is not True:
return False
return True
return self.get_extra().get("disable_data_preview", False) is True

@property
def data(self) -> dict[str, Any]:
Expand Down Expand Up @@ -285,11 +282,8 @@ def parameters(self) -> dict[str, Any]:
masked_uri = make_url_safe(self.sqlalchemy_uri)
encrypted_config = {}
if (masked_encrypted_extra := self.masked_encrypted_extra) is not None:
try:
with suppress(TypeError, json.JSONDecodeError):
encrypted_config = json.loads(masked_encrypted_extra)
except (TypeError, json.JSONDecodeError):
pass

try:
# pylint: disable=useless-suppression
parameters = self.db_engine_spec.get_parameters_from_uri( # type: ignore
Expand Down Expand Up @@ -550,7 +544,7 @@ def get_default_schema_for_query(self, query: Query) -> str | None:

@property
def quote_identifier(self) -> Callable[[str], str]:
"""Add quotes to potential identifiter expressions if needed"""
"""Add quotes to potential identifier expressions if needed"""
return self.get_dialect().identifier_preparer.quote

def get_reserved_words(self) -> set[str]:
Expand Down Expand Up @@ -692,15 +686,14 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument
"""
try:
with self.get_inspector_with_context() as inspector:
tables = {
return {
(table, schema)
for table in self.db_engine_spec.get_table_names(
database=self,
inspector=inspector,
schema=schema,
)
}
return tables
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)

Expand Down Expand Up @@ -985,7 +978,6 @@ def make_sqla_column_compatible(


class Log(Model): # pylint: disable=too-few-public-methods

"""ORM object used to log Superset actions to the database"""

__tablename__ = "logs"
Expand Down
Loading