From 5e55e09e3ea41569ecb24cd22e352ab1c274aab9 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Sat, 4 Apr 2020 12:45:14 -0700 Subject: [PATCH] [mypy] Enforcing typing for some modules (#9416) Co-authored-by: John Bodley --- setup.cfg | 2 +- superset/commands/base.py | 5 ++++- superset/commands/exceptions.py | 16 ++++++++-------- superset/common/query_context.py | 4 ++-- superset/common/query_object.py | 2 +- superset/common/tags.py | 9 +++++---- superset/dao/base.py | 2 +- superset/db_engines/hive.py | 11 +++++++++-- superset/stats_logger.py | 12 ++++++------ superset/utils/core.py | 3 ++- 10 files changed, 39 insertions(+), 27 deletions(-) diff --git a/setup.cfg b/setup.cfg index d835c761c05e3..9fd8ef228f3c8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ order_by_type = false ignore_missing_imports = true no_implicit_optional = true -[mypy-superset.charts.*,superset.db_engine_specs.*] +[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*] check_untyped_defs = true disallow_untyped_calls = true disallow_untyped_defs = true diff --git a/superset/commands/base.py b/superset/commands/base.py index 44f46eb742215..9c6de0c3bad0a 100644 --- a/superset/commands/base.py +++ b/superset/commands/base.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. from abc import ABC, abstractmethod +from typing import Optional + +from flask_appbuilder.models.sqla import Model class BaseCommand(ABC): @@ -23,7 +26,7 @@ class BaseCommand(ABC): """ @abstractmethod - def run(self): + def run(self) -> Optional[Model]: """ Run executes the command. Can raise command exceptions :raises: CommandException diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py index 03eef1f9d16c8..cf67ea9227159 100644 --- a/superset/commands/exceptions.py +++ b/superset/commands/exceptions.py @@ -25,10 +25,10 @@ class CommandException(SupersetException): """ Common base class for Command exceptions. """ - def __repr__(self): + def __repr__(self) -> str: if self._exception: - return self._exception - return self + return repr(self._exception) + return repr(self) class CommandInvalidError(CommandException): @@ -36,14 +36,14 @@ class CommandInvalidError(CommandException): status = 422 - def __init__(self, message="") -> None: + def __init__(self, message: str = "") -> None: self._invalid_exceptions: List[ValidationError] = [] super().__init__(self.message) - def add(self, exception: ValidationError): + def add(self, exception: ValidationError) -> None: self._invalid_exceptions.append(exception) - def add_list(self, exceptions: List[ValidationError]): + def add_list(self, exceptions: List[ValidationError]) -> None: self._invalid_exceptions.extend(exceptions) def normalized_messages(self) -> Dict[Any, Any]: @@ -76,12 +76,12 @@ class ForbiddenError(CommandException): class OwnersNotFoundValidationError(ValidationError): status = 422 - def __init__(self): + def __init__(self) -> None: super().__init__(_("Owners are invalid"), field_names=["owners"]) class DatasourceNotFoundValidationError(ValidationError): status = 404 - def __init__(self): + def __init__(self) -> None: super().__init__(_("Datasource does not exist"), field_names=["datasource_id"]) diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 7e1254068e2bd..525d8292ff0df 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -157,7 +157,7 @@ def cache_timeout(self) -> int: return self.datasource.database.cache_timeout return config["CACHE_DEFAULT_TIMEOUT"] - def cache_key(self, query_obj: QueryObject, **kwargs) -> Optional[str]: + def cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]: extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict()) cache_key = ( query_obj.cache_key( @@ -173,7 +173,7 @@ def cache_key(self, query_obj: QueryObject, **kwargs) -> Optional[str]: return cache_key def get_df_payload( # pylint: disable=too-many-locals,too-many-statements - self, query_obj: QueryObject, **kwargs + self, query_obj: QueryObject, **kwargs: Any ) -> Dict[str, Any]: """Handles caching around the df payload retrieval""" cache_key = self.cache_key(query_obj, **kwargs) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index e0681be7971d4..f5133857583f6 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -122,7 +122,7 @@ def to_dict(self) -> Dict[str, Any]: } return query_object_dict - def cache_key(self, **extra) -> str: + def cache_key(self, **extra: Any) -> str: """ The cache key is made out of the key/values from to_dict(), plus any other key/values in `extra` diff --git a/superset/common/tags.py b/superset/common/tags.py index 657611c602b8b..74c882cf92f69 100644 --- a/superset/common/tags.py +++ b/superset/common/tags.py @@ -14,14 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from sqlalchemy import Metadata +from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import and_, func, functions, join, literal, select from superset.models.tags import ObjectTypes, TagTypes -def add_types(engine, metadata): +def add_types(engine: Engine, metadata: Metadata) -> None: """ Tag every object according to its type: @@ -163,7 +164,7 @@ def add_types(engine, metadata): engine.execute(query) -def add_owners(engine, metadata): +def add_owners(engine: Engine, metadata: Metadata) -> None: """ Tag every object according to its owner: @@ -319,7 +320,7 @@ def add_owners(engine, metadata): engine.execute(query) -def add_favorites(engine, metadata): +def add_favorites(engine: Engine, metadata: Metadata) -> None: """ Tag every object that was favorited: diff --git a/superset/dao/base.py b/superset/dao/base.py index 8d6152a9d0783..7158643e219e9 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -112,7 +112,7 @@ def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model: return model @classmethod - def delete(cls, model: Model, commit=True): + def delete(cls, model: Model, commit: bool = True) -> Model: """ Generic delete a model :raises: DAOCreateFailedError diff --git a/superset/db_engines/hive.py b/superset/db_engines/hive.py index 093b5ebb05bb3..25f71b4cd28c1 100644 --- a/superset/db_engines/hive.py +++ b/superset/db_engines/hive.py @@ -14,12 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from pyhive.hive import Cursor # pylint: disable=unused-import + from TCLIService.ttypes import TFetchOrientation # pylint: disable=unused-import # pylint: disable=protected-access # TODO: contribute back to pyhive. def fetch_logs( - self, max_rows=1024, orientation=None -): # pylint: disable=unused-argument + self: "Cursor", + max_rows: int = 1024, # pylint: disable=unused-argument + orientation: Optional["TFetchOrientation"] = None, +) -> str: # pylint: disable=unused-argument """Mocked. Retrieve the logs produced by the execution of the query. Can be called multiple times to fetch the logs produced after the previous call. diff --git a/superset/stats_logger.py b/superset/stats_logger.py index 758208a6040f2..37fe3d39d6f05 100644 --- a/superset/stats_logger.py +++ b/superset/stats_logger.py @@ -24,26 +24,26 @@ class BaseStatsLogger: """Base class for logging realtime events""" - def __init__(self, prefix="superset"): + def __init__(self, prefix: str = "superset") -> None: self.prefix = prefix - def key(self, key): + def key(self, key: str) -> str: if self.prefix: return self.prefix + key return key - def incr(self, key): + def incr(self, key: str) -> None: """Increment a counter""" raise NotImplementedError() - def decr(self, key): + def decr(self, key: str) -> None: """Decrement a counter""" raise NotImplementedError() - def timing(self, key, value): + def timing(self, key, value: float) -> None: raise NotImplementedError() - def gauge(self, key): + def gauge(self, key: str) -> None: """Setup a gauge""" raise NotImplementedError() diff --git a/superset/utils/core.py b/superset/utils/core.py index 999c4d4784c69..95f103298c5c1 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1224,9 +1224,10 @@ class DatasourceName(NamedTuple): schema: str -def get_stacktrace(): +def get_stacktrace() -> Optional[str]: if current_app.config["SHOW_STACKTRACE"]: return traceback.format_exc() + return None def split(