From cade3cb5c87dda9e7454168a9d2511e88b2f444b Mon Sep 17 00:00:00 2001 From: John Bodley Date: Tue, 27 Jul 2021 08:32:01 -0700 Subject: [PATCH] fix: Ensure table uniqueness on update --- UPDATING.md | 3 + superset/connectors/sqla/models.py | 52 +++- .../31b2a1039d4a_drop_tables_constraint.py | 54 ++++ tests/integration_tests/access_tests.py | 8 +- tests/integration_tests/base_tests.py | 43 +-- tests/integration_tests/charts/api_tests.py | 2 +- tests/integration_tests/csv_upload_tests.py | 4 +- tests/integration_tests/dashboard_utils.py | 8 +- tests/integration_tests/datasets/api_tests.py | 22 +- tests/integration_tests/datasource_tests.py | 66 ++--- .../dict_import_export_tests.py | 5 +- .../integration_tests/fixtures/datasource.py | 274 +++++++++--------- .../fixtures/query_context.py | 4 +- .../integration_tests/import_export_tests.py | 29 +- tests/integration_tests/model_tests.py | 12 +- .../integration_tests/query_context_tests.py | 2 +- tests/integration_tests/security_tests.py | 12 +- tests/integration_tests/sqla_models_tests.py | 2 +- .../tasks/async_queries_tests.py | 7 +- 19 files changed, 340 insertions(+), 269 deletions(-) create mode 100644 superset/migrations/versions/31b2a1039d4a_drop_tables_constraint.py diff --git a/UPDATING.md b/UPDATING.md index a24280e9db016..0e4c3a0577d84 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -23,6 +23,9 @@ This file documents any backwards-incompatible changes in Superset and assists people when migrating to a new version. ## Next +- [15909](https://github.com/apache/incubator-superset/pull/15909): a change which +drops a uniqueness criterion (which may or may not have existed) to the tables table. This constraint was obsolete as it is handled by the ORM due to differences in how MySQL, PostgreSQL, etc. handle uniqueness for NULL values. + - [13772](https://github.com/apache/superset/pull/13772): Row level security (RLS) is now enabled by default. To activate the feature, please run `superset init` to expose the RLS menus to Admin users. - [13980](https://github.com/apache/superset/pull/13980): Data health checks no longer use the metadata database as an interim cache. Though non-breaking, deployments which implement complex logic should likely memoize the callback function. Refer to documentation in the confg.py file for more detail. diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 3ebbaf298bda8..5c068da7c4a8e 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -487,7 +487,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at owner_class = security_manager.user_model __tablename__ = "tables" - __table_args__ = (UniqueConstraint("database_id", "table_name"),) + + # Note this uniqueness constraint is not part of the physical schema, i.e., it does + # not exist in the migrations, but is required by `import_from_dict` to ensure the + # correct filters are applied in order to identify uniqueness. + # + # The reason it does not physically exist is MySQL, PostgreSQL have a different + # interpretation of uniqueness when it comes to NULL which is problematic given the + # schema is optional. + __table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),) table_name = Column(String(250), nullable=False) main_dttm_col = Column(String(250)) @@ -1669,6 +1677,47 @@ class and any keys added via `ExtraCache`. extra_cache_keys += sqla_query.extra_cache_keys return extra_cache_keys + @staticmethod + def before_update( + mapper: Mapper, # pylint: disable=unused-argument + connection: Connection, # pylint: disable=unused-argument + target: "SqlaTable", + ) -> None: + """ + Check whether before update if the target table already exists. + + Note this listener is called when any fields are being updated and thus it is + necessary to first check whether the reference table is being updated. + + Note this logic is temporary, given uniqueness is handled via the dataset DAO, + but is necessary until both the legacy datasource editor and datasource/save + endpoints are deprecated. + + :param mapper: The table mapper + :param connection: The DB-API connection + :param target: The mapped instance being persisted + :raises Exception: If the target table is not unique + """ + + from superset.datasets.commands.exceptions import get_dataset_exist_error_msg + from superset.datasets.dao import DatasetDAO + + # Check whether the relevant attributes have changed. + state = db.inspect(target) # pylint: disable=no-member + + for attr in ["database_id", "schema", "table_name"]: + history = state.get_history(attr, True) + + if history.has_changes(): + break + else: + return None + + if not DatasetDAO.validate_uniqueness( + target.database_id, target.schema, target.table_name + ): + raise Exception(get_dataset_exist_error_msg(target.full_name)) + def update_table( _mapper: Mapper, _connection: Connection, obj: Union[SqlMetric, TableColumn] @@ -1686,6 +1735,7 @@ def update_table( sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm) sa.event.listen(SqlaTable, "after_update", security_manager.set_perm) +sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update) sa.event.listen(SqlMetric, "after_update", update_table) sa.event.listen(TableColumn, "after_update", update_table) diff --git a/superset/migrations/versions/31b2a1039d4a_drop_tables_constraint.py b/superset/migrations/versions/31b2a1039d4a_drop_tables_constraint.py new file mode 100644 index 0000000000000..02123d73ef5d4 --- /dev/null +++ b/superset/migrations/versions/31b2a1039d4a_drop_tables_constraint.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""drop tables constraint + +Revision ID: 31b2a1039d4a +Revises: ae1ed299413b +Create Date: 2021-07-27 08:25:20.755453 + +""" + +from alembic import op +from sqlalchemy import engine +from sqlalchemy.exc import OperationalError, ProgrammingError + +from superset.utils.core import generic_find_uq_constraint_name + +# revision identifiers, used by Alembic. +revision = "31b2a1039d4a" +down_revision = "ae1ed299413b" + +conv = {"uq": "uq_%(table_name)s_%(column_0_name)s"} + + +def upgrade(): + bind = op.get_bind() + insp = engine.reflection.Inspector.from_engine(bind) + + # Drop the uniqueness constraint if it exists. + constraint = generic_find_uq_constraint_name("tables", {"table_name"}, insp) + + if constraint: + with op.batch_alter_table("tables", naming_convention=conv) as batch_op: + batch_op.drop_constraint(constraint, type_="unique") + + +def downgrade(): + + # One cannot simply re-add the uniqueness constraint as it may not have previously + # existed. + pass diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py index 3d384cb65c694..d888dbf53c19f 100644 --- a/tests/integration_tests/access_tests.py +++ b/tests/integration_tests/access_tests.py @@ -161,7 +161,7 @@ def test_override_role_permissions_1_table(self): updated_override_me = security_manager.find_role("override_me") self.assertEqual(1, len(updated_override_me.permissions)) - birth_names = self.get_table_by_name("birth_names") + birth_names = self.get_table(name="birth_names") self.assertEqual( birth_names.perm, updated_override_me.permissions[0].view_menu.name ) @@ -190,7 +190,7 @@ def test_override_role_permissions_druid_and_table(self): "datasource_access", updated_role.permissions[1].permission.name ) - birth_names = self.get_table_by_name("birth_names") + birth_names = self.get_table(name="birth_names") self.assertEqual(birth_names.perm, perms[2].view_menu.name) self.assertEqual( "datasource_access", updated_role.permissions[2].permission.name @@ -204,7 +204,7 @@ def test_override_role_permissions_drops_absent_perms(self): override_me = security_manager.find_role("override_me") override_me.permissions.append( security_manager.find_permission_view_menu( - view_menu_name=self.get_table_by_name("energy_usage").perm, + view_menu_name=self.get_table(name="energy_usage").perm, permission_name="datasource_access", ) ) @@ -218,7 +218,7 @@ def test_override_role_permissions_drops_absent_perms(self): self.assertEqual(201, response.status_code) updated_override_me = security_manager.find_role("override_me") self.assertEqual(1, len(updated_override_me.permissions)) - birth_names = self.get_table_by_name("birth_names") + birth_names = self.get_table(name="birth_names") self.assertEqual( birth_names.perm, updated_override_me.permissions[0].view_menu.name ) diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 672b8d210aca0..7e4ebfd7df035 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -99,10 +99,6 @@ def post_assert_metric( return rv -def get_table_by_name(name: str) -> SqlaTable: - return db.session.query(SqlaTable).filter_by(table_name=name).one() - - @pytest.fixture def logged_in_admin(): """Fixture with app context and logged in admin user.""" @@ -132,12 +128,7 @@ def get_nonexistent_numeric_id(model): @staticmethod def get_birth_names_dataset() -> SqlaTable: - example_db = get_example_database() - return ( - db.session.query(SqlaTable) - .filter_by(database=example_db, table_name="birth_names") - .one() - ) + return SupersetTestCase.get_table(name="birth_names") @staticmethod def create_user_with_roles( @@ -254,13 +245,31 @@ def get_slice( return slc @staticmethod - def get_table_by_name(name: str) -> SqlaTable: - return get_table_by_name(name) + def get_table( + name: str, database_id: Optional[int] = None, schema: Optional[str] = None + ) -> SqlaTable: + return ( + db.session.query(SqlaTable) + .filter_by( + database_id=database_id + or SupersetTestCase.get_database_by_name("examples").id, + schema=schema, + table_name=name, + ) + .one() + ) @staticmethod def get_database_by_id(db_id: int) -> Database: return db.session.query(Database).filter_by(id=db_id).one() + @staticmethod + def get_database_by_name(database_name: str = "main") -> Database: + if database_name == "examples": + return get_example_database() + else: + raise ValueError("Database doesn't exist") + @staticmethod def get_druid_ds_by_name(name: str) -> DruidDatasource: return db.session.query(DruidDatasource).filter_by(datasource_name=name).first() @@ -340,12 +349,6 @@ def revoke_role_access_to_table(self, role_name, table): ): security_manager.del_permission_role(public_role, perm) - def _get_database_by_name(self, database_name="main"): - if database_name == "examples": - return get_example_database() - else: - raise ValueError("Database doesn't exist") - def run_sql( self, sql, @@ -364,7 +367,7 @@ def run_sql( if user_name: self.logout() self.login(username=(user_name or "admin")) - dbid = self._get_database_by_name(database_name).id + dbid = SupersetTestCase.get_database_by_name(database_name).id json_payload = { "database_id": dbid, "sql": sql, @@ -448,7 +451,7 @@ def validate_sql( if user_name: self.logout() self.login(username=(user_name if user_name else "admin")) - dbid = self._get_database_by_name(database_name).id + dbid = SupersetTestCase.get_database_by_name(database_name).id resp = self.get_json_resp( "/superset/validate_sql_json/", raise_on_error=False, diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index e74f50d88143a..803d81f01f1a4 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -545,7 +545,7 @@ def test_update_chart(self): """ admin = self.get_user("admin") gamma = self.get_user("gamma") - birth_names_table_id = SupersetTestCase.get_table_by_name("birth_names").id + birth_names_table_id = SupersetTestCase.get_table(name="birth_names").id chart_id = self.insert_chart( "title", [admin.id], birth_names_table_id, admin ).id diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py index 99afbeb5c8303..a8821fb0308d2 100644 --- a/tests/integration_tests/csv_upload_tests.py +++ b/tests/integration_tests/csv_upload_tests.py @@ -221,7 +221,7 @@ def test_import_csv_explore_database(setup_csv_upload, create_csv_files): f'CSV file "{CSV_FILENAME1}" uploaded to table "{CSV_UPLOAD_TABLE_W_EXPLORE}"' in resp ) - table = SupersetTestCase.get_table_by_name(CSV_UPLOAD_TABLE_W_EXPLORE) + table = SupersetTestCase.get_table(name=CSV_UPLOAD_TABLE_W_EXPLORE) assert table.database_id == utils.get_example_database().id @@ -267,7 +267,7 @@ def test_import_csv(setup_csv_upload, create_csv_files): ) assert success_msg_f2 in resp - table = SupersetTestCase.get_table_by_name(CSV_UPLOAD_TABLE) + table = SupersetTestCase.get_table(name=CSV_UPLOAD_TABLE) # make sure the new column name is reflected in the table metadata assert "d" in table.column_names diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py index 311dd5965adf1..85daa0b1b8d09 100644 --- a/tests/integration_tests/dashboard_utils.py +++ b/tests/integration_tests/dashboard_utils.py @@ -35,6 +35,7 @@ def create_table_for_dashboard( dtype: Dict[str, Any], table_description: str = "", fetch_values_predicate: Optional[str] = None, + schema: Optional[str] = None, ) -> SqlaTable: df.to_sql( table_name, @@ -44,14 +45,17 @@ def create_table_for_dashboard( dtype=dtype, index=False, method="multi", + schema=schema, ) table_source = ConnectorRegistry.sources["table"] table = ( - db.session.query(table_source).filter_by(table_name=table_name).one_or_none() + db.session.query(table_source) + .filter_by(database_id=database.id, schema=schema, table_name=table_name) + .one_or_none() ) if not table: - table = table_source(table_name=table_name) + table = table_source(schema=schema, table_name=table_name) if fetch_values_predicate: table.fetch_values_predicate = fetch_values_predicate table.database = database diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 894c6f6900f9e..5275a87173030 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -63,10 +63,10 @@ class TestDatasetApi(SupersetTestCase): @staticmethod def insert_dataset( table_name: str, - schema: str, owners: List[int], database: Database, sql: Optional[str] = None, + schema: Optional[str] = None, ) -> SqlaTable: obj_owners = list() for owner in owners: @@ -86,7 +86,7 @@ def insert_dataset( def insert_default_dataset(self): return self.insert_dataset( - "ab_permission", "", [self.get_user("admin").id], get_main_database() + "ab_permission", [self.get_user("admin").id], get_main_database() ) def get_fixture_datasets(self) -> List[SqlaTable]: @@ -105,11 +105,7 @@ def create_virtual_datasets(self): for table_name in self.fixture_virtual_table_names: datasets.append( self.insert_dataset( - table_name, - "", - [admin.id], - main_db, - "SELECT * from ab_view_menu;", + table_name, [admin.id], main_db, "SELECT * from ab_view_menu;", ) ) yield datasets @@ -126,9 +122,7 @@ def create_datasets(self): admin = self.get_user("admin") main_db = get_main_database() for tables_name in self.fixture_tables_names: - datasets.append( - self.insert_dataset(tables_name, "", [admin.id], main_db) - ) + datasets.append(self.insert_dataset(tables_name, [admin.id], main_db)) yield datasets # rollback changes @@ -270,11 +264,13 @@ def pg_test_query_parameter(query_parameter, expected_response): datasets = [] if example_db.backend == "postgresql": datasets.append( - self.insert_dataset("ab_permission", "public", [], get_main_database()) + self.insert_dataset( + "ab_permission", [], get_main_database(), schema="public" + ) ) datasets.append( self.insert_dataset( - "columns", "information_schema", [], get_main_database() + "columns", [], get_main_database(), schema="information_schema", ) ) schema_values = [ @@ -921,7 +917,7 @@ def test_update_dataset_item_uniqueness(self): dataset = self.insert_default_dataset() self.login(username="admin") ab_user = self.insert_dataset( - "ab_user", "", [self.get_user("admin").id], get_main_database() + "ab_user", [self.get_user("admin").id], get_main_database() ) table_data = {"table_name": "ab_user"} uri = f"api/v1/dataset/{dataset.id}" diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 14280264acecf..be58e4e9b8e85 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -16,7 +16,6 @@ # under the License. """Unit tests for Superset""" import json -from copy import deepcopy from unittest import mock import pytest @@ -26,30 +25,24 @@ from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.exceptions import SupersetException, SupersetGenericDBErrorException from superset.utils.core import get_example_database +from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, ) - -from .base_tests import db_insert_temp_object, SupersetTestCase -from .fixtures.datasource import datasource_post +from tests.integration_tests.fixtures.datasource import get_datasource_post class TestDatasource(SupersetTestCase): def setUp(self): - self.original_attrs = {} - self.datasource = None + db.session.begin(subtransactions=True) def tearDown(self): - if self.datasource: - for key, value in self.original_attrs.items(): - setattr(self.datasource, key, value) - - db.session.commit() + db.session.rollback() @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_external_metadata_for_physical_table(self): self.login(username="admin") - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") url = f"/datasource/external_metadata/table/{tbl.id}/" resp = self.get_json_resp(url) col_names = {o.get("name") for o in resp} @@ -68,7 +61,7 @@ def test_external_metadata_for_virtual_table(self): session.add(table) session.commit() - table = self.get_table_by_name("dummy_sql_table") + table = self.get_table(name="dummy_sql_table") url = f"/datasource/external_metadata/table/{table.id}/" resp = self.get_json_resp(url) assert {o.get("name") for o in resp} == {"intcol", "strcol"} @@ -87,7 +80,7 @@ def test_external_metadata_for_virtual_table_template_params(self): session.add(table) session.commit() - table = self.get_table_by_name("dummy_sql_table_with_template_params") + table = self.get_table(name="dummy_sql_table_with_template_params") url = f"/datasource/external_metadata/table/{table.id}/" resp = self.get_json_resp(url) assert {o.get("name") for o in resp} == {"intcol"} @@ -123,7 +116,7 @@ def test_external_metadata_for_mutistatement_virtual_table(self): @mock.patch("superset.connectors.sqla.models.SqlaTable.external_metadata") def test_external_metadata_error_return_400(self, mock_get_datasource): self.login(username="admin") - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") url = f"/datasource/external_metadata/table/{tbl.id}/" mock_get_datasource.side_effect = SupersetGenericDBErrorException("oops") @@ -148,13 +141,9 @@ def compare_lists(self, l1, l2, key): def test_save(self): self.login(username="admin") - tbl_id = self.get_table_by_name("birth_names").id - - self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session) - - for key in self.datasource.export_fields: - self.original_attrs[key] = getattr(self.datasource, key) + tbl_id = self.get_table(name="birth_names").id + datasource_post = get_datasource_post() datasource_post["id"] = tbl_id data = dict(data=json.dumps(datasource_post)) resp = self.get_json_resp("/datasource/save/", data) @@ -168,25 +157,21 @@ def test_save(self): else: self.assertEqual(resp[k], datasource_post[k]) - def save_datasource_from_dict(self, datasource_dict): + def save_datasource_from_dict(self, datasource_post): data = dict(data=json.dumps(datasource_post)) resp = self.get_json_resp("/datasource/save/", data) return resp + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_change_database(self): self.login(username="admin") - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") tbl_id = tbl.id db_id = tbl.database_id + datasource_post = get_datasource_post() datasource_post["id"] = tbl_id - self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session) - - for key in self.datasource.export_fields: - self.original_attrs[key] = getattr(self.datasource, key) - new_db = self.create_fake_db() - datasource_post["database"]["id"] = new_db.id resp = self.save_datasource_from_dict(datasource_post) self.assertEqual(resp["database"]["id"], new_db.id) @@ -199,15 +184,11 @@ def test_change_database(self): def test_save_duplicate_key(self): self.login(username="admin") - tbl_id = self.get_table_by_name("birth_names").id - self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session) - - for key in self.datasource.export_fields: - self.original_attrs[key] = getattr(self.datasource, key) + tbl_id = self.get_table(name="birth_names").id - datasource_post_copy = deepcopy(datasource_post) - datasource_post_copy["id"] = tbl_id - datasource_post_copy["columns"].extend( + datasource_post = get_datasource_post() + datasource_post["id"] = tbl_id + datasource_post["columns"].extend( [ { "column_name": "", @@ -225,18 +206,15 @@ def test_save_duplicate_key(self): }, ] ) - data = dict(data=json.dumps(datasource_post_copy)) + data = dict(data=json.dumps(datasource_post)) resp = self.get_json_resp("/datasource/save/", data, raise_on_error=False) self.assertIn("Duplicate column name(s): ", resp["error"]) def test_get_datasource(self): self.login(username="admin") - tbl = self.get_table_by_name("birth_names") - self.datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session) - - for key in self.datasource.export_fields: - self.original_attrs[key] = getattr(self.datasource, key) + tbl = self.get_table(name="birth_names") + datasource_post = get_datasource_post() datasource_post["id"] = tbl.id data = dict(data=json.dumps(datasource_post)) self.get_json_resp("/datasource/save/", data) @@ -264,7 +242,7 @@ def my_check(datasource): app.config["DATASET_HEALTH_CHECK"] = my_check self.login(username="admin") - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session) assert datasource.health_check_message == "Warning message!" app.config["DATASET_HEALTH_CHECK"] = None diff --git a/tests/integration_tests/dict_import_export_tests.py b/tests/integration_tests/dict_import_export_tests.py index 922b10be889f1..fe7ff512f1fe3 100644 --- a/tests/integration_tests/dict_import_export_tests.py +++ b/tests/integration_tests/dict_import_export_tests.py @@ -66,7 +66,7 @@ def tearDownClass(cls): cls.delete_imports() def create_table( - self, name, schema="", id=0, cols_names=[], cols_uuids=None, metric_names=[] + self, name, schema=None, id=0, cols_names=[], cols_uuids=None, metric_names=[] ): database_name = "main" name = "{0}{1}".format(NAME_PREFIX, name) @@ -128,9 +128,6 @@ def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]): def get_datasource(self, datasource_id): return db.session.query(DruidDatasource).filter_by(id=datasource_id).first() - def get_table_by_name(self, name): - return db.session.query(SqlaTable).filter_by(table_name=name).first() - def yaml_compare(self, obj_1, obj_2): obj_1_str = yaml.safe_dump(obj_1, default_flow_style=False) obj_2_str = yaml.safe_dump(obj_2, default_flow_style=False) diff --git a/tests/integration_tests/fixtures/datasource.py b/tests/integration_tests/fixtures/datasource.py index fad2bfeea179a..e6cd7e8229cc5 100644 --- a/tests/integration_tests/fixtures/datasource.py +++ b/tests/integration_tests/fixtures/datasource.py @@ -15,138 +15,142 @@ # specific language governing permissions and limitations # under the License. """Fixtures for test_datasource.py""" -datasource_post = { - "id": None, - "column_formats": {"ratio": ".2%"}, - "database": {"id": 1}, - "description": "Adding a DESCRip", - "default_endpoint": "", - "filter_select_enabled": True, - "name": "birth_names", - "table_name": "birth_names", - "datasource_name": "birth_names", - "type": "table", - "schema": "", - "offset": 66, - "cache_timeout": 55, - "sql": "", - "columns": [ - { - "id": 504, - "column_name": "ds", - "verbose_name": "", - "description": None, - "expression": "", - "filterable": True, - "groupby": True, - "is_dttm": True, - "type": "DATETIME", - }, - { - "id": 505, - "column_name": "gender", - "verbose_name": None, - "description": None, - "expression": "", - "filterable": True, - "groupby": True, - "is_dttm": False, - "type": "VARCHAR(16)", - }, - { - "id": 506, - "column_name": "name", - "verbose_name": None, - "description": None, - "expression": None, - "filterable": True, - "groupby": True, - "is_dttm": None, - "type": "VARCHAR(255)", - }, - { - "id": 508, - "column_name": "state", - "verbose_name": None, - "description": None, - "expression": None, - "filterable": True, - "groupby": True, - "is_dttm": None, - "type": "VARCHAR(10)", - }, - { - "id": 509, - "column_name": "num_boys", - "verbose_name": None, - "description": None, - "expression": None, - "filterable": True, - "groupby": True, - "is_dttm": None, - "type": "BIGINT(20)", - }, - { - "id": 510, - "column_name": "num_girls", - "verbose_name": None, - "description": None, - "expression": "", - "filterable": False, - "groupby": False, - "is_dttm": False, - "type": "BIGINT(20)", - }, - { - "id": 532, - "column_name": "num", - "verbose_name": None, - "description": None, - "expression": None, - "filterable": True, - "groupby": True, - "is_dttm": None, - "type": "BIGINT(20)", - }, - { - "id": 522, - "column_name": "num_california", - "verbose_name": None, - "description": None, - "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", - "filterable": False, - "groupby": False, - "is_dttm": False, - "type": "NUMBER", - }, - ], - "metrics": [ - { - "id": 824, - "metric_name": "sum__num", - "verbose_name": "Babies", - "description": "", - "expression": "SUM(num)", - "warning_text": "", - "d3format": "", - }, - { - "id": 836, - "metric_name": "count", - "verbose_name": "", - "description": None, - "expression": "count(1)", - "warning_text": None, - "d3format": None, - }, - { - "id": 843, - "metric_name": "ratio", - "verbose_name": "Ratio Boys/Girls", - "description": "This represents the ratio of boys/girls", - "expression": "sum(num_boys) / sum(num_girls)", - "warning_text": "no warning", - "d3format": ".2%", - }, - ], -} +from typing import Any, Dict + + +def get_datasource_post() -> Dict[str, Any]: + return { + "id": None, + "column_formats": {"ratio": ".2%"}, + "database": {"id": 1}, + "description": "Adding a DESCRip", + "default_endpoint": "", + "filter_select_enabled": True, + "name": "birth_names", + "table_name": "birth_names", + "datasource_name": "birth_names", + "type": "table", + "schema": None, + "offset": 66, + "cache_timeout": 55, + "sql": "", + "columns": [ + { + "id": 504, + "column_name": "ds", + "verbose_name": "", + "description": None, + "expression": "", + "filterable": True, + "groupby": True, + "is_dttm": True, + "type": "DATETIME", + }, + { + "id": 505, + "column_name": "gender", + "verbose_name": None, + "description": None, + "expression": "", + "filterable": True, + "groupby": True, + "is_dttm": False, + "type": "VARCHAR(16)", + }, + { + "id": 506, + "column_name": "name", + "verbose_name": None, + "description": None, + "expression": None, + "filterable": True, + "groupby": True, + "is_dttm": None, + "type": "VARCHAR(255)", + }, + { + "id": 508, + "column_name": "state", + "verbose_name": None, + "description": None, + "expression": None, + "filterable": True, + "groupby": True, + "is_dttm": None, + "type": "VARCHAR(10)", + }, + { + "id": 509, + "column_name": "num_boys", + "verbose_name": None, + "description": None, + "expression": None, + "filterable": True, + "groupby": True, + "is_dttm": None, + "type": "BIGINT(20)", + }, + { + "id": 510, + "column_name": "num_girls", + "verbose_name": None, + "description": None, + "expression": "", + "filterable": False, + "groupby": False, + "is_dttm": False, + "type": "BIGINT(20)", + }, + { + "id": 532, + "column_name": "num", + "verbose_name": None, + "description": None, + "expression": None, + "filterable": True, + "groupby": True, + "is_dttm": None, + "type": "BIGINT(20)", + }, + { + "id": 522, + "column_name": "num_california", + "verbose_name": None, + "description": None, + "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", + "filterable": False, + "groupby": False, + "is_dttm": False, + "type": "NUMBER", + }, + ], + "metrics": [ + { + "id": 824, + "metric_name": "sum__num", + "verbose_name": "Babies", + "description": "", + "expression": "SUM(num)", + "warning_text": "", + "d3format": "", + }, + { + "id": 836, + "metric_name": "count", + "verbose_name": "", + "description": None, + "expression": "count(1)", + "warning_text": None, + "d3format": None, + }, + { + "id": 843, + "metric_name": "ratio", + "verbose_name": "Ratio Boys/Girls", + "description": "This represents the ratio of boys/girls", + "expression": "sum(num_boys) / sum(num_girls)", + "warning_text": "no warning", + "d3format": ".2%", + }, + ], + } diff --git a/tests/integration_tests/fixtures/query_context.py b/tests/integration_tests/fixtures/query_context.py index 268e403875e5f..3cd69b93d1389 100644 --- a/tests/integration_tests/fixtures/query_context.py +++ b/tests/integration_tests/fixtures/query_context.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List from superset.utils.core import AnnotationType, DTTM_ALIAS, TimeRangeEndpoint -from tests.integration_tests.base_tests import get_table_by_name +from tests.integration_tests.base_tests import SupersetTestCase query_birth_names = { "extras": { @@ -239,7 +239,7 @@ def get_query_context( :return: Request payload """ table_name = query_name.split(":")[0] - table = get_table_by_name(table_name) + table = SupersetTestCase.get_table(name=table_name) return { "datasource": {"id": table.id, "type": table.type}, "queries": [get_query_object(query_name, add_postprocessing_operations)], diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py index 7dbc7beea3edd..bcf8187f518d5 100644 --- a/tests/integration_tests/import_export_tests.py +++ b/tests/integration_tests/import_export_tests.py @@ -89,19 +89,20 @@ def create_slice( id=None, db_name="examples", table_name="wb_health_population", + schema=None, ): params = { "num_period_compare": "10", "remote_id": id, "datasource_name": table_name, "database_name": db_name, - "schema": "", + "schema": schema, # Test for trailing commas "metrics": ["sum__signup_attempt_email", "sum__signup_attempt_facebook"], } if table_name and not ds_id: - table = self.get_table_by_name(table_name) + table = self.get_table(schema=schema, name=table_name) if table: ds_id = table.id @@ -167,9 +168,6 @@ def get_dash(self, dash_id): def get_datasource(self, datasource_id): return db.session.query(DruidDatasource).filter_by(id=datasource_id).first() - def get_table_by_name(self, name): - return db.session.query(SqlaTable).filter_by(table_name=name).first() - def assert_dash_equals( self, expected_dash, actual_dash, check_position=True, check_slugs=True ): @@ -273,9 +271,7 @@ def test_export_1_dashboard(self): resp.data.decode("utf-8"), object_hook=decode_dashboards )["datasources"] self.assertEqual(1, len(exported_tables)) - self.assert_table_equals( - self.get_table_by_name("birth_names"), exported_tables[0] - ) + self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0]) @pytest.mark.usefixtures( "load_world_bank_dashboard_with_slices", @@ -314,11 +310,9 @@ def test_export_2_dashboards(self): resp_data.get("datasources"), key=lambda t: t.table_name ) self.assertEqual(2, len(exported_tables)) + self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0]) self.assert_table_equals( - self.get_table_by_name("birth_names"), exported_tables[0] - ) - self.assert_table_equals( - self.get_table_by_name("wb_health_population"), exported_tables[1] + self.get_table(name="wb_health_population"), exported_tables[1] ) @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") @@ -329,12 +323,12 @@ def test_import_1_slice(self): self.assertEqual(slc.datasource.perm, slc.perm) self.assert_slice_equals(expected_slice, slc) - table_id = self.get_table_by_name("wb_health_population").id + table_id = self.get_table(name="wb_health_population").id self.assertEqual(table_id, self.get_slice(slc_id).datasource_id) @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_import_2_slices_for_same_table(self): - table_id = self.get_table_by_name("wb_health_population").id + table_id = self.get_table(name="wb_health_population").id # table_id != 666, import func will have to find the table slc_1 = self.create_slice("Import Me 1", ds_id=666, id=10002) slc_id_1 = import_chart(slc_1, None) @@ -351,13 +345,6 @@ def test_import_2_slices_for_same_table(self): self.assert_slice_equals(slc_2, imported_slc_2) self.assertEqual(imported_slc_2.datasource.perm, imported_slc_2.perm) - def test_import_slices_for_non_existent_table(self): - with self.assertRaises(AttributeError): - import_chart( - self.create_slice("Import Me 3", id=10004, table_name="non_existent"), - None, - ) - def test_import_slices_override(self): slc = self.create_slice("Import Me New", id=10005) slc_1_id = import_chart(slc, None, import_time=1990) diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 74387a18d15e7..665f1cd27e150 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -339,7 +339,7 @@ def test_multi_statement(self): class TestSqlaTableModel(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_timestamp_expression(self): - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") ds_col = tbl.get_column("ds") sqla_literal = ds_col.get_timestamp_expression(None) self.assertEqual(str(sqla_literal.compile()), "ds") @@ -359,7 +359,7 @@ def test_get_timestamp_expression(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_timestamp_expression_epoch(self): - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") ds_col = tbl.get_column("ds") ds_col.expression = None @@ -384,7 +384,7 @@ def test_get_timestamp_expression_epoch(self): ds_col.expression = prev_ds_expr def query_with_expr_helper(self, is_timeseries, inner_join=True): - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") ds_col = tbl.get_column("ds") ds_col.expression = None ds_col.python_date_format = None @@ -447,7 +447,7 @@ def test_query_with_expr_groupby(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_sql_mutator(self): - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") query_obj = dict( groupby=[], metrics=None, @@ -472,7 +472,7 @@ def mutator(*args): app.config["SQL_QUERY_MUTATOR"] = None def test_query_with_non_existent_metrics(self): - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") query_obj = dict( groupby=[], @@ -493,7 +493,7 @@ def test_query_with_non_existent_metrics(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_data_for_slices(self): - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") slc = ( metadata_db.session.query(Slice) .filter_by( diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 9c04c62307a3c..b259aa5746353 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -92,7 +92,7 @@ def test_schema_deserialization(self): def test_cache(self): table_name = "birth_names" - table = self.get_table_by_name(table_name) + table = self.get_table(name=table_name) payload = get_query_context(table.name, table.id) payload["force"] = True diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 009082c2f4953..56bfe846957b1 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -1151,7 +1151,7 @@ def tearDown(self): @pytest.mark.usefixtures("load_energy_table_with_slice") def test_rls_filter_alters_energy_query(self): g.user = self.get_user(username="alpha") - tbl = self.get_table_by_name("energy_usage") + tbl = self.get_table(name="energy_usage") sql = tbl.get_query_str(self.query_obj) assert tbl.get_extra_cache_keys(self.query_obj) == [1] assert "value > 1" in sql @@ -1161,7 +1161,7 @@ def test_rls_filter_doesnt_alter_energy_query(self): g.user = self.get_user( username="admin" ) # self.login() doesn't actually set the user - tbl = self.get_table_by_name("energy_usage") + tbl = self.get_table(name="energy_usage") sql = tbl.get_query_str(self.query_obj) assert tbl.get_extra_cache_keys(self.query_obj) == [] assert "value > 1" not in sql @@ -1171,7 +1171,7 @@ def test_multiple_table_filter_alters_another_tables_query(self): g.user = self.get_user( username="alpha" ) # self.login() doesn't actually set the user - tbl = self.get_table_by_name("unicode_test") + tbl = self.get_table(name="unicode_test") sql = tbl.get_query_str(self.query_obj) assert tbl.get_extra_cache_keys(self.query_obj) == [1] assert "value > 1" in sql @@ -1179,7 +1179,7 @@ def test_multiple_table_filter_alters_another_tables_query(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_rls_filter_alters_gamma_birth_names_query(self): g.user = self.get_user(username="gamma") - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") sql = tbl.get_query_str(self.query_obj) # establish that the filters are grouped together correctly with @@ -1192,7 +1192,7 @@ def test_rls_filter_alters_gamma_birth_names_query(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_rls_filter_alters_no_role_user_birth_names_query(self): g.user = self.get_user(username="NoRlsRoleUser") - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") sql = tbl.get_query_str(self.query_obj) # gamma's filters should not be present query @@ -1205,7 +1205,7 @@ def test_rls_filter_alters_no_role_user_birth_names_query(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_rls_filter_doesnt_alter_admin_birth_names_query(self): g.user = self.get_user(username="admin") - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table(name="birth_names") sql = tbl.get_query_str(self.query_obj) # no filters are applied for admin user diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index b60cdcca77d79..ed1358ac7f45e 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -241,7 +241,7 @@ class FilterTestCase(NamedTuple): FilterTestCase(FilterOperator.IN, ["1", "2"], "IN (1, 2)"), FilterTestCase(FilterOperator.NOT_IN, ["1", "2"], "NOT IN (1, 2)"), ) - table = self.get_table_by_name("birth_names") + table = self.get_table(name="birth_names") for filter_ in filters: query_obj = { "granularity": None, diff --git a/tests/integration_tests/tasks/async_queries_tests.py b/tests/integration_tests/tasks/async_queries_tests.py index 8743374569dd4..57b0df5ad2273 100644 --- a/tests/integration_tests/tasks/async_queries_tests.py +++ b/tests/integration_tests/tasks/async_queries_tests.py @@ -42,11 +42,6 @@ from tests.integration_tests.test_app import app -def get_table_by_name(name: str) -> SqlaTable: - with app.app_context(): - return db.session.query(SqlaTable).filter_by(table_name=name).one() - - class TestAsyncQueries(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.object(async_query_manager, "update_job") @@ -127,7 +122,7 @@ def test_soft_timeout_load_chart_data_into_cache( @mock.patch.object(async_query_manager, "update_job") def test_load_explore_json_into_cache(self, mock_update_job): async_query_manager.init_app(app) - table = get_table_by_name("birth_names") + table = self.get_table(name="birth_names") user = security_manager.find_user("gamma") form_data = { "datasource": f"{table.id}__table",