diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py index cb1d3682838b9..8a3d8de2dcdb7 100644 --- a/superset/utils/sqllab_execution_context.py +++ b/superset/utils/sqllab_execution_context.py @@ -118,25 +118,24 @@ def is_run_asynchronous(self) -> bool: def select_as_cta(self) -> bool: return self.create_table_as_select is not None - def set_database(self, db: Database) -> None: - self._validate_db(db) - self.database = db + def set_database(self, database: Database) -> None: + self._validate_db(database) + self.database = database if self.select_as_cta: - self.create_table_as_select.target_schema_name = self._get_ctas_target_schema_name( # type: ignore - db - ) + schema_name = self._get_ctas_target_schema_name(database) + self.create_table_as_select.target_schema_name = schema_name # type: ignore - def _get_ctas_target_schema_name(self, db: Database) -> Optional[str]: - if db.force_ctas_schema: - return db.force_ctas_schema - else: - return get_cta_schema_name(db, g.user, self.schema, self.sql) + def _get_ctas_target_schema_name(self, database: Database) -> Optional[str]: + if database.force_ctas_schema: + return database.force_ctas_schema + return get_cta_schema_name(database, g.user, self.schema, self.sql) - def _validate_db(self, db: Database) -> None: + def _validate_db(self, database: Database) -> None: # TODO validate db.id is equal to self.database_id pass def create_query(self) -> Query: + # pylint: disable=C0301 start_time = now_as_float() if self.select_as_cta: return Query( @@ -154,22 +153,21 @@ def create_query(self) -> Query: user_id=self.user_id, client_id=self.client_id_or_short_id, ) - else: - return Query( - database_id=self.database_id, - sql=self.sql, - schema=self.schema, - select_as_cta=False, - start_time=start_time, - tab_name=self.tab_name, - status=self.status, - sql_editor_id=self.sql_editor_id, - user_id=self.user_id, - client_id=self.client_id_or_short_id, - ) + return Query( + database_id=self.database_id, + sql=self.sql, + schema=self.schema, + select_as_cta=False, + start_time=start_time, + tab_name=self.tab_name, + status=self.status, + sql_editor_id=self.sql_editor_id, + user_id=self.user_id, + client_id=self.client_id_or_short_id, + ) -class CreateTableAsSelect: +class CreateTableAsSelect: # pylint: disable=R0903 ctas_method: CtasMethod target_schema_name: Optional[str] target_table_name: str diff --git a/superset/views/core.py b/superset/views/core.py index 402d2841d87f2..d3c56026cca29 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2566,7 +2566,7 @@ def is_query_handled(cls, query: Optional[Query]) -> bool: QueryStatus.TIMED_OUT, ] - def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals + def sql_json_exec( # pylint: disable=too-many-statements self, execution_context: SqlJsonExecutionContext, query_params: Dict[str, Any], diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index 1ca51065fcfec..eb55c7c924f88 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -72,6 +72,7 @@ def get_query_by_id(id: int): @pytest.fixture(autouse=True, scope="module") def setup_sqllab(): + with app.app_context(): yield @@ -216,7 +217,8 @@ def test_run_sync_query_cta_no_data(setup_sqllab): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) @mock.patch( - "superset.utils.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME + "superset.utils.sqllab_execution_context.get_cta_schema_name", + lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) def test_run_sync_query_cta_config(setup_sqllab, ctas_method): if backend() == "sqlite": @@ -243,7 +245,8 @@ def test_run_sync_query_cta_config(setup_sqllab, ctas_method): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) @mock.patch( - "superset.utils.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME + "superset.utils.sqllab_execution_context.get_cta_schema_name", + lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) def test_run_async_query_cta_config(setup_sqllab, ctas_method): if backend() == "sqlite": diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index f5b7859ba975a..d712daec4c97d 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -63,6 +63,7 @@ QUERY_3 = "SELECT * FROM birth_names LIMIT 10" +@pytest.mark.sqllab class TestSqlLab(SupersetTestCase): """Testings for Sql Lab""" @@ -188,7 +189,7 @@ def test_sql_json_cta_dynamic_db(self, ctas_method): return with mock.patch( - "superset.views.core.get_cta_schema_name", + "superset.utils.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: f"{u.username}_database", ): old_allow_ctas = examples_db.allow_ctas