Skip to content

Commit

Permalink
chore: bump black to 19.10b0 and mypy to 0.770 (#9378)
Browse files Browse the repository at this point in the history
* Bump black to 19.10b0

* Upgrade mypy to 0.770

* Update how inline type is defined
  • Loading branch information
ktmud authored Apr 4, 2020
1 parent 5e55e09 commit 801e2f1
Show file tree
Hide file tree
Showing 13 changed files with 36 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
repos:
- repo: https://github.com/ambv/black
rev: 19.3b0
rev: 19.10b0
hooks:
- id: black
language_version: python3
Expand Down
4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
black==19.3b0
black==19.10b0
coverage==4.5.3
flask-cors==3.0.7
flask-testing==0.7.1
ipdb==0.12
isort==4.3.21
mypy==0.670
mypy==0.770
nose==1.3.7
pip-tools==4.5.1
pre-commit==1.17.0
Expand Down
19 changes: 10 additions & 9 deletions superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def granularity(
"year": "P1Y",
}

granularity = {"type": "period"}
granularity: Dict[str, Union[str, float]] = {"type": "period"}
if timezone:
granularity["timeZone"] = timezone

Expand All @@ -831,7 +831,7 @@ def granularity(
granularity["period"] = period_name
else:
granularity["type"] = "duration"
granularity["duration"] = ( # type: ignore
granularity["duration"] = (
utils.parse_human_timedelta(period_name).total_seconds() * 1000
)
return granularity
Expand Down Expand Up @@ -941,23 +941,24 @@ def metrics_and_post_aggs(
adhoc_agg_configs = []
postagg_names = []
for metric in metrics:
if utils.is_adhoc_metric(metric):
if isinstance(metric, dict) and utils.is_adhoc_metric(metric):
adhoc_agg_configs.append(metric)
elif metrics_dict[metric].metric_type != POST_AGG_TYPE: # type: ignore
saved_agg_names.add(metric)
else:
postagg_names.append(metric)
elif isinstance(metric, str):
if metrics_dict[metric].metric_type != POST_AGG_TYPE:
saved_agg_names.add(metric)
else:
postagg_names.append(metric)
# Create the post aggregations, maintain order since postaggs
# may depend on previous ones
post_aggs: "OrderedDict[str, Postaggregator]" = OrderedDict()
visited_postaggs = set()
for postagg_name in postagg_names:
postagg = metrics_dict[postagg_name] # type: ignore
postagg = metrics_dict[postagg_name]
visited_postaggs.add(postagg_name)
DruidDatasource.resolve_postagg(
postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict
)
aggs = DruidDatasource.get_aggregations( # type: ignore
aggs = DruidDatasource.get_aggregations(
metrics_dict, saved_agg_names, adhoc_agg_configs
)
return aggs, post_aggs
Expand Down
2 changes: 1 addition & 1 deletion superset/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class CommaSeparatedListField(Field):
widget = BS3TextFieldWidget()
data = [] # type: List[str]
data: List[str] = []

def _value(self):
if self.data:
Expand Down
4 changes: 2 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def get_all_table_names_in_schema(
self,
schema: str,
cache: bool = False,
cache_timeout: int = None,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
"""Parameters need to be passed as keyword arguments.
Expand Down Expand Up @@ -492,7 +492,7 @@ def get_all_view_names_in_schema(
self,
schema: str,
cache: bool = False,
cache_timeout: int = None,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
"""Parameters need to be passed as keyword arguments.
Expand Down
12 changes: 4 additions & 8 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from cryptography.hazmat.backends.openssl.x509 import _Certificate
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from flask import current_app, flash, Flask, g, Markup, render_template
from flask import current_app, flash, g, Markup, render_template
from flask_appbuilder import SQLA
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __, lazy_gettext as _
Expand Down Expand Up @@ -1057,15 +1057,11 @@ def get_since_until(
else:
rel, num, grain = time_range.split()
if rel == "Last":
since = relative_start - relativedelta( # type: ignore
**{grain: int(num)}
)
since = relative_start - relativedelta(**{grain: int(num)}) # type: ignore
until = relative_end
else: # rel == 'Next'
since = relative_start
until = relative_end + relativedelta( # type: ignore
**{grain: int(num)}
)
until = relative_end + relativedelta(**{grain: int(num)}) # type: ignore
else:
since = since or ""
if since:
Expand Down Expand Up @@ -1184,7 +1180,7 @@ def parse_ssl_cert(certificate: str) -> _Certificate:
return x509.load_pem_x509_certificate(
certificate.encode("utf-8"), default_backend()
)
except ValueError as e:
except ValueError:
raise CertificateException("Invalid certificate")


Expand Down
5 changes: 4 additions & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ def check_datasource_perms(
except SupersetException as e:
raise SupersetSecurityException(str(e))

viz_obj = get_viz( # type: ignore
if datasource_type is None:
raise SupersetSecurityException("Could not determine datasource type")

viz_obj = get_viz(
datasource_type=datasource_type,
datasource_id=datasource_id,
form_data=form_data,
Expand Down
2 changes: 1 addition & 1 deletion superset/views/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class EmailScheduleView(
): # pylint: disable=too-many-ancestors
include_route_methods = RouteMethod.CRUD_SET
_extra_data = {"test_email": False, "test_email_recipients": None}
schedule_type: Optional[Type] = None
schedule_type: Optional[str] = None
schedule_type_model: Optional[Type] = None

page_size = 20
Expand Down
5 changes: 1 addition & 4 deletions tests/access_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,7 @@ def test_request_access(self):
self.get_resp(ACCESS_REQUEST.format("druid", druid_ds_4_id, "go"))
access_request4 = self.get_access_requests("gamma", "druid", druid_ds_4_id)

self.assertEqual(
access_request4.roles_with_datasource,
"<ul></ul>".format(access_request4.id),
)
self.assertEqual(access_request4.roles_with_datasource, "<ul></ul>")

# Case 5. Roles exist that contains the druid datasource.
# add druid ds to the existing roles
Expand Down
10 changes: 5 additions & 5 deletions tests/db_engine_specs/hive_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_job_1_launched_stage_1(self):
self.assertEqual(0, HiveEngineSpec.progress(log))

def test_job_1_launched_stage_1_map_40_progress(
self
self,
): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
Expand All @@ -71,7 +71,7 @@ def test_job_1_launched_stage_1_map_40_progress(
self.assertEqual(10, HiveEngineSpec.progress(log))

def test_job_1_launched_stage_1_map_80_reduce_40_progress(
self
self,
): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
Expand All @@ -85,7 +85,7 @@ def test_job_1_launched_stage_1_map_80_reduce_40_progress(
self.assertEqual(30, HiveEngineSpec.progress(log))

def test_job_1_launched_stage_2_stages_progress(
self
self,
): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
Expand All @@ -101,7 +101,7 @@ def test_job_1_launched_stage_2_stages_progress(
self.assertEqual(12, HiveEngineSpec.progress(log))

def test_job_2_launched_stage_2_stages_progress(
self
self,
): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_hive_error_msg(self):
)

def test_hive_get_view_names_return_empty_list(
self
self,
): # pylint: disable=invalid-name
self.assertEqual(
[], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
Expand Down
2 changes: 1 addition & 1 deletion tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_get_datatype_presto(self):
self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string"))

def test_presto_get_view_names_return_empty_list(
self
self,
): # pylint: disable=invalid-name
self.assertEqual(
[], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
Expand Down
3 changes: 2 additions & 1 deletion tests/superset_test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# type: ignore
from copy import copy

from superset.config import * # type: ignore
from superset.config import *

AUTH_USER_REGISTRATION_ROLE = "alpha"
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")
Expand Down
3 changes: 2 additions & 1 deletion tests/superset_test_config_sqllab_backend_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
# specific language governing permissions and limitations
# under the License.
# flake8: noqa
# type: ignore
import os
from copy import copy

from superset.config import * # type: ignore
from superset.config import *

AUTH_USER_REGISTRATION_ROLE = "alpha"
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")
Expand Down

0 comments on commit 801e2f1

Please sign in to comment.