diff --git a/setup.cfg b/setup.cfg index eda501d7b2acc..d835c761c05e3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ order_by_type = false ignore_missing_imports = true no_implicit_optional = true -[mypy-superset.db_engine_specs.*] +[mypy-superset.charts.*,superset.db_engine_specs.*] check_untyped_defs = true disallow_untyped_calls = true disallow_untyped_defs = true diff --git a/superset/charts/api.py b/superset/charts/api.py index 882e6596ac0c1..feb6e2230da98 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from typing import Any from flask import g, request, Response from flask_appbuilder.api import expose, protect, rison, safe @@ -287,7 +288,9 @@ def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ @protect() @safe @rison(get_delete_ids_schema) - def bulk_delete(self, **kwargs) -> Response: # pylint: disable=arguments-differ + def bulk_delete( + self, **kwargs: Any + ) -> Response: # pylint: disable=arguments-differ """Delete bulk Charts --- delete: diff --git a/superset/charts/commands/bulk_delete.py b/superset/charts/commands/bulk_delete.py index f8ded64fd9fc9..5e715610b879e 100644 --- a/superset/charts/commands/bulk_delete.py +++ b/superset/charts/commands/bulk_delete.py @@ -40,7 +40,7 @@ def __init__(self, user: User, model_ids: List[int]): self._model_ids = model_ids self._models: Optional[List[Slice]] = None - def run(self): + def run(self) -> None: self.validate() try: ChartDAO.bulk_delete(self._models) diff --git a/superset/charts/commands/create.py b/superset/charts/commands/create.py index b86fdcfab613c..6efcfa17b3379 100644 --- a/superset/charts/commands/create.py +++ b/superset/charts/commands/create.py @@ -17,6 +17,7 @@ import logging from typing import Dict, List, Optional +from flask_appbuilder.models.sqla import Model from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError @@ -39,7 +40,7 @@ def __init__(self, user: User, data: Dict): self._actor = user self._properties = data.copy() - def run(self): + def run(self) -> Model: self.validate() try: chart = ChartDAO.create(self._properties) diff --git a/superset/charts/commands/delete.py b/superset/charts/commands/delete.py index 51b4c5f65a083..bf85d9ea19dfb 100644 --- a/superset/charts/commands/delete.py +++ b/superset/charts/commands/delete.py @@ -17,6 +17,7 @@ import logging from typing import Optional +from flask_appbuilder.models.sqla import Model from flask_appbuilder.security.sqla.models import User from superset.charts.commands.exceptions import ( @@ -40,7 +41,7 @@ def __init__(self, user: User, model_id: int): self._model_id = model_id self._model: Optional[SqlaTable] = None - def run(self): + def run(self) -> Model: self.validate() try: chart = ChartDAO.delete(self._model) diff --git a/superset/charts/commands/exceptions.py b/superset/charts/commands/exceptions.py index 9e0f79df81dc9..2308d62a77216 100644 --- a/superset/charts/commands/exceptions.py +++ b/superset/charts/commands/exceptions.py @@ -32,7 +32,7 @@ class DatabaseNotFoundValidationError(ValidationError): Marshmallow validation error for database does not exist """ - def __init__(self): + def __init__(self) -> None: super().__init__(_("Database does not exist"), field_names=["database"]) @@ -41,7 +41,7 @@ class DashboardsNotFoundValidationError(ValidationError): Marshmallow validation error for dashboards don't exist """ - def __init__(self): + def __init__(self) -> None: super().__init__(_("Dashboards do not exist"), field_names=["dashboards"]) @@ -50,7 +50,7 @@ class DatasourceTypeUpdateRequiredValidationError(ValidationError): Marshmallow validation error for dashboards don't exist """ - def __init__(self): + def __init__(self) -> None: super().__init__( _("Datasource type is required when datasource_id is given"), field_names=["datasource_type"], diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index 1698c9f798261..5d9ee2a3e2723 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -17,6 +17,7 @@ import logging from typing import Dict, List, Optional +from flask_appbuilder.models.sqla import Model from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError @@ -47,7 +48,7 @@ def __init__(self, user: User, model_id: int, data: Dict): self._properties = data.copy() self._model: Optional[SqlaTable] = None - def run(self): + def run(self) -> Model: self.validate() try: chart = ChartDAO.update(self._model, self._properties) diff --git a/superset/charts/dao.py b/superset/charts/dao.py index ec56b5d0b3147..80ea3d6f47b6d 100644 --- a/superset/charts/dao.py +++ b/superset/charts/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List +from typing import List, Optional from sqlalchemy.exc import SQLAlchemyError @@ -32,13 +32,14 @@ class ChartDAO(BaseDAO): base_filter = ChartFilter @staticmethod - def bulk_delete(models: List[Slice], commit=True): - item_ids = [model.id for model in models] + def bulk_delete(models: Optional[List[Slice]], commit: bool = True) -> None: + item_ids = [model.id for model in models] if models else [] # bulk delete, first delete related data - for model in models: - model.owners = [] - model.dashboards = [] - db.session.merge(model) + if models: + for model in models: + model.owners = [] + model.dashboards = [] + db.session.merge(model) # bulk delete itself try: db.session.query(Slice).filter(Slice.id.in_(item_ids)).delete( diff --git a/superset/charts/filters.py b/superset/charts/filters.py index 77c0020d672cf..a35ba2912b073 100644 --- a/superset/charts/filters.py +++ b/superset/charts/filters.py @@ -14,14 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any + from sqlalchemy import or_ +from sqlalchemy.orm.query import Query from superset import security_manager from superset.views.base import BaseFilter class ChartFilter(BaseFilter): # pylint: disable=too-few-public-methods - def apply(self, query, value): + def apply(self, query: Query, value: Any) -> Query: if security_manager.all_datasource_access(): return query perms = security_manager.user_view_menu_names("datasource_access") diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 9a965a43e5986..96dc3d36066a8 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Union from marshmallow import fields, Schema, ValidationError from marshmallow.validate import Length @@ -24,7 +25,7 @@ get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}} -def validate_json(value): +def validate_json(value: Union[bytes, bytearray, str]) -> None: try: utils.validate_json(value) except SupersetException: diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py index 64b9f62414960..03eef1f9d16c8 100644 --- a/superset/commands/exceptions.py +++ b/superset/commands/exceptions.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List +from typing import Any, Dict, List from flask_babel import lazy_gettext as _ from marshmallow import ValidationError @@ -36,8 +36,8 @@ class CommandInvalidError(CommandException): status = 422 - def __init__(self, message=""): - self._invalid_exceptions = list() + def __init__(self, message="") -> None: + self._invalid_exceptions: List[ValidationError] = [] super().__init__(self.message) def add(self, exception: ValidationError): @@ -46,8 +46,8 @@ def add(self, exception: ValidationError): def add_list(self, exceptions: List[ValidationError]): self._invalid_exceptions.extend(exceptions) - def normalized_messages(self): - errors = {} + def normalized_messages(self) -> Dict[Any, Any]: + errors: Dict[Any, Any] = {} for exception in self._invalid_exceptions: errors.update(exception.normalized_messages()) return errors diff --git a/superset/dao/base.py b/superset/dao/base.py index f0133515cb210..8d6152a9d0783 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -75,7 +75,7 @@ def find_by_ids(cls, model_ids: List[int]) -> List[Model]: return query.all() @classmethod - def create(cls, properties: Dict, commit=True) -> Optional[Model]: + def create(cls, properties: Dict, commit: bool = True) -> Model: """ Generic for creating models :raises: DAOCreateFailedError @@ -95,7 +95,7 @@ def create(cls, properties: Dict, commit=True) -> Optional[Model]: return model @classmethod - def update(cls, model: Model, properties: Dict, commit=True) -> Optional[Model]: + def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model: """ Generic update a model :raises: DAOCreateFailedError diff --git a/superset/utils/core.py b/superset/utils/core.py index 84fbe46d9226a..999c4d4784c69 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -547,7 +547,7 @@ def get_datasource_full_name(database_name, datasource_name, schema=None): return "[{}].[{}].[{}]".format(database_name, schema, datasource_name) -def validate_json(obj): +def validate_json(obj: Union[bytes, bytearray, str]) -> None: if obj: try: json.loads(obj) diff --git a/superset/views/base.py b/superset/views/base.py index 75d6d54d3d130..595cdbde64170 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -18,7 +18,7 @@ import logging import traceback from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import simplejson as json import yaml @@ -27,6 +27,7 @@ from flask_appbuilder.actions import action from flask_appbuilder.forms import DynamicForm from flask_appbuilder.models.sqla.filters import BaseFilter +from flask_appbuilder.security.sqla.models import User from flask_appbuilder.widgets import ListWidget from flask_babel import get_locale, gettext as __, lazy_gettext as _ from flask_wtf.form import FlaskForm @@ -365,7 +366,7 @@ class CsvResponse(Response): # pylint: disable=too-many-ancestors charset = conf["CSV_EXPORT"].get("encoding", "utf-8") -def check_ownership(obj, raise_if_false=True): +def check_ownership(obj: Any, raise_if_false: bool = True) -> bool: """Meant to be used in `pre_update` hooks on models to enforce ownership Admin have all access, and other users need to be referenced on either @@ -392,7 +393,7 @@ def check_ownership(obj, raise_if_false=True): orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first() # Making a list of owners that works across ORM models - owners = [] + owners: List[User] = [] if hasattr(orig_obj, "owners"): owners += orig_obj.owners if hasattr(orig_obj, "owner"):