From d583ca9ef57d1c49ae84cda8cc888ee01dcf5601 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 19 May 2023 00:37:13 -0700 Subject: [PATCH 1/7] chore: Embrace the walrus operator (#24127) --- .pre-commit-config.yaml | 4 ++ superset/charts/api.py | 6 +-- superset/charts/commands/bulk_delete.py | 3 +- superset/charts/commands/delete.py | 3 +- superset/common/query_actions.py | 3 +- superset/common/query_context_factory.py | 3 +- superset/common/query_context_processor.py | 3 +- superset/common/utils/query_cache_manager.py | 3 +- superset/connectors/sqla/models.py | 3 +- superset/dashboards/commands/bulk_delete.py | 3 +- superset/dashboards/commands/delete.py | 3 +- superset/dashboards/dao.py | 3 +- superset/databases/api.py | 3 +- superset/databases/commands/delete.py | 3 +- .../databases/commands/test_connection.py | 3 +- superset/databases/commands/validate.py | 3 +- superset/datasets/api.py | 3 +- .../datasets/commands/importers/v1/utils.py | 3 +- superset/datasets/commands/update.py | 6 +-- superset/db_engine_specs/base.py | 6 +-- superset/db_engine_specs/bigquery.py | 3 +- superset/db_engine_specs/databricks.py | 3 +- superset/db_engine_specs/presto.py | 6 +-- superset/db_engine_specs/snowflake.py | 3 +- superset/db_engine_specs/trino.py | 6 +-- superset/errors.py | 3 +- superset/initialization/__init__.py | 3 +- superset/migrations/env.py | 5 +-- .../migrations/shared/migrate_viz/base.py | 6 +-- .../shared/migrate_viz/processors.py | 6 +-- ...f3fed1fe_convert_dashboard_v1_positions.py | 3 +- ...95_migrate_native_filters_to_new_schema.py | 6 +-- ...-25_31b2a1039d4a_drop_tables_constraint.py | 3 +- superset/models/core.py | 3 +- superset/models/helpers.py | 3 +- superset/models/slice.py | 3 +- superset/queries/saved_queries/api.py | 3 +- superset/reports/commands/execute.py | 3 +- superset/result_set.py | 6 +-- superset/security/manager.py | 12 ++--- superset/sql_lab.py | 3 +- superset/sql_validators/presto_db.py | 3 +- superset/utils/core.py | 3 +- superset/utils/screenshots.py | 3 +- superset/views/api.py | 3 +- superset/views/base_api.py | 6 +-- superset/views/core.py | 45 +++++++------------ superset/views/database/validators.py | 3 +- superset/views/log/api.py | 3 +- superset/views/utils.py | 7 ++- superset/viz.py | 26 +++++------ .../data_loading/pandas/pandas_data_loader.py | 3 +- tests/integration_tests/csv_upload_tests.py | 9 ++-- tests/integration_tests/datasets/api_tests.py | 9 ++-- 54 files changed, 100 insertions(+), 185 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ff79da9e96f0..3f524b3658b57 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,6 +19,10 @@ repos: rev: 5.12.0 hooks: - id: isort + - repo: https://github.com/MarcoGorelli/auto-walrus + rev: v0.2.2 + hooks: + - id: auto-walrus - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.3.0 hooks: diff --git a/superset/charts/api.py b/superset/charts/api.py index 6a4bf04aa1a4b..2c50a8d163cb8 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -649,8 +649,7 @@ def screenshot(self, pk: int, digest: str) -> WerkzeugResponse: return self.response_404() # fetch the chart screenshot using the current user and cache if set - img = ChartScreenshot.get_from_cache_key(thumbnail_cache, digest) - if img: + if img := ChartScreenshot.get_from_cache_key(thumbnail_cache, digest): return Response( FileWrapper(img), mimetype="image/png", direct_passthrough=True ) @@ -783,7 +782,6 @@ def export(self, **kwargs: Any) -> Response: 500: $ref: '#/components/responses/500' """ - token = request.args.get("token") requested_ids = kwargs["rison"] timestamp = datetime.now().strftime("%Y%m%dT%H%M%S") root = f"chart_export_{timestamp}" @@ -805,7 +803,7 @@ def export(self, **kwargs: Any) -> Response: as_attachment=True, download_name=filename, ) - if token: + if token := request.args.get("token"): response.set_cookie(token, "done", max_age=600) return response diff --git a/superset/charts/commands/bulk_delete.py b/superset/charts/commands/bulk_delete.py index caf8fe0399228..c252f0be4cc28 100644 --- a/superset/charts/commands/bulk_delete.py +++ b/superset/charts/commands/bulk_delete.py @@ -55,8 +55,7 @@ def validate(self) -> None: if not self._models or len(self._models) != len(self._model_ids): raise ChartNotFoundError() # Check there are no associated ReportSchedules - reports = ReportScheduleDAO.find_by_chart_ids(self._model_ids) - if reports: + if reports := ReportScheduleDAO.find_by_chart_ids(self._model_ids): report_names = [report.name for report in reports] raise ChartBulkDeleteFailedReportsExistError( _("There are associated alerts or reports: %s" % ",".join(report_names)) diff --git a/superset/charts/commands/delete.py b/superset/charts/commands/delete.py index 4c636f0433a73..11f6e5925773d 100644 --- a/superset/charts/commands/delete.py +++ b/superset/charts/commands/delete.py @@ -64,8 +64,7 @@ def validate(self) -> None: if not self._model: raise ChartNotFoundError() # Check there are no associated ReportSchedules - reports = ReportScheduleDAO.find_by_chart_id(self._model_id) - if reports: + if reports := ReportScheduleDAO.find_by_chart_id(self._model_id): report_names = [report.name for report in reports] raise ChartDeleteFailedReportsExistError( _("There are associated alerts or reports: %s" % ",".join(report_names)) diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index 38526475b9349..f6f5a5cd62cfb 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -221,8 +221,7 @@ def get_query_results( :raises QueryObjectValidationError: if an unsupported result type is requested :return: JSON serializable result payload """ - result_func = _result_type_functions.get(result_type) - if result_func: + if result_func := _result_type_functions.get(result_type): return result_func(query_context, query_obj, force_cached) raise QueryObjectValidationError( _("Invalid result type: %(result_type)s", result_type=result_type) diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index a42d1d4ba7316..84c0415722c99 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -125,10 +125,9 @@ def _apply_granularity( for column in datasource.columns if (column["is_dttm"] if isinstance(column, dict) else column.is_dttm) } - granularity = query_object.granularity x_axis = form_data and form_data.get("x_axis") - if granularity: + if granularity := query_object.granularity: filter_to_remove = None if x_axis and x_axis in temporal_columns: filter_to_remove = x_axis diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 56f07dcb64b5e..85a2b5d97ae72 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -500,8 +500,7 @@ def get_payload( return return_value def get_cache_timeout(self) -> int: - cache_timeout_rv = self._query_context.get_cache_timeout() - if cache_timeout_rv: + if cache_timeout_rv := self._query_context.get_cache_timeout(): return cache_timeout_rv if ( data_cache_timeout := config["DATA_CACHE_CONFIG"].get( diff --git a/superset/common/utils/query_cache_manager.py b/superset/common/utils/query_cache_manager.py index 7143fcc201a57..6c1b268f46534 100644 --- a/superset/common/utils/query_cache_manager.py +++ b/superset/common/utils/query_cache_manager.py @@ -148,8 +148,7 @@ def get( if not key or not _cache[region] or force_query: return query_cache - cache_value = _cache[region].get(key) - if cache_value: + if cache_value := _cache[region].get(key): logger.debug("Cache key: %s", key) stats_logger.incr("loading_from_cache") try: diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index e339f7b1f4c9e..5f487f60f6aa0 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -993,11 +993,10 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals schema=self.schema, template_processor=template_processor, ) - col_in_metadata = self.get_column(expression) time_grain = col.get("timeGrain") has_timegrain = col.get("columnType") == "BASE_AXIS" and time_grain is_dttm = False - if col_in_metadata: + if col_in_metadata := self.get_column(expression): sqla_column = col_in_metadata.get_sqla_col( template_processor=template_processor ) diff --git a/superset/dashboards/commands/bulk_delete.py b/superset/dashboards/commands/bulk_delete.py index 52f5998438bcb..13541cd946ba0 100644 --- a/superset/dashboards/commands/bulk_delete.py +++ b/superset/dashboards/commands/bulk_delete.py @@ -56,8 +56,7 @@ def validate(self) -> None: if not self._models or len(self._models) != len(self._model_ids): raise DashboardNotFoundError() # Check there are no associated ReportSchedules - reports = ReportScheduleDAO.find_by_dashboard_ids(self._model_ids) - if reports: + if reports := ReportScheduleDAO.find_by_dashboard_ids(self._model_ids): report_names = [report.name for report in reports] raise DashboardBulkDeleteFailedReportsExistError( _("There are associated alerts or reports: %s" % ",".join(report_names)) diff --git a/superset/dashboards/commands/delete.py b/superset/dashboards/commands/delete.py index 7af2fdf4ce03c..8ce7cb0cbf84f 100644 --- a/superset/dashboards/commands/delete.py +++ b/superset/dashboards/commands/delete.py @@ -57,8 +57,7 @@ def validate(self) -> None: if not self._model: raise DashboardNotFoundError() # Check there are no associated ReportSchedules - reports = ReportScheduleDAO.find_by_dashboard_id(self._model_id) - if reports: + if reports := ReportScheduleDAO.find_by_dashboard_id(self._model_id): report_names = [report.name for report in reports] raise DashboardDeleteFailedReportsExistError( _("There are associated alerts or reports: %s" % ",".join(report_names)) diff --git a/superset/dashboards/dao.py b/superset/dashboards/dao.py index 89fca4619a0bf..5355d602bec04 100644 --- a/superset/dashboards/dao.py +++ b/superset/dashboards/dao.py @@ -200,11 +200,10 @@ def set_dash_metadata( # pylint: disable=too-many-locals old_to_new_slice_ids: Optional[Dict[int, int]] = None, commit: bool = False, ) -> Dashboard: - positions = data.get("positions") new_filter_scopes = {} md = dashboard.params_dict - if positions is not None: + if (positions := data.get("positions")) is not None: # find slices in the position data slice_ids = [ value.get("meta", {}).get("chartId") diff --git a/superset/databases/api.py b/superset/databases/api.py index 4997edc0738f3..8e444a84d8d82 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -1036,7 +1036,6 @@ def export(self, **kwargs: Any) -> Response: 500: $ref: '#/components/responses/500' """ - token = request.args.get("token") requested_ids = kwargs["rison"] timestamp = datetime.now().strftime("%Y%m%dT%H%M%S") root = f"database_export_{timestamp}" @@ -1060,7 +1059,7 @@ def export(self, **kwargs: Any) -> Response: as_attachment=True, download_name=filename, ) - if token: + if token := request.args.get("token"): response.set_cookie(token, "done", max_age=600) return response diff --git a/superset/databases/commands/delete.py b/superset/databases/commands/delete.py index ebdd543570a11..825b12621811a 100644 --- a/superset/databases/commands/delete.py +++ b/superset/databases/commands/delete.py @@ -55,9 +55,8 @@ def validate(self) -> None: if not self._model: raise DatabaseNotFoundError() # Check there are no associated ReportSchedules - reports = ReportScheduleDAO.find_by_database_id(self._model_id) - if reports: + if reports := ReportScheduleDAO.find_by_database_id(self._model_id): report_names = [report.name for report in reports] raise DatabaseDeleteFailedReportsExistError( _("There are associated alerts or reports: %s" % ",".join(report_names)) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index cbc1240905a10..9809641d5cd6a 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -228,6 +228,5 @@ def ping(engine: Engine) -> bool: raise DatabaseTestConnectionUnexpectedError(errors) from ex def validate(self) -> None: - database_name = self._properties.get("database_name") - if database_name is not None: + if (database_name := self._properties.get("database_name")) is not None: self._model = DatabaseDAO.get_database_by_name(database_name) diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index 8c58ef5de0bfb..2a624e32c7abc 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -128,6 +128,5 @@ def run(self) -> None: ) def validate(self) -> None: - database_id = self._properties.get("id") - if database_id is not None: + if (database_id := self._properties.get("id")) is not None: self._model = DatabaseDAO.find_by_id(database_id) diff --git a/superset/datasets/api.py b/superset/datasets/api.py index d52e6227932b8..6568ba379367c 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -977,8 +977,7 @@ def get_or_create_dataset(self) -> Response: return self.response(400, message=ex.messages) table_name = body["table_name"] database_id = body["database_id"] - table = DatasetDAO.get_table_by_name(database_id, table_name) - if table: + if table := DatasetDAO.get_table_by_name(database_id, table_name): return self.response(200, result={"table_id": table.id}) body["database"] = database_id diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py index 2df85cdfa27a7..52f46829b5b35 100644 --- a/superset/datasets/commands/importers/v1/utils.py +++ b/superset/datasets/commands/importers/v1/utils.py @@ -62,8 +62,7 @@ def get_sqla_type(native_type: str) -> VisitableType: if native_type.upper() in type_map: return type_map[native_type.upper()] - match = VARCHAR.match(native_type) - if match: + if match := VARCHAR.match(native_type): size = int(match.group(1)) return String(size) diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py index a2e483ba93ddb..b6bf1256d1904 100644 --- a/superset/datasets/commands/update.py +++ b/superset/datasets/commands/update.py @@ -114,13 +114,11 @@ def validate(self) -> None: exceptions.append(DatasetEndpointUnsafeValidationError()) # Validate columns - columns = self._properties.get("columns") - if columns: + if columns := self._properties.get("columns"): self._validate_columns(columns, exceptions) # Validate metrics - metrics = self._properties.get("metrics") - if metrics: + if metrics := self._properties.get("metrics"): self._validate_metrics(metrics, exceptions) if exceptions: diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 98fb60c27565b..221872f544835 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1704,8 +1704,7 @@ def get_column_spec( # pylint: disable=unused-argument :param source: Type coming from the database table or cursor description :return: ColumnSpec object """ - col_types = cls.get_column_types(native_type) - if col_types: + if col_types := cls.get_column_types(native_type): column_type, generic_type = col_types is_dttm = generic_type == GenericDataType.TEMPORAL return ColumnSpec( @@ -1996,9 +1995,8 @@ def validate_parameters( required = {"host", "port", "username", "database"} parameters = properties.get("parameters", {}) present = {key for key in parameters if parameters.get(key, ())} - missing = sorted(required - present) - if missing: + if missing := sorted(required - present): errors.append( SupersetError( message=f'One or more parameters are missing: {", ".join(missing)}', diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 1f2ee51068a86..1f5068ad04bbd 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -384,9 +384,8 @@ def df_to_sql( } # Add credentials if they are set on the SQLAlchemy dialect. - creds = engine.dialect.credentials_info - if creds: + if creds := engine.dialect.credentials_info: to_gbq_kwargs[ "credentials" ] = service_account.Credentials.from_service_account_info(creds) diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index f39e43aa60b4f..5f12f3174d363 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -285,9 +285,8 @@ def validate_parameters( # type: ignore parameters["http_path"] = connect_args.get("http_path") present = {key for key in parameters if parameters.get(key, ())} - missing = sorted(required - present) - if missing: + if missing := sorted(required - present): errors.append( SupersetError( message=f'One or more parameters are missing: {", ".join(missing)}', diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index e047923f922ef..42f6ed9af6db0 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -1213,8 +1213,7 @@ def extra_table_metadata( ) -> Dict[str, Any]: metadata = {} - indexes = database.get_indexes(table_name, schema_name) - if indexes: + if indexes := database.get_indexes(table_name, schema_name): col_names, latest_parts = cls.latest_partition( table_name, schema_name, database, show_first=True ) @@ -1278,8 +1277,7 @@ def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]: @classmethod def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None: """Updates progress information""" - tracking_url = cls.get_tracking_url(cursor) - if tracking_url: + if tracking_url := cls.get_tracking_url(cursor): query.tracking_url = tracking_url session.commit() diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index c7049ae71d688..69ccf55931922 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -312,9 +312,8 @@ def validate_parameters( } parameters = properties.get("parameters", {}) present = {key for key in parameters if parameters.get(key, ())} - missing = sorted(required - present) - if missing: + if missing := sorted(required - present): errors.append( SupersetError( message=f'One or more parameters are missing: {", ".join(missing)}', diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 6cca83be06f09..0fa4d05cbce0d 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -57,8 +57,7 @@ def extra_table_metadata( ) -> Dict[str, Any]: metadata = {} - indexes = database.get_indexes(table_name, schema_name) - if indexes: + if indexes := database.get_indexes(table_name, schema_name): col_names, latest_parts = cls.latest_partition( table_name, schema_name, database, show_first=True ) @@ -150,8 +149,7 @@ def get_tracking_url(cls, cursor: Cursor) -> Optional[str]: @classmethod def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: - tracking_url = cls.get_tracking_url(cursor) - if tracking_url: + if tracking_url := cls.get_tracking_url(cursor): query.tracking_url = tracking_url # Adds the executed query id to the extra payload so the query can be cancelled diff --git a/superset/errors.py b/superset/errors.py index 2df0eb82b26e0..5261848687f2f 100644 --- a/superset/errors.py +++ b/superset/errors.py @@ -211,8 +211,7 @@ def __post_init__(self) -> None: Mutates the extra params with user facing error codes that map to backend errors. """ - issue_codes = ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type) - if issue_codes: + if issue_codes := ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type): self.extra = self.extra or {} self.extra.update( { diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 90c1653d0b833..c489cc323cb64 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -453,8 +453,7 @@ def init_app_in_ctx(self) -> None: # Hook that provides administrators a handle on the Flask APP # after initialization - flask_app_mutator = self.config["FLASK_APP_MUTATOR"] - if flask_app_mutator: + if flask_app_mutator := self.config["FLASK_APP_MUTATOR"]: flask_app_mutator(self.superset_app) if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): diff --git a/superset/migrations/env.py b/superset/migrations/env.py index 90561beea4706..e3779bb65bcc2 100755 --- a/superset/migrations/env.py +++ b/superset/migrations/env.py @@ -103,8 +103,7 @@ def process_revision_directives( # pylint: disable=redefined-outer-name, unused kwargs = {} if engine.name in ("sqlite", "mysql"): kwargs = {"transaction_per_migration": True, "transactional_ddl": True} - configure_args = current_app.extensions["migrate"].configure_args - if configure_args: + if configure_args := current_app.extensions["migrate"].configure_args: kwargs.update(configure_args) context.configure( @@ -112,7 +111,7 @@ def process_revision_directives( # pylint: disable=redefined-outer-name, unused target_metadata=target_metadata, # compare_type=True, process_revision_directives=process_revision_directives, - **kwargs + **kwargs, ) try: diff --git a/superset/migrations/shared/migrate_viz/base.py b/superset/migrations/shared/migrate_viz/base.py index 19bb7cc2a957b..5ea23551ead57 100644 --- a/superset/migrations/shared/migrate_viz/base.py +++ b/superset/migrations/shared/migrate_viz/base.py @@ -130,8 +130,7 @@ def upgrade_slice(cls, slc: Slice) -> Slice: # only backup params slc.params = json.dumps({**clz.data, FORM_DATA_BAK_FIELD_NAME: form_data_bak}) - query_context = try_load_json(slc.query_context) - if "form_data" in query_context: + if "form_data" in (query_context := try_load_json(slc.query_context)): query_context["form_data"] = clz.data slc.query_context = json.dumps(query_context) return slc @@ -139,8 +138,7 @@ def upgrade_slice(cls, slc: Slice) -> Slice: @classmethod def downgrade_slice(cls, slc: Slice) -> Slice: form_data = try_load_json(slc.params) - form_data_bak = form_data.get(FORM_DATA_BAK_FIELD_NAME, {}) - if "viz_type" in form_data_bak: + if "viz_type" in (form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {})): slc.params = json.dumps(form_data_bak) slc.viz_type = form_data_bak.get("viz_type") query_context = try_load_json(slc.query_context) diff --git a/superset/migrations/shared/migrate_viz/processors.py b/superset/migrations/shared/migrate_viz/processors.py index 3584856beb72c..6d35a974dbafe 100644 --- a/superset/migrations/shared/migrate_viz/processors.py +++ b/superset/migrations/shared/migrate_viz/processors.py @@ -40,8 +40,7 @@ def _pre_action(self) -> None: if self.data.get("contribution"): self.data["contributionMode"] = "row" - stacked = self.data.get("stacked_style") - if stacked: + if stacked := self.data.get("stacked_style"): stacked_map = { "expand": "Expand", "stack": "Stack", @@ -49,7 +48,6 @@ def _pre_action(self) -> None: self.data["show_extra_controls"] = True self.data["stack"] = stacked_map.get(stacked) - x_axis_label = self.data.get("x_axis_label") - if x_axis_label: + if x_axis_label := self.data.get("x_axis_label"): self.data["x_axis_title"] = x_axis_label self.data["x_axis_title_margin"] = 30 diff --git a/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py b/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py index 865a8e59a06e4..13c4e61718cc4 100644 --- a/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py +++ b/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py @@ -193,13 +193,12 @@ def get_chart_holder(position): size_y = position["size_y"] slice_id = position["slice_id"] slice_name = position.get("slice_name") - code = position.get("code") width = max(GRID_MIN_COLUMN_COUNT, int(round(size_x / GRID_RATIO))) height = max( GRID_MIN_ROW_UNITS, int(round(((size_y / GRID_RATIO) * 100) / ROW_HEIGHT)) ) - if code is not None: + if (code := position.get("code")) is not None: markdown_content = " " # white-space markdown if len(code): markdown_content = code diff --git a/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py b/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py index 46b8e5f958670..ec8f8e1cc0566 100644 --- a/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py +++ b/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py @@ -80,8 +80,7 @@ def upgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]: changed_filters, changed_filter_sets = 0, 0 # upgrade native select filter metadata # upgrade native select filter metadata - native_filters = dashboard.get("native_filter_configuration") - if native_filters: + if native_filters := dashboard.get("native_filter_configuration"): changed_filters += upgrade_filters(native_filters) # upgrade filter sets @@ -123,8 +122,7 @@ def upgrade(): def downgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]: changed_filters, changed_filter_sets = 0, 0 # upgrade native select filter metadata - native_filters = dashboard.get("native_filter_configuration") - if native_filters: + if native_filters := dashboard.get("native_filter_configuration"): changed_filters += downgrade_filters(native_filters) # upgrade filter sets diff --git a/superset/migrations/versions/2021-07-27_08-25_31b2a1039d4a_drop_tables_constraint.py b/superset/migrations/versions/2021-07-27_08-25_31b2a1039d4a_drop_tables_constraint.py index 8f07ba1ae3d8c..9773851ae99d8 100644 --- a/superset/migrations/versions/2021-07-27_08-25_31b2a1039d4a_drop_tables_constraint.py +++ b/superset/migrations/versions/2021-07-27_08-25_31b2a1039d4a_drop_tables_constraint.py @@ -40,9 +40,8 @@ def upgrade(): insp = engine.reflection.Inspector.from_engine(bind) # Drop the uniqueness constraint if it exists. - constraint = generic_find_uq_constraint_name("tables", {"table_name"}, insp) - if constraint: + if constraint := generic_find_uq_constraint_name("tables", {"table_name"}, insp): with op.batch_alter_table("tables", naming_convention=conv) as batch_op: batch_op.drop_constraint(constraint, type_="unique") diff --git a/superset/models/core.py b/superset/models/core.py index 43d12900e613d..592207faba05a 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -277,9 +277,8 @@ def parameters(self) -> Dict[str, Any]: # When returning the parameters we should use the masked SQLAlchemy URI and the # masked ``encrypted_extra`` to prevent exposing sensitive credentials. masked_uri = make_url_safe(self.sqlalchemy_uri) - masked_encrypted_extra = self.masked_encrypted_extra encrypted_config = {} - if masked_encrypted_extra is not None: + if (masked_encrypted_extra := self.masked_encrypted_extra) is not None: try: encrypted_config = json.loads(masked_encrypted_extra) except (TypeError, json.JSONDecodeError): diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 4022bcbc1373b..558ad15fc9afe 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -880,8 +880,7 @@ def mutate_query_from_config(self, sql: str) -> str: """Apply config's SQL_QUERY_MUTATOR Typically adds comments to the query with context""" - sql_query_mutator = config["SQL_QUERY_MUTATOR"] - if sql_query_mutator: + if sql_query_mutator := config["SQL_QUERY_MUTATOR"]: sql = sql_query_mutator( sql, user_name=utils.get_username(), # TODO(john-bodley): Deprecate in 3.0. diff --git a/superset/models/slice.py b/superset/models/slice.py index d08e345d8240b..6835215338f49 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -378,8 +378,7 @@ def get(cls, id_: int) -> Slice: def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) -> None: src_class = target.cls_model - id_ = target.datasource_id - if id_: + if id_ := target.datasource_id: ds = db.session.query(src_class).filter_by(id=int(id_)).first() if ds: target.perm = ds.perm diff --git a/superset/queries/saved_queries/api.py b/superset/queries/saved_queries/api.py index 0c28e31a5296a..c6e980c5defe8 100644 --- a/superset/queries/saved_queries/api.py +++ b/superset/queries/saved_queries/api.py @@ -262,7 +262,6 @@ def export(self, **kwargs: Any) -> Response: 500: $ref: '#/components/responses/500' """ - token = request.args.get("token") requested_ids = kwargs["rison"] timestamp = datetime.now().strftime("%Y%m%dT%H%M%S") root = f"saved_query_export_{timestamp}" @@ -286,7 +285,7 @@ def export(self, **kwargs: Any) -> Response: as_attachment=True, download_name=filename, ) - if token: + if token := request.args.get("token"): response.set_cookie(token, "done", max_age=600) return response diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index 78ac8aa3a88f8..61f72d4790ef4 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -176,8 +176,7 @@ def _get_url( ) # If we need to render dashboard in a specific state, use stateful permalink - dashboard_state = self._report_schedule.extra.get("dashboard") - if dashboard_state: + if dashboard_state := self._report_schedule.extra.get("dashboard"): permalink_key = CreateDashboardPermalinkCommand( dashboard_id=str(self._report_schedule.dashboard.uuid), state=dashboard_state, diff --git a/superset/result_set.py b/superset/result_set.py index 170de1869c830..9aa06bba09ca1 100644 --- a/superset/result_set.py +++ b/superset/result_set.py @@ -225,12 +225,10 @@ def is_temporal(self, db_type_str: Optional[str]) -> bool: def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]: """Given a pyarrow data type, Returns a generic database type""" - set_type = self._type_dict.get(col_name) - if set_type: + if set_type := self._type_dict.get(col_name): return set_type - mapped_type = self.convert_pa_dtype(pa_dtype) - if mapped_type: + if mapped_type := self.convert_pa_dtype(pa_dtype): return mapped_type return None diff --git a/superset/security/manager.py b/superset/security/manager.py index c54fdac87adc1..db6e631d918b0 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -589,8 +589,7 @@ def user_view_menu_names(self, permission_name: str) -> Set[str]: return {s.name for s in view_menu_names} # Properly treat anonymous user - public_role = self.get_public_role() - if public_role: + if public_role := self.get_public_role(): # filter by public role view_menu_names = ( base_query.filter(self.role_model.id == public_role.id).filter( @@ -639,8 +638,7 @@ def get_schemas_accessible_by_user( } # datasource_access - perms = self.user_view_menu_names("datasource_access") - if perms: + if perms := self.user_view_menu_names("datasource_access"): tables = ( self.get_session.query(SqlaTable.schema) .filter(SqlaTable.database_id == database.id) @@ -770,9 +768,8 @@ def clean_perms(self) -> None: == None, ) ) - deleted_count = pvms.delete() sesh.commit() - if deleted_count: + if deleted_count := pvms.delete(): logger.info("Deleted %i faulty permissions", deleted_count) def sync_role_definitions(self) -> None: @@ -1916,8 +1913,7 @@ def get_guest_rls_filters( :param dataset: The dataset to check against :return: A list of filters """ - guest_user = self.get_current_guest_user_if_guest() - if guest_user: + if guest_user := self.get_current_guest_user_if_guest(): return [ rule for rule in guest_user.rls diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 149feb163926c..0f373a3514e64 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -95,7 +95,6 @@ def handle_query_error( """Local method handling error while processing the SQL""" payload = payload or {} msg = f"{prefix_message} {str(ex)}".strip() - troubleshooting_link = config["TROUBLESHOOTING_LINK"] query.error_message = msg query.tmp_table_name = None query.status = QueryStatus.FAILED @@ -119,7 +118,7 @@ def handle_query_error( session.commit() payload.update({"status": query.status, "error": msg, "errors": errors_payload}) - if troubleshooting_link: + if troubleshooting_link := config["TROUBLESHOOTING_LINK"]: payload["link"] = troubleshooting_link return payload diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 5bc844751b83b..10ef1fc1e13c9 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -54,8 +54,7 @@ def validate_statement( sql = parsed_query.stripped() # Hook to allow environment-specific mutation (usually comments) to the SQL - sql_query_mutator = config["SQL_QUERY_MUTATOR"] - if sql_query_mutator: + if sql_query_mutator := config["SQL_QUERY_MUTATOR"]: sql = sql_query_mutator( sql, user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0. diff --git a/superset/utils/core.py b/superset/utils/core.py index 8451eaaa6f457..c537abf459648 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1800,8 +1800,7 @@ def get_time_filter_status( } applied: List[Dict[str, str]] = [] rejected: List[Dict[str, str]] = [] - time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL) - if time_column: + if time_column := applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL): if time_column in temporal_columns: applied.append({"column": ExtraFiltersTimeColumnType.TIME_COL}) else: diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index a904b7dc4384a..88b97901b2893 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -121,8 +121,7 @@ def get_from_cache( @staticmethod def get_from_cache_key(cache: Cache, cache_key: str) -> Optional[BytesIO]: logger.info("Attempting to get from cache: %s", cache_key) - payload = cache.get(cache_key) - if payload: + if payload := cache.get(cache_key): return BytesIO(payload) logger.info("Failed at getting from cache: %s", cache_key) return None diff --git a/superset/views/api.py b/superset/views/api.py index 2884ac997f23a..84c27d2fac18e 100644 --- a/superset/views/api.py +++ b/superset/views/api.py @@ -79,8 +79,7 @@ def query_form_data(self) -> FlaskResponse: # pylint: disable=no-self-use params: slice_id: integer """ form_data = {} - slice_id = request.args.get("slice_id") - if slice_id: + if slice_id := request.args.get("slice_id"): slc = db.session.query(Slice).filter_by(id=slice_id).one_or_none() if slc: form_data = slc.form_data.copy() diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 2e069c196d44c..30d25382f37ee 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -380,8 +380,7 @@ def _get_related_filter( filter_field = cast(RelatedFieldFilter, filter_field) search_columns = [filter_field.field_name] if filter_field else None filters = datamodel.get_filters(search_columns) - base_filters = self.base_related_field_filters.get(column_name) - if base_filters: + if base_filters := self.base_related_field_filters.get(column_name): filters.add_filter_list(base_filters) if value and filter_field: filters.add_filter( @@ -588,8 +587,7 @@ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse: return self.response_404() page, page_size = self._sanitize_page_args(page, page_size) # handle ordering - order_field = self.order_rel_fields.get(column_name) - if order_field: + if order_field := self.order_rel_fields.get(column_name): order_column, order_direction = order_field else: order_column, order_direction = "", "" diff --git a/superset/views/core.py b/superset/views/core.py index b473172399cc5..24bc16c3106f1 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -765,8 +765,7 @@ def get_redirect_url() -> str: """ redirect_url = request.url.replace("/superset/explore", "/explore") form_data_key = None - request_form_data = request.args.get("form_data") - if request_form_data: + if request_form_data := request.args.get("form_data"): parsed_form_data = loads_request_json(request_form_data) slice_id = parsed_form_data.get( "slice_id", int(request.args.get("slice_id", 0)) @@ -1498,8 +1497,7 @@ def get_user_activity_access_error(user_id: int) -> Optional[FlaskResponse]: @deprecated(new_target="/api/v1/log/recent_activity//") def recent_activity(self, user_id: int) -> FlaskResponse: """Recent activity (actions) for a given user""" - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj limit = request.args.get("limit") @@ -1543,8 +1541,7 @@ def fave_dashboards_by_username(self, username: str) -> FlaskResponse: @expose("/fave_dashboards//", methods=("GET",)) @deprecated(new_target="api/v1/dashboard/favorite_status/") def fave_dashboards(self, user_id: int) -> FlaskResponse: - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj qry = ( db.session.query(Dashboard, FavStar.dttm) @@ -1580,8 +1577,7 @@ def fave_dashboards(self, user_id: int) -> FlaskResponse: @expose("/created_dashboards//", methods=("GET",)) @deprecated(new_target="api/v1/dashboard/") def created_dashboards(self, user_id: int) -> FlaskResponse: - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj qry = ( db.session.query(Dashboard) @@ -1615,8 +1611,7 @@ def user_slices(self, user_id: Optional[int] = None) -> FlaskResponse: """List of slices a user owns, created, modified or faved""" if not user_id: user_id = cast(int, get_user_id()) - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj owner_ids_query = ( @@ -1669,8 +1664,7 @@ def created_slices(self, user_id: Optional[int] = None) -> FlaskResponse: """List of slices created by this user""" if not user_id: user_id = cast(int, get_user_id()) - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj qry = ( db.session.query(Slice) @@ -1701,8 +1695,7 @@ def fave_slices(self, user_id: Optional[int] = None) -> FlaskResponse: """Favorite slices for a user""" if user_id is None: user_id = cast(int, get_user_id()) - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj qry = ( db.session.query(Slice, FavStar.dttm) @@ -1965,8 +1958,7 @@ def dashboard_permalink( # pylint: disable=no-self-use return json_error_response(_("permalink state not found"), status=404) dashboard_id, state = value["dashboardId"], value.get("state", {}) url = f"/superset/dashboard/{dashboard_id}?permalink_key={key}" - url_params = state.get("urlParams") - if url_params: + if url_params := state.get("urlParams"): params = parse.urlencode(url_params) url = f"{url}&{params}" hash_ = state.get("anchor", state.get("hash")) @@ -2125,8 +2117,7 @@ def estimate_query_cost( # pylint: disable=no-self-use mydb = db.session.query(Database).get(database_id) sql = json.loads(request.form.get("sql", '""')) - template_params = json.loads(request.form.get("templateParams") or "{}") - if template_params: + if template_params := json.loads(request.form.get("templateParams") or "{}"): template_processor = get_template_processor(mydb) sql = template_processor.process_template(sql, **template_params) @@ -2393,8 +2384,7 @@ def validate_sql_json( @expose("/sql_json/", methods=("POST",)) @deprecated(new_target="/api/v1/sqllab/execute/") def sql_json(self) -> FlaskResponse: - errors = SqlJsonPayloadSchema().validate(request.json) - if errors: + if errors := SqlJsonPayloadSchema().validate(request.json): return json_error_response(status=400, payload=errors) try: @@ -2621,10 +2611,7 @@ def search_queries(self) -> FlaskResponse: # pylint: disable=no-self-use search_user_id = get_user_id() database_id = request.args.get("database_id") search_text = request.args.get("search_text") - status = request.args.get("status") # From and To time stamp should be Epoch timestamp in seconds - from_time = request.args.get("from") - to_time = request.args.get("to") query = db.session.query(Query) if search_user_id: @@ -2635,7 +2622,7 @@ def search_queries(self) -> FlaskResponse: # pylint: disable=no-self-use # Filter on db Id query = query.filter(Query.database_id == database_id) - if status: + if status := request.args.get("status"): # Filter on status query = query.filter(Query.status == status) @@ -2643,10 +2630,10 @@ def search_queries(self) -> FlaskResponse: # pylint: disable=no-self-use # Filter on search text query = query.filter(Query.sql.like(f"%{search_text}%")) - if from_time: + if from_time := request.args.get("from"): query = query.filter(Query.start_time > int(from_time)) - if to_time: + if to_time := request.args.get("to"): query = query.filter(Query.start_time < int(to_time)) query_limit = config["QUERY_SEARCH_LIMIT"] @@ -2709,8 +2696,7 @@ def profile(self, username: str) -> FlaskResponse: user_id = -1 if not user else user.id # Prevent unauthorized access to other user's profiles, # unless configured to do so on with ENABLE_BROAD_ACTIVITY_ACCESS - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj payload = { @@ -2789,8 +2775,7 @@ def sqllab(self) -> FlaskResponse: **self._get_sqllab_tabs(get_user_id()), } - form_data = request.form.get("form_data") - if form_data: + if form_data := request.form.get("form_data"): try: payload["requested_query"] = json.loads(form_data) except json.JSONDecodeError: diff --git a/superset/views/database/validators.py b/superset/views/database/validators.py index 93723ac38b8f2..29d80611a2421 100644 --- a/superset/views/database/validators.py +++ b/superset/views/database/validators.py @@ -51,7 +51,6 @@ def sqlalchemy_uri_validator( def schema_allows_file_upload(database: Database, schema: Optional[str]) -> bool: if not database.allow_file_upload: return False - schemas = database.get_schema_access_for_file_upload() - if schemas: + if schemas := database.get_schema_access_for_file_upload(): return schema in schemas return security_manager.can_access_database(database) diff --git a/superset/views/log/api.py b/superset/views/log/api.py index b94af731c4f84..e218792c25970 100644 --- a/superset/views/log/api.py +++ b/superset/views/log/api.py @@ -125,8 +125,7 @@ def recent_activity(self, user_id: int, **kwargs: Any) -> FlaskResponse: 500: $ref: '#/components/responses/500' """ - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj args = kwargs["rison"] diff --git a/superset/views/utils.py b/superset/views/utils.py index 35a39fdc9c04d..a53e7500406f0 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -261,9 +261,8 @@ def get_datasource_info( :raises SupersetException: If the datasource no longer exists """ - datasource = form_data.get("datasource", "") - - if "__" in datasource: + # pylint: disable=superfluous-parens + if "__" in (datasource := form_data.get("datasource", "")): datasource_id, datasource_type = datasource.split("__") # The case where the datasource has been deleted if datasource_id == "None": @@ -462,7 +461,7 @@ def check_datasource_perms( _self: Any, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """ Check if user can access a cached response from explore_json. diff --git a/superset/viz.py b/superset/viz.py index d605b8b0068cd..8abb6038e8b53 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -365,10 +365,11 @@ def query_obj(self) -> QueryObjectDict: # pylint: disable=too-many-locals metrics = self.all_metrics or [] groupby = self.dedup_columns(self.groupby, self.form_data.get("columns")) - groupby_labels = get_column_names(groupby) is_timeseries = self.is_timeseries - if DTTM_ALIAS in groupby_labels: + + # pylint: disable=superfluous-parens + if DTTM_ALIAS in (groupby_labels := get_column_names(groupby)): del groupby[groupby_labels.index(DTTM_ALIAS)] is_timeseries = True @@ -959,8 +960,7 @@ def query_obj(self) -> QueryObjectDict: if len(deduped_cols) < (len(groupby) + len(columns)): raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap")) - sort_by = self.form_data.get("timeseries_limit_metric") - if sort_by: + if sort_by := self.form_data.get("timeseries_limit_metric"): sort_by_label = utils.get_metric_name(sort_by) if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): query_obj["metrics"].append(sort_by) @@ -1077,8 +1077,7 @@ class TreemapViz(BaseViz): @deprecated(deprecated_in="3.0") def query_obj(self) -> QueryObjectDict: query_obj = super().query_obj() - sort_by = self.form_data.get("timeseries_limit_metric") - if sort_by: + if sort_by := self.form_data.get("timeseries_limit_metric"): sort_by_label = utils.get_metric_name(sort_by) if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): query_obj["metrics"].append(sort_by) @@ -1880,8 +1879,7 @@ def query_obj(self) -> QueryObjectDict: if not self.form_data.get("groupby"): raise QueryObjectValidationError(_("Pick at least one field for [Series]")) - sort_by = self.form_data.get("timeseries_limit_metric") - if sort_by: + if sort_by := self.form_data.get("timeseries_limit_metric"): sort_by_label = utils.get_metric_name(sort_by) if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): query_obj["metrics"].append(sort_by) @@ -2310,8 +2308,7 @@ class ParallelCoordinatesViz(BaseViz): def query_obj(self) -> QueryObjectDict: query_obj = super().query_obj() query_obj["groupby"] = [self.form_data.get("series")] - sort_by = self.form_data.get("timeseries_limit_metric") - if sort_by: + if sort_by := self.form_data.get("timeseries_limit_metric"): sort_by_label = utils.get_metric_name(sort_by) if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): query_obj["metrics"].append(sort_by) @@ -2679,8 +2676,7 @@ def add_null_filters(self) -> None: if self.form_data.get("adhoc_filters") is None: self.form_data["adhoc_filters"] = [] - line_column = self.form_data.get("line_column") - if line_column: + if line_column := self.form_data.get("line_column"): spatial_columns.add(line_column) for column in sorted(spatial_columns): @@ -2706,13 +2702,12 @@ def query_obj(self) -> QueryObjectDict: if self.form_data.get("js_columns"): group_by += self.form_data.get("js_columns") or [] - metrics = self.get_metrics() # Ensure this value is sorted so that it does not # cause the cache key generation (which hashes the # query object) to generate different keys for values # that should be considered the same. group_by = sorted(set(group_by)) - if metrics: + if metrics := self.get_metrics(): query_obj["groupby"] = group_by query_obj["metrics"] = metrics query_obj["columns"] = [] @@ -3097,8 +3092,7 @@ class PairedTTestViz(BaseViz): @deprecated(deprecated_in="3.0") def query_obj(self) -> QueryObjectDict: query_obj = super().query_obj() - sort_by = self.form_data.get("timeseries_limit_metric") - if sort_by: + if sort_by := self.form_data.get("timeseries_limit_metric"): sort_by_label = utils.get_metric_name(sort_by) if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): query_obj["metrics"].append(sort_by) diff --git a/tests/example_data/data_loading/pandas/pandas_data_loader.py b/tests/example_data/data_loading/pandas/pandas_data_loader.py index 00f3f775cafb3..7f41602054e18 100644 --- a/tests/example_data/data_loading/pandas/pandas_data_loader.py +++ b/tests/example_data/data_loading/pandas/pandas_data_loader.py @@ -67,8 +67,7 @@ def _detect_schema_name(self) -> Optional[str]: return inspect(self._db_engine).default_schema_name def _take_data_types(self, table: Table) -> Optional[Dict[str, str]]: - metadata_table = table.table_metadata - if metadata_table: + if metadata_table := table.table_metadata: types = metadata_table.types if types: return types diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py index 70f984775db59..91a76f97cf298 100644 --- a/tests/integration_tests/csv_upload_tests.py +++ b/tests/integration_tests/csv_upload_tests.py @@ -136,7 +136,6 @@ def upload_csv( dtype: Union[str, None] = None, ): csv_upload_db_id = get_upload_db().id - schema = utils.get_example_default_schema() form_data = { "csv_file": open(filename, "rb"), "delimiter": ",", @@ -146,7 +145,7 @@ def upload_csv( "index_label": "test_label", "overwrite_duplicate": False, } - if schema: + if schema := utils.get_example_default_schema(): form_data["schema"] = schema if extra: form_data.update(extra) @@ -159,7 +158,6 @@ def upload_excel( filename: str, table_name: str, extra: Optional[Dict[str, str]] = None ): excel_upload_db_id = get_upload_db().id - schema = utils.get_example_default_schema() form_data = { "excel_file": open(filename, "rb"), "name": table_name, @@ -169,7 +167,7 @@ def upload_excel( "index_label": "test_label", "mangle_dupe_cols": False, } - if schema: + if schema := utils.get_example_default_schema(): form_data["schema"] = schema if extra: form_data.update(extra) @@ -180,7 +178,6 @@ def upload_columnar( filename: str, table_name: str, extra: Optional[Dict[str, str]] = None ): columnar_upload_db_id = get_upload_db().id - schema = utils.get_example_default_schema() form_data = { "columnar_file": open(filename, "rb"), "name": table_name, @@ -188,7 +185,7 @@ def upload_columnar( "if_exists": "fail", "index_label": "test_label", } - if schema: + if schema := utils.get_example_default_schema(): form_data["schema"] = schema if extra: form_data.update(extra) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 055cf4779eb84..f87bafcd376af 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -641,14 +641,13 @@ def test_create_dataset_validate_uniqueness(self): if backend() == "sqlite": return - schema = get_example_default_schema() energy_usage_ds = self.get_energy_usage_dataset() self.login(username="admin") table_data = { "database": energy_usage_ds.database_id, "table_name": energy_usage_ds.table_name, } - if schema: + if schema := get_example_default_schema(): table_data["schema"] = schema rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post") assert rv.status_code == 422 @@ -665,7 +664,6 @@ def test_create_dataset_with_sql_validate_uniqueness(self): if backend() == "sqlite": return - schema = get_example_default_schema() energy_usage_ds = self.get_energy_usage_dataset() self.login(username="admin") table_data = { @@ -673,7 +671,7 @@ def test_create_dataset_with_sql_validate_uniqueness(self): "table_name": energy_usage_ds.table_name, "sql": "select * from energy_usage", } - if schema: + if schema := get_example_default_schema(): table_data["schema"] = schema rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post") assert rv.status_code == 422 @@ -690,7 +688,6 @@ def test_create_dataset_with_sql(self): if backend() == "sqlite": return - schema = get_example_default_schema() energy_usage_ds = self.get_energy_usage_dataset() self.login(username="alpha") admin = self.get_user("admin") @@ -701,7 +698,7 @@ def test_create_dataset_with_sql(self): "sql": "select * from energy_usage", "owners": [admin.id], } - if schema: + if schema := get_example_default_schema(): table_data["schema"] = schema rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post") assert rv.status_code == 201 From d0687d04eb0365da34e937c37f9c2cd079bed415 Mon Sep 17 00:00:00 2001 From: Vitali Logvin <82170648+vitoldi@users.noreply.github.com> Date: Fri, 19 May 2023 13:18:16 +0300 Subject: [PATCH 2/7] feat: dashboard page xlsx export (#24005) Co-authored-by: Vitali Logvin --- .../components/SliceHeader/index.tsx | 2 ++ .../SliceHeaderControls.test.tsx | 13 +++++++++ .../components/SliceHeaderControls/index.tsx | 12 +++++++++ .../components/gridComponents/Chart.jsx | 27 ++++++++++++++----- .../components/gridComponents/Chart.test.jsx | 17 ++++++++++++ superset-frontend/src/logger/LogUtils.ts | 2 ++ 6 files changed, 67 insertions(+), 6 deletions(-) diff --git a/superset-frontend/src/dashboard/components/SliceHeader/index.tsx b/superset-frontend/src/dashboard/components/SliceHeader/index.tsx index 3ee67e7f3868d..b4a706c70f98c 100644 --- a/superset-frontend/src/dashboard/components/SliceHeader/index.tsx +++ b/superset-frontend/src/dashboard/components/SliceHeader/index.tsx @@ -135,6 +135,7 @@ const SliceHeader: FC = ({ logExploreChart = () => ({}), logEvent, exportCSV = () => ({}), + exportXLSX = () => ({}), editMode = false, annotationQuery = {}, annotationError = {}, @@ -264,6 +265,7 @@ const SliceHeader: FC = ({ logEvent={logEvent} exportCSV={exportCSV} exportFullCSV={exportFullCSV} + exportXLSX={exportXLSX} supersetCanExplore={supersetCanExplore} supersetCanShare={supersetCanShare} supersetCanCSV={supersetCanCSV} diff --git a/superset-frontend/src/dashboard/components/SliceHeaderControls/SliceHeaderControls.test.tsx b/superset-frontend/src/dashboard/components/SliceHeaderControls/SliceHeaderControls.test.tsx index 26f3a6bbbf8e2..d1a33da259b07 100644 --- a/superset-frontend/src/dashboard/components/SliceHeaderControls/SliceHeaderControls.test.tsx +++ b/superset-frontend/src/dashboard/components/SliceHeaderControls/SliceHeaderControls.test.tsx @@ -44,6 +44,7 @@ const createProps = (viz_type = 'sunburst') => exploreChart: jest.fn(), exportCSV: jest.fn(), exportFullCSV: jest.fn(), + exportXLSX: jest.fn(), forceRefresh: jest.fn(), handleToggleFullSize: jest.fn(), toggleExpandSlice: jest.fn(), @@ -126,6 +127,8 @@ test('Should render default props', () => { // @ts-ignore delete props.exportCSV; // @ts-ignore + delete props.exportXLSX; + // @ts-ignore delete props.cachedDttm; // @ts-ignore delete props.updatedDttm; @@ -170,6 +173,16 @@ test('Should "export to CSV"', async () => { expect(props.exportCSV).toBeCalledWith(371); }); +test('Should "export to Excel"', async () => { + const props = createProps(); + renderWrapper(props); + expect(props.exportXLSX).toBeCalledTimes(0); + userEvent.hover(screen.getByText('Download')); + userEvent.click(await screen.findByText('Export to Excel')); + expect(props.exportXLSX).toBeCalledTimes(1); + expect(props.exportXLSX).toBeCalledWith(371); +}); + test('Should not show "Download" if slice is filter box', () => { const props = createProps('filter_box'); renderWrapper(props); diff --git a/superset-frontend/src/dashboard/components/SliceHeaderControls/index.tsx b/superset-frontend/src/dashboard/components/SliceHeaderControls/index.tsx index bba888710a54c..72b0fb1aa0788 100644 --- a/superset-frontend/src/dashboard/components/SliceHeaderControls/index.tsx +++ b/superset-frontend/src/dashboard/components/SliceHeaderControls/index.tsx @@ -64,6 +64,7 @@ const MENU_KEYS = { EXPLORE_CHART: 'explore_chart', EXPORT_CSV: 'export_csv', EXPORT_FULL_CSV: 'export_full_csv', + EXPORT_XLSX: 'export_xlsx', FORCE_REFRESH: 'force_refresh', FULLSCREEN: 'fullscreen', TOGGLE_CHART_DESCRIPTION: 'toggle_chart_description', @@ -144,6 +145,7 @@ export interface SliceHeaderControlsProps { toggleExpandSlice?: (sliceId: number) => void; exportCSV?: (sliceId: number) => void; exportFullCSV?: (sliceId: number) => void; + exportXLSX?: (sliceId: number) => void; handleToggleFullSize: () => void; addDangerToast: (message: string) => void; @@ -294,6 +296,10 @@ const SliceHeaderControls = (props: SliceHeaderControlsPropsWithRouter) => { // eslint-disable-next-line no-unused-expressions props.exportFullCSV?.(props.slice.slice_id); break; + case MENU_KEYS.EXPORT_XLSX: + // eslint-disable-next-line no-unused-expressions + props.exportXLSX?.(props.slice.slice_id); + break; case MENU_KEYS.DOWNLOAD_AS_IMAGE: { // menu closes with a delay, we need to hide it manually, // so that we don't capture it on the screenshot @@ -493,6 +499,12 @@ const SliceHeaderControls = (props: SliceHeaderControlsPropsWithRouter) => { )} + } + > + {t('Export to Excel')} + } diff --git a/superset-frontend/src/dashboard/components/gridComponents/Chart.jsx b/superset-frontend/src/dashboard/components/gridComponents/Chart.jsx index 0f151ff4f8d6a..b16bd63f188e1 100644 --- a/superset-frontend/src/dashboard/components/gridComponents/Chart.jsx +++ b/superset-frontend/src/dashboard/components/gridComponents/Chart.jsx @@ -29,6 +29,7 @@ import { LOG_ACTIONS_CHANGE_DASHBOARD_FILTER, LOG_ACTIONS_EXPLORE_DASHBOARD_CHART, LOG_ACTIONS_EXPORT_CSV_DASHBOARD_CHART, + LOG_ACTIONS_EXPORT_XLSX_DASHBOARD_CHART, LOG_ACTIONS_FORCE_REFRESH_CHART, } from 'src/logger/LogUtils'; import { areObjectsEqual } from 'src/reduxUtils'; @@ -139,6 +140,7 @@ class Chart extends React.Component { this.handleFilterMenuClose = this.handleFilterMenuClose.bind(this); this.exportCSV = this.exportCSV.bind(this); this.exportFullCSV = this.exportFullCSV.bind(this); + this.exportXLSX = this.exportXLSX.bind(this); this.forceRefresh = this.forceRefresh.bind(this); this.resize = this.resize.bind(this); this.setDescriptionRef = this.setDescriptionRef.bind(this); @@ -324,8 +326,24 @@ class Chart extends React.Component { } }; + exportFullCSV() { + this.exportCSV(true); + } + exportCSV(isFullCSV = false) { - this.props.logEvent(LOG_ACTIONS_EXPORT_CSV_DASHBOARD_CHART, { + this.exportTable('csv', isFullCSV); + } + + exportXLSX() { + this.exportTable('xlsx', false); + } + + exportTable(format, isFullCSV) { + const logAction = + format === 'csv' + ? LOG_ACTIONS_EXPORT_CSV_DASHBOARD_CHART + : LOG_ACTIONS_EXPORT_XLSX_DASHBOARD_CHART; + this.props.logEvent(logAction, { slice_id: this.props.slice.slice_id, is_cached: this.props.isCached, }); @@ -334,16 +352,12 @@ class Chart extends React.Component { ? { ...this.props.formData, row_limit: this.props.maxRows } : this.props.formData, resultType: 'full', - resultFormat: 'csv', + resultFormat: format, force: true, ownState: this.props.ownState, }); } - exportFullCSV() { - this.exportCSV(true); - } - forceRefresh() { this.props.logEvent(LOG_ACTIONS_FORCE_REFRESH_CHART, { slice_id: this.props.slice.slice_id, @@ -437,6 +451,7 @@ class Chart extends React.Component { logEvent={logEvent} onExploreChart={this.onExploreChart} exportCSV={this.exportCSV} + exportXLSX={this.exportXLSX} exportFullCSV={this.exportFullCSV} updateSliceName={updateSliceName} sliceName={sliceName} diff --git a/superset-frontend/src/dashboard/components/gridComponents/Chart.test.jsx b/superset-frontend/src/dashboard/components/gridComponents/Chart.test.jsx index a3851a73b3e94..c892f7fff58de 100644 --- a/superset-frontend/src/dashboard/components/gridComponents/Chart.test.jsx +++ b/superset-frontend/src/dashboard/components/gridComponents/Chart.test.jsx @@ -62,6 +62,7 @@ describe('Chart', () => { addDangerToast() {}, exportCSV() {}, exportFullCSV() {}, + exportXLSX() {}, componentId: 'test', dashboardId: 111, editMode: false, @@ -145,4 +146,20 @@ describe('Chart', () => { expect(stubbedExportCSV.lastCall.args[0].formData.row_limit).toEqual(666); exploreUtils.exportChart.restore(); }); + it('should call exportChart when exportXLSX is clicked', () => { + const stubbedExportXLSX = sinon + .stub(exploreUtils, 'exportChart') + .returns(() => {}); + const wrapper = setup(); + wrapper.instance().exportXLSX(props.slice.sliceId); + expect(stubbedExportXLSX.calledOnce).toBe(true); + expect(stubbedExportXLSX.lastCall.args[0]).toEqual( + expect.objectContaining({ + formData: expect.anything(), + resultType: 'full', + resultFormat: 'xlsx', + }), + ); + exploreUtils.exportChart.restore(); + }); }); diff --git a/superset-frontend/src/logger/LogUtils.ts b/superset-frontend/src/logger/LogUtils.ts index cf5580c7bdaa3..258b5dbb5eea9 100644 --- a/superset-frontend/src/logger/LogUtils.ts +++ b/superset-frontend/src/logger/LogUtils.ts @@ -34,6 +34,8 @@ export const LOG_ACTIONS_PERIODIC_RENDER_DASHBOARD = export const LOG_ACTIONS_EXPLORE_DASHBOARD_CHART = 'explore_dashboard_chart'; export const LOG_ACTIONS_EXPORT_CSV_DASHBOARD_CHART = 'export_csv_dashboard_chart'; +export const LOG_ACTIONS_EXPORT_XLSX_DASHBOARD_CHART = + 'export_csv_dashboard_chart'; export const LOG_ACTIONS_CHANGE_DASHBOARD_FILTER = 'change_dashboard_filter'; export const LOG_ACTIONS_DATASET_CREATION_EMPTY_CANCELLATION = 'dataset_creation_empty_cancellation'; From f817c10422a74edb49858150ea5dae48499d5ef7 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Fri, 19 May 2023 16:29:11 +0300 Subject: [PATCH 3/7] fix(plugin-chart-echarts): normalize temporal string groupbys (#24134) --- .../plugin-chart-echarts/src/utils/series.ts | 8 +- .../test/utils/series.test.ts | 378 +++++++++--------- 2 files changed, 199 insertions(+), 187 deletions(-) diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/utils/series.ts b/superset-frontend/plugins/plugin-chart-echarts/src/utils/series.ts index 41554347d4da1..6247813742b26 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/utils/series.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/utils/series.ts @@ -29,6 +29,7 @@ import { NumberFormatter, TimeFormatter, SupersetTheme, + normalizeTimestamp, } from '@superset-ui/core'; import { SortSeriesType } from '@superset-ui/chart-controls'; import { format, LegendComponentOption, SeriesOption } from 'echarts'; @@ -336,7 +337,12 @@ export function formatSeriesName( return name.toString(); } if (name instanceof Date || coltype === GenericDataType.TEMPORAL) { - const d = name instanceof Date ? name : new Date(name); + const normalizedName = + typeof name === 'string' ? normalizeTimestamp(name) : name; + const d = + normalizedName instanceof Date + ? normalizedName + : new Date(normalizedName); return timeFormatter ? timeFormatter(d) : d.toISOString(); } diff --git a/superset-frontend/plugins/plugin-chart-echarts/test/utils/series.test.ts b/superset-frontend/plugins/plugin-chart-echarts/test/utils/series.test.ts index 8a2229fbe0580..c2de493c51a72 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/test/utils/series.test.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/test/utils/series.test.ts @@ -19,6 +19,7 @@ import { SortSeriesType } from '@superset-ui/chart-controls'; import { DataRecord, + GenericDataType, getNumberFormatter, getTimeFormatter, supersetTheme as theme, @@ -628,226 +629,231 @@ describe('formatSeriesName', () => { ); }); - describe('getLegendProps', () => { - it('should return the correct props for scroll type with top orientation without zoom', () => { - expect( - getLegendProps( - LegendType.Scroll, - LegendOrientation.Top, - true, - theme, - false, - ), - ).toEqual({ - show: true, - top: 0, - right: 0, - orient: 'horizontal', - type: 'scroll', - ...expectedThemeProps, - }); + it('should normalize non-UTC string based timestamp', () => { + const annualTimeFormatter = getTimeFormatter('%Y'); + expect( + formatSeriesName('1995-01-01 00:00:00.000000', { + timeFormatter: annualTimeFormatter, + coltype: GenericDataType.TEMPORAL, + }), + ).toEqual('1995'); + }); +}); + +describe('getLegendProps', () => { + it('should return the correct props for scroll type with top orientation without zoom', () => { + expect( + getLegendProps( + LegendType.Scroll, + LegendOrientation.Top, + true, + theme, + false, + ), + ).toEqual({ + show: true, + top: 0, + right: 0, + orient: 'horizontal', + type: 'scroll', + ...expectedThemeProps, }); + }); - it('should return the correct props for scroll type with top orientation with zoom', () => { - expect( - getLegendProps( - LegendType.Scroll, - LegendOrientation.Top, - true, - theme, - true, - ), - ).toEqual({ - show: true, - top: 0, - right: 55, - orient: 'horizontal', - type: 'scroll', - ...expectedThemeProps, - }); + it('should return the correct props for scroll type with top orientation with zoom', () => { + expect( + getLegendProps( + LegendType.Scroll, + LegendOrientation.Top, + true, + theme, + true, + ), + ).toEqual({ + show: true, + top: 0, + right: 55, + orient: 'horizontal', + type: 'scroll', + ...expectedThemeProps, }); + }); - it('should return the correct props for plain type with left orientation', () => { - expect( - getLegendProps(LegendType.Plain, LegendOrientation.Left, true, theme), - ).toEqual({ - show: true, - left: 0, - orient: 'vertical', - type: 'plain', - ...expectedThemeProps, - }); + it('should return the correct props for plain type with left orientation', () => { + expect( + getLegendProps(LegendType.Plain, LegendOrientation.Left, true, theme), + ).toEqual({ + show: true, + left: 0, + orient: 'vertical', + type: 'plain', + ...expectedThemeProps, }); + }); - it('should return the correct props for plain type with right orientation without zoom', () => { - expect( - getLegendProps( - LegendType.Plain, - LegendOrientation.Right, - false, - theme, - false, - ), - ).toEqual({ - show: false, - right: 0, - top: 0, - orient: 'vertical', - type: 'plain', - ...expectedThemeProps, - }); + it('should return the correct props for plain type with right orientation without zoom', () => { + expect( + getLegendProps( + LegendType.Plain, + LegendOrientation.Right, + false, + theme, + false, + ), + ).toEqual({ + show: false, + right: 0, + top: 0, + orient: 'vertical', + type: 'plain', + ...expectedThemeProps, }); + }); - it('should return the correct props for plain type with right orientation with zoom', () => { - expect( - getLegendProps( - LegendType.Plain, - LegendOrientation.Right, - false, - theme, - true, - ), - ).toEqual({ - show: false, - right: 0, - top: 30, - orient: 'vertical', - type: 'plain', - ...expectedThemeProps, - }); + it('should return the correct props for plain type with right orientation with zoom', () => { + expect( + getLegendProps( + LegendType.Plain, + LegendOrientation.Right, + false, + theme, + true, + ), + ).toEqual({ + show: false, + right: 0, + top: 30, + orient: 'vertical', + type: 'plain', + ...expectedThemeProps, }); + }); - it('should return the correct props for plain type with bottom orientation', () => { - expect( - getLegendProps( - LegendType.Plain, - LegendOrientation.Bottom, - false, - theme, - ), - ).toEqual({ - show: false, - bottom: 0, - orient: 'horizontal', - type: 'plain', - ...expectedThemeProps, - }); + it('should return the correct props for plain type with bottom orientation', () => { + expect( + getLegendProps(LegendType.Plain, LegendOrientation.Bottom, false, theme), + ).toEqual({ + show: false, + bottom: 0, + orient: 'horizontal', + type: 'plain', + ...expectedThemeProps, }); }); +}); - describe('getChartPadding', () => { - it('should handle top default', () => { - expect(getChartPadding(true, LegendOrientation.Top)).toEqual({ - bottom: 0, - left: 0, - right: 0, - top: defaultLegendPadding[LegendOrientation.Top], - }); +describe('getChartPadding', () => { + it('should handle top default', () => { + expect(getChartPadding(true, LegendOrientation.Top)).toEqual({ + bottom: 0, + left: 0, + right: 0, + top: defaultLegendPadding[LegendOrientation.Top], }); + }); - it('should handle left default', () => { - expect(getChartPadding(true, LegendOrientation.Left)).toEqual({ - bottom: 0, - left: defaultLegendPadding[LegendOrientation.Left], - right: 0, - top: 0, - }); + it('should handle left default', () => { + expect(getChartPadding(true, LegendOrientation.Left)).toEqual({ + bottom: 0, + left: defaultLegendPadding[LegendOrientation.Left], + right: 0, + top: 0, }); + }); - it('should return the default padding when show is false', () => { - expect( - getChartPadding(false, LegendOrientation.Left, 100, { - top: 10, - bottom: 20, - left: 30, - right: 40, - }), - ).toEqual({ + it('should return the default padding when show is false', () => { + expect( + getChartPadding(false, LegendOrientation.Left, 100, { + top: 10, bottom: 20, left: 30, right: 40, - top: 10, - }); + }), + ).toEqual({ + bottom: 20, + left: 30, + right: 40, + top: 10, }); + }); - it('should return the correct padding for left orientation', () => { - expect(getChartPadding(true, LegendOrientation.Left, 100)).toEqual({ - bottom: 0, - left: 100, - right: 0, - top: 0, - }); + it('should return the correct padding for left orientation', () => { + expect(getChartPadding(true, LegendOrientation.Left, 100)).toEqual({ + bottom: 0, + left: 100, + right: 0, + top: 0, }); + }); - it('should return the correct padding for right orientation', () => { - expect(getChartPadding(true, LegendOrientation.Right, 50)).toEqual({ - bottom: 0, - left: 0, - right: 50, - top: 0, - }); + it('should return the correct padding for right orientation', () => { + expect(getChartPadding(true, LegendOrientation.Right, 50)).toEqual({ + bottom: 0, + left: 0, + right: 50, + top: 0, }); + }); - it('should return the correct padding for top orientation', () => { - expect(getChartPadding(true, LegendOrientation.Top, 20)).toEqual({ - bottom: 0, - left: 0, - right: 0, - top: 20, - }); + it('should return the correct padding for top orientation', () => { + expect(getChartPadding(true, LegendOrientation.Top, 20)).toEqual({ + bottom: 0, + left: 0, + right: 0, + top: 20, }); + }); - it('should return the correct padding for bottom orientation', () => { - expect(getChartPadding(true, LegendOrientation.Bottom, 10)).toEqual({ - bottom: 10, - left: 0, - right: 0, - top: 0, - }); + it('should return the correct padding for bottom orientation', () => { + expect(getChartPadding(true, LegendOrientation.Bottom, 10)).toEqual({ + bottom: 10, + left: 0, + right: 0, + top: 0, }); }); +}); - describe('dedupSeries', () => { - it('should deduplicate ids in series', () => { - expect( - dedupSeries([ - { - id: 'foo', - }, - { - id: 'bar', - }, - { - id: 'foo', - }, - { - id: 'foo', - }, - ]), - ).toEqual([ - { id: 'foo' }, - { id: 'bar' }, - { id: 'foo (1)' }, - { id: 'foo (2)' }, - ]); - }); +describe('dedupSeries', () => { + it('should deduplicate ids in series', () => { + expect( + dedupSeries([ + { + id: 'foo', + }, + { + id: 'bar', + }, + { + id: 'foo', + }, + { + id: 'foo', + }, + ]), + ).toEqual([ + { id: 'foo' }, + { id: 'bar' }, + { id: 'foo (1)' }, + { id: 'foo (2)' }, + ]); }); +}); - describe('sanitizeHtml', () => { - it('should remove html tags from series name', () => { - expect(sanitizeHtml(NULL_STRING)).toEqual('<NULL>'); - }); +describe('sanitizeHtml', () => { + it('should remove html tags from series name', () => { + expect(sanitizeHtml(NULL_STRING)).toEqual('<NULL>'); }); +}); - describe('getOverMaxHiddenFormatter', () => { - it('should hide value if greater than max', () => { - const formatter = getOverMaxHiddenFormatter({ max: 81000 }); - expect(formatter.format(84500)).toEqual(''); - }); - it('should show value if less or equal than max', () => { - const formatter = getOverMaxHiddenFormatter({ max: 81000 }); - expect(formatter.format(81000)).toEqual('81000'); - expect(formatter.format(50000)).toEqual('50000'); - }); +describe('getOverMaxHiddenFormatter', () => { + it('should hide value if greater than max', () => { + const formatter = getOverMaxHiddenFormatter({ max: 81000 }); + expect(formatter.format(84500)).toEqual(''); + }); + it('should show value if less or equal than max', () => { + const formatter = getOverMaxHiddenFormatter({ max: 81000 }); + expect(formatter.format(81000)).toEqual('81000'); + expect(formatter.format(50000)).toEqual('50000'); }); }); From ba0bb20be54b7bfd2cfa6054e465c04a20726cff Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 19 May 2023 06:40:38 -0700 Subject: [PATCH 4/7] fix: Revert tox basepython (#24124) --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index c68e2919d0be8..c63fa947df6f4 100644 --- a/tox.ini +++ b/tox.ini @@ -18,7 +18,7 @@ # Remember to start celery workers to run celery tests, e.g. # celery --app=superset.tasks.celery_app:app worker -Ofair -c 2 [testenv] -basepython = python3.10 +basepython = python3.9 ignore_basepython_conflict = true commands = superset db upgrade From 488ec02e701c7b254e5114f942b77052b898791a Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 19 May 2023 06:41:17 -0700 Subject: [PATCH 5/7] chore(report): Use for/else clause (#24107) --- superset/reports/commands/execute.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index 61f72d4790ef4..f5f7bf4130254 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -670,7 +670,6 @@ def __init__( self._scheduled_dttm = scheduled_dttm def run(self) -> None: - state_found = False for state_cls in self.states_cls: if (self._report_schedule.last_state is None and state_cls.initial) or ( self._report_schedule.last_state in state_cls.current_states @@ -681,9 +680,8 @@ def run(self) -> None: self._scheduled_dttm, self._execution_id, ).next() - state_found = True break - if not state_found: + else: raise ReportScheduleStateNotFoundError() From 0496779434d166bf45767912dd49503323175cbe Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 19 May 2023 06:42:19 -0700 Subject: [PATCH 6/7] chore: Update QUERY_LOGGER and SQL_QUERY_MUTATOR signatures (#24029) --- UPDATING.md | 2 +- superset/config.py | 2 -- superset/connectors/sqla/models.py | 4 +--- superset/db_engine_specs/base.py | 3 +-- superset/models/core.py | 4 ---- superset/models/helpers.py | 1 - superset/sql_lab.py | 3 --- superset/sql_validators/presto_db.py | 3 +-- 8 files changed, 4 insertions(+), 18 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index 36f3645146ab7..7ecb299fc5285 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -24,7 +24,6 @@ assists people when migrating to a new version. ## Next -- [23785](https://github.com/apache/superset/pull/23785) Deprecated the following feature flags: `CLIENT_CACHE`, `DASHBOARD_CACHE`, `DASHBOARD_FILTERS_EXPERIMENTAL`, `DASHBOARD_NATIVE_FILTERS`, `DASHBOARD_NATIVE_FILTERS_SET`, `DISABLE_DATASET_SOURCE_EDIT`, `ENABLE_EXPLORE_JSON_CSRF_PROTECTION`, `REMOVE_SLICE_LEVEL_LABEL_COLORS`. It also removed `DASHBOARD_EDIT_CHART_IN_NEW_TAB` as the feature is supported without the need for a feature flag. - [23652](https://github.com/apache/superset/pull/23652) Enables GENERIC_CHART_AXES feature flag by default. - [23226](https://github.com/apache/superset/pull/23226) Migrated endpoint `/estimate_query_cost/` to `/api/v1/sqllab/estimate/`. Corresponding permissions are can estimate query cost on SQLLab. Make sure you add/replace the necessary permissions on any custom roles you may have. - [22809](https://github.com/apache/superset/pull/22809): Migrated endpoint `/superset/sql_json` and `/superset/results/` to `/api/v1/sqllab/execute/` and `/api/v1/sqllab/results/` respectively. Corresponding permissions are `can sql_json on Superset` to `can execute on SQLLab`, `can results on Superset` to `can results on SQLLab`. Make sure you add/replace the necessary permissions on any custom roles you may have. @@ -50,6 +49,7 @@ assists people when migrating to a new version. ### Breaking Changes +- [23785](https://github.com/apache/superset/pull/23785) Deprecated the following feature flags: `CLIENT_CACHE`, `DASHBOARD_CACHE`, `DASHBOARD_FILTERS_EXPERIMENTAL`, `DASHBOARD_NATIVE_FILTERS`, `DASHBOARD_NATIVE_FILTERS_SET`, `DISABLE_DATASET_SOURCE_EDIT`, `ENABLE_EXPLORE_JSON_CSRF_PROTECTION`, `REMOVE_SLICE_LEVEL_LABEL_COLORS`. It also removed `DASHBOARD_EDIT_CHART_IN_NEW_TAB` as the feature is supported without the need for a feature flag. - [22801](https://github.com/apache/superset/pull/22801): The Thumbnails feature has been changed to execute as the currently logged in user by default, falling back to the selenium user for anonymous users. To continue always using the selenium user, please add the following to your `superset_config.py`: `THUMBNAILS_EXECUTE_AS = ["selenium"]` - [22799](https://github.com/apache/superset/pull/22799): Alerts & Reports has been changed to execute as the owner of the alert/report by default, giving priority to the last modifier and then the creator if either is contained within the list of owners, otherwise the first owner will be used. To continue using the selenium user, please add the following to your `superset_config.py`: `ALERT_REPORTS_EXECUTE_AS = ["selenium"]` - [23651](https://github.com/apache/superset/pull/23651): Removes UX_BETA feature flag. diff --git a/superset/config.py b/superset/config.py index afb6504ee069b..63cabf1c45a4c 100644 --- a/superset/config.py +++ b/superset/config.py @@ -834,7 +834,6 @@ class D3Format(TypedDict, total=False): # database, # query, # schema=None, -# user=None, # TODO(john-bodley): Deprecate in 3.0. # client=None, # security_manager=None, # log_params=None, @@ -1188,7 +1187,6 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # # def SQL_QUERY_MUTATOR( # sql, -# user_name=user_name, # TODO(john-bodley): Deprecate in 3.0. # security_manager=security_manager, # database=database, # ): diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 5f487f60f6aa0..8833d6f6cb561 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -117,7 +117,7 @@ from superset.sql_parse import ParsedQuery, sanitize_clause from superset.superset_typing import AdhocColumn, AdhocMetric, Metric, QueryObjectDict from superset.utils import core as utils -from superset.utils.core import GenericDataType, get_username, MediumText +from superset.utils.core import GenericDataType, MediumText config = app.config metadata = Model.metadata # pylint: disable=no-member @@ -829,8 +829,6 @@ def mutate_query_from_config(self, sql: str) -> str: if sql_query_mutator and not mutate_after_split: sql = sql_query_mutator( sql, - # TODO(john-bodley): Deprecate in 3.0. - user_name=get_username(), security_manager=security_manager, database=self.database, ) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 221872f544835..1b46bb35bb47c 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -68,7 +68,7 @@ from superset.sql_parse import ParsedQuery, Table from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils -from superset.utils.core import ColumnSpec, GenericDataType, get_username +from superset.utils.core import ColumnSpec, GenericDataType from superset.utils.hashing import md5_sha_from_str from superset.utils.network import is_hostname_valid, is_port_open @@ -1393,7 +1393,6 @@ def process_statement(cls, statement: str, database: Database) -> str: if sql_query_mutator and not mutate_after_split: sql = sql_query_mutator( sql, - user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0. security_manager=security_manager, database=database, ) diff --git a/superset/models/core.py b/superset/models/core.py index 592207faba05a..ee50f063456f7 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -551,7 +551,6 @@ def get_df( # pylint: disable=too-many-locals ) -> pd.DataFrame: sqls = self.db_engine_spec.parse_sql(sql) engine = self._get_sqla_engine(schema) - username = utils.get_username() mutate_after_split = config["MUTATE_AFTER_SPLIT"] sql_query_mutator = config["SQL_QUERY_MUTATOR"] @@ -568,7 +567,6 @@ def _log_query(sql: str) -> None: engine.url, sql, schema, - get_username(), __name__, security_manager, ) @@ -579,7 +577,6 @@ def _log_query(sql: str) -> None: if mutate_after_split: sql_ = sql_query_mutator( sql_, - user_name=username, security_manager=security_manager, database=None, ) @@ -590,7 +587,6 @@ def _log_query(sql: str) -> None: if mutate_after_split: last_sql = sql_query_mutator( sqls[-1], - user_name=username, security_manager=security_manager, database=None, ) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 558ad15fc9afe..32c6f5ff6ab4a 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -883,7 +883,6 @@ def mutate_query_from_config(self, sql: str) -> str: if sql_query_mutator := config["SQL_QUERY_MUTATOR"]: sql = sql_query_mutator( sql, - user_name=utils.get_username(), # TODO(john-bodley): Deprecate in 3.0. security_manager=security_manager, database=self.database, ) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 0f373a3514e64..5cb52d4d1cc3f 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -53,7 +53,6 @@ from superset.sqllab.limiting_factor import LimitingFactor from superset.utils.celery import session_scope from superset.utils.core import ( - get_username, json_iso_dttm_ser, override_user, QuerySource, @@ -255,7 +254,6 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statem # Hook to allow environment-specific mutation (usually comments) to the SQL sql = SQL_QUERY_MUTATOR( sql, - user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0. security_manager=security_manager, database=database, ) @@ -266,7 +264,6 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statem query.database.sqlalchemy_uri, query.executed_sql, query.schema, - get_username(), __name__, security_manager, log_params, diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 10ef1fc1e13c9..c5ecf4c96e20c 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -24,7 +24,7 @@ from superset.models.core import Database from superset.sql_parse import ParsedQuery from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation -from superset.utils.core import get_username, QuerySource +from superset.utils.core import QuerySource MAX_ERROR_ROWS = 10 @@ -57,7 +57,6 @@ def validate_statement( if sql_query_mutator := config["SQL_QUERY_MUTATOR"]: sql = sql_query_mutator( sql, - user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0. security_manager=security_manager, database=database, ) From 8b4222ff9ec5bc4c9afb3bdb8a12f55720d5357c Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 19 May 2023 09:08:12 -0700 Subject: [PATCH 7/7] chore: Refactor command exceptions (#24117) --- .../annotations/commands/create.py | 4 +--- .../annotations/commands/update.py | 4 +--- superset/annotation_layers/commands/create.py | 4 +--- superset/annotation_layers/commands/update.py | 4 +--- superset/charts/commands/create.py | 4 +--- superset/charts/commands/update.py | 4 +--- superset/commands/exceptions.py | 20 +++++++++++-------- superset/commands/importers/v1/__init__.py | 7 ++++--- superset/commands/importers/v1/assets.py | 7 ++++--- superset/dashboards/commands/create.py | 8 ++------ superset/dashboards/commands/update.py | 8 ++------ superset/databases/commands/create.py | 2 +- superset/databases/commands/update.py | 4 +--- .../databases/ssh_tunnel/commands/create.py | 2 +- .../databases/ssh_tunnel/commands/update.py | 6 +++--- superset/datasets/commands/create.py | 4 +--- superset/datasets/commands/duplicate.py | 4 +--- superset/datasets/commands/update.py | 4 +--- superset/reports/commands/create.py | 4 +--- superset/reports/commands/update.py | 4 +--- superset/tags/commands/create.py | 4 +--- superset/tags/commands/delete.py | 8 ++------ 22 files changed, 44 insertions(+), 76 deletions(-) diff --git a/superset/annotation_layers/annotations/commands/create.py b/superset/annotation_layers/annotations/commands/create.py index 26cd968c5a1f4..0974624561142 100644 --- a/superset/annotation_layers/annotations/commands/create.py +++ b/superset/annotation_layers/annotations/commands/create.py @@ -73,6 +73,4 @@ def validate(self) -> None: exceptions.append(AnnotationDatesValidationError()) if exceptions: - exception = AnnotationInvalidError() - exception.add_list(exceptions) - raise exception + raise AnnotationInvalidError(exceptions=exceptions) diff --git a/superset/annotation_layers/annotations/commands/update.py b/superset/annotation_layers/annotations/commands/update.py index c55a1cdaf768e..b644ddc3622d9 100644 --- a/superset/annotation_layers/annotations/commands/update.py +++ b/superset/annotation_layers/annotations/commands/update.py @@ -87,6 +87,4 @@ def validate(self) -> None: exceptions.append(AnnotationDatesValidationError()) if exceptions: - exception = AnnotationInvalidError() - exception.add_list(exceptions) - raise exception + raise AnnotationInvalidError(exceptions=exceptions) diff --git a/superset/annotation_layers/commands/create.py b/superset/annotation_layers/commands/create.py index d5af6c24a292a..97431568a9a68 100644 --- a/superset/annotation_layers/commands/create.py +++ b/superset/annotation_layers/commands/create.py @@ -54,6 +54,4 @@ def validate(self) -> None: exceptions.append(AnnotationLayerNameUniquenessValidationError()) if exceptions: - exception = AnnotationLayerInvalidError() - exception.add_list(exceptions) - raise exception + raise AnnotationLayerInvalidError(exceptions=exceptions) diff --git a/superset/annotation_layers/commands/update.py b/superset/annotation_layers/commands/update.py index f4a04cdeb703b..4a9cc31be5f8d 100644 --- a/superset/annotation_layers/commands/update.py +++ b/superset/annotation_layers/commands/update.py @@ -63,6 +63,4 @@ def validate(self) -> None: exceptions.append(AnnotationLayerNameUniquenessValidationError()) if exceptions: - exception = AnnotationLayerInvalidError() - exception.add_list(exceptions) - raise exception + raise AnnotationLayerInvalidError(exceptions=exceptions) diff --git a/superset/charts/commands/create.py b/superset/charts/commands/create.py index 8238340794478..38076fb9cde8f 100644 --- a/superset/charts/commands/create.py +++ b/superset/charts/commands/create.py @@ -77,6 +77,4 @@ def validate(self) -> None: except ValidationError as ex: exceptions.append(ex) if exceptions: - exception = ChartInvalidError() - exception.add_list(exceptions) - raise exception + raise ChartInvalidError(exceptions=exceptions) diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index 042c85a930f93..f5fc2616a5a3c 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -114,6 +114,4 @@ def validate(self) -> None: self._properties["dashboards"] = dashboards if exceptions: - exception = ChartInvalidError() - exception.add_list(exceptions) - raise exception + raise ChartInvalidError(exceptions=exceptions) diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py index a661ef4d6047d..db9d1b6c63916 100644 --- a/superset/commands/exceptions.py +++ b/superset/commands/exceptions.py @@ -56,22 +56,26 @@ class CommandInvalidError(CommandException): status = 422 - def __init__(self, message: str = "") -> None: - self._invalid_exceptions: List[ValidationError] = [] + def __init__( + self, + message: str = "", + exceptions: Optional[List[ValidationError]] = None, + ) -> None: + self._exceptions = exceptions or [] super().__init__(message) - def add(self, exception: ValidationError) -> None: - self._invalid_exceptions.append(exception) + def append(self, exception: ValidationError) -> None: + self._exceptions.append(exception) - def add_list(self, exceptions: List[ValidationError]) -> None: - self._invalid_exceptions.extend(exceptions) + def extend(self, exceptions: List[ValidationError]) -> None: + self._exceptions.extend(exceptions) def get_list_classnames(self) -> List[str]: - return list(sorted({ex.__class__.__name__ for ex in self._invalid_exceptions})) + return list(sorted({ex.__class__.__name__ for ex in self._exceptions})) def normalized_messages(self) -> Dict[Any, Any]: errors: Dict[Any, Any] = {} - for exception in self._invalid_exceptions: + for exception in self._exceptions: errors.update(exception.normalized_messages()) return errors diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py index a4208ded41563..a67828bdb283d 100644 --- a/superset/commands/importers/v1/__init__.py +++ b/superset/commands/importers/v1/__init__.py @@ -108,9 +108,10 @@ def validate(self) -> None: self._prevent_overwrite_existing_model(exceptions) if exceptions: - exception = CommandInvalidError(f"Error importing {self.model_name}") - exception.add_list(exceptions) - raise exception + raise CommandInvalidError( + f"Error importing {self.model_name}", + exceptions, + ) def _prevent_overwrite_existing_model( # pylint: disable=invalid-name self, exceptions: List[ValidationError] diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py index da4d7808d3038..ce8b46c2a0c46 100644 --- a/superset/commands/importers/v1/assets.py +++ b/superset/commands/importers/v1/assets.py @@ -172,6 +172,7 @@ def validate(self) -> None: ) if exceptions: - exception = CommandInvalidError("Error importing assets") - exception.add_list(exceptions) - raise exception + raise CommandInvalidError( + "Error importing assets", + exceptions, + ) diff --git a/superset/dashboards/commands/create.py b/superset/dashboards/commands/create.py index 811508c2e78fb..0ad8ddee7c4d7 100644 --- a/superset/dashboards/commands/create.py +++ b/superset/dashboards/commands/create.py @@ -63,9 +63,7 @@ def validate(self) -> None: except ValidationError as ex: exceptions.append(ex) if exceptions: - exception = DashboardInvalidError() - exception.add_list(exceptions) - raise exception + raise DashboardInvalidError(exceptions=exceptions) try: roles = populate_roles(role_ids) @@ -73,6 +71,4 @@ def validate(self) -> None: except ValidationError as ex: exceptions.append(ex) if exceptions: - exception = DashboardInvalidError() - exception.add_list(exceptions) - raise exception + raise DashboardInvalidError(exceptions=exceptions) diff --git a/superset/dashboards/commands/update.py b/superset/dashboards/commands/update.py index 12ac241dccc22..11833a64be17d 100644 --- a/superset/dashboards/commands/update.py +++ b/superset/dashboards/commands/update.py @@ -92,9 +92,7 @@ def validate(self) -> None: except ValidationError as ex: exceptions.append(ex) if exceptions: - exception = DashboardInvalidError() - exception.add_list(exceptions) - raise exception + raise DashboardInvalidError(exceptions=exceptions) # Validate/Populate role if roles_ids is None: @@ -105,6 +103,4 @@ def validate(self) -> None: except ValidationError as ex: exceptions.append(ex) if exceptions: - exception = DashboardInvalidError() - exception.add_list(exceptions) - raise exception + raise DashboardInvalidError(exceptions=exceptions) diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index eb0582a980fca..16d27835b37c3 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -141,7 +141,7 @@ def validate(self) -> None: exceptions.append(DatabaseExistsValidationError()) if exceptions: exception = DatabaseInvalidError() - exception.add_list(exceptions) + exception.extend(exceptions) event_logger.log_with_context( action="db_connection_failed.{}.{}".format( exception.__class__.__name__, diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py index 03531803553a5..746f7a8152a74 100644 --- a/superset/databases/commands/update.py +++ b/superset/databases/commands/update.py @@ -177,6 +177,4 @@ def validate(self) -> None: ): exceptions.append(DatabaseExistsValidationError()) if exceptions: - exception = DatabaseInvalidError() - exception.add_list(exceptions) - raise exception + raise DatabaseInvalidError(exceptions=exceptions) diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py index 9c17149ba3d00..45e5af5f44ea9 100644 --- a/superset/databases/ssh_tunnel/commands/create.py +++ b/superset/databases/ssh_tunnel/commands/create.py @@ -82,7 +82,7 @@ def validate(self) -> None: exceptions.append(SSHTunnelRequiredFieldValidationError("private_key")) if exceptions: exception = SSHTunnelInvalidError() - exception.add_list(exceptions) + exception.extend(exceptions) event_logger.log_with_context( action="ssh_tunnel_creation_failed.{}.{}".format( exception.__class__.__name__, diff --git a/superset/databases/ssh_tunnel/commands/update.py b/superset/databases/ssh_tunnel/commands/update.py index 2ac7856705401..42925d1caa317 100644 --- a/superset/databases/ssh_tunnel/commands/update.py +++ b/superset/databases/ssh_tunnel/commands/update.py @@ -58,6 +58,6 @@ def validate(self) -> None: "private_key_password" ) if private_key_password and private_key is None: - exception = SSHTunnelInvalidError() - exception.add(SSHTunnelRequiredFieldValidationError("private_key")) - raise exception + raise SSHTunnelInvalidError( + exceptions=[SSHTunnelRequiredFieldValidationError("private_key")] + ) diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py index 809eec7a1159a..04f54339d0847 100644 --- a/superset/datasets/commands/create.py +++ b/superset/datasets/commands/create.py @@ -87,6 +87,4 @@ def validate(self) -> None: except ValidationError as ex: exceptions.append(ex) if exceptions: - exception = DatasetInvalidError() - exception.add_list(exceptions) - raise exception + raise DatasetInvalidError(exceptions=exceptions) diff --git a/superset/datasets/commands/duplicate.py b/superset/datasets/commands/duplicate.py index c6b0bbea69257..5fc642cbe3e66 100644 --- a/superset/datasets/commands/duplicate.py +++ b/superset/datasets/commands/duplicate.py @@ -128,6 +128,4 @@ def validate(self) -> None: exceptions.append(ex) if exceptions: - exception = DatasetInvalidError() - exception.add_list(exceptions) - raise exception + raise DatasetInvalidError(exceptions=exceptions) diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py index b6bf1256d1904..cc9f480a41b54 100644 --- a/superset/datasets/commands/update.py +++ b/superset/datasets/commands/update.py @@ -122,9 +122,7 @@ def validate(self) -> None: self._validate_metrics(metrics, exceptions) if exceptions: - exception = DatasetInvalidError() - exception.add_list(exceptions) - raise exception + raise DatasetInvalidError(exceptions=exceptions) def _validate_columns( self, columns: List[Dict[str, Any]], exceptions: List[ValidationError] diff --git a/superset/reports/commands/create.py b/superset/reports/commands/create.py index aac5f07856e5e..27626170d6458 100644 --- a/superset/reports/commands/create.py +++ b/superset/reports/commands/create.py @@ -117,9 +117,7 @@ def validate(self) -> None: except ValidationError as ex: exceptions.append(ex) if exceptions: - exception = ReportScheduleInvalidError() - exception.add_list(exceptions) - raise exception + raise ReportScheduleInvalidError(exceptions=exceptions) def _validate_report_extra(self, exceptions: List[ValidationError]) -> None: extra: Optional[ReportScheduleExtra] = self._properties.get("extra") diff --git a/superset/reports/commands/update.py b/superset/reports/commands/update.py index 3399eca7b72cd..0c4f18f1b842f 100644 --- a/superset/reports/commands/update.py +++ b/superset/reports/commands/update.py @@ -124,6 +124,4 @@ def validate(self) -> None: except ValidationError as ex: exceptions.append(ex) if exceptions: - exception = ReportScheduleInvalidError() - exception.add_list(exceptions) - raise exception + raise ReportScheduleInvalidError(exceptions=exceptions) diff --git a/superset/tags/commands/create.py b/superset/tags/commands/create.py index e9afe4a38d4a9..1e886e2af65a1 100644 --- a/superset/tags/commands/create.py +++ b/superset/tags/commands/create.py @@ -60,6 +60,4 @@ def validate(self) -> None: TagCreateFailedError(f"invalid object type {self._object_type}") ) if exceptions: - exception = TagInvalidError() - exception.add_list(exceptions) - raise exception + raise TagInvalidError(exceptions=exceptions) diff --git a/superset/tags/commands/delete.py b/superset/tags/commands/delete.py index 63a514e5996d7..acec01661935b 100644 --- a/superset/tags/commands/delete.py +++ b/superset/tags/commands/delete.py @@ -86,9 +86,7 @@ def validate(self) -> None: ) ) if exceptions: - exception = TagInvalidError() - exception.add_list(exceptions) - raise exception + raise TagInvalidError(exceptions=exceptions) class DeleteTagsCommand(DeleteMixin, BaseCommand): @@ -110,6 +108,4 @@ def validate(self) -> None: if not TagDAO.find_by_name(tag): exceptions.append(TagNotFoundError(tag)) if exceptions: - exception = TagInvalidError() - exception.add_list(exceptions) - raise exception + raise TagInvalidError(exceptions=exceptions)