diff --git a/superset-frontend/src/components/ListView/types.ts b/superset-frontend/src/components/ListView/types.ts index 2efb5f0a183a6..634631a1833ec 100644 --- a/superset-frontend/src/components/ListView/types.ts +++ b/superset-frontend/src/components/ListView/types.ts @@ -53,7 +53,8 @@ export interface Filter { | 'rel_m_m' | 'rel_o_m' | 'title_or_slug' - | 'name_or_description'; + | 'name_or_description' + | 'all_text'; input?: 'text' | 'textarea' | 'select' | 'checkbox' | 'search'; unfilteredLabel?: string; selects?: SelectOption[]; diff --git a/superset/queries/saved_queries/api.py b/superset/queries/saved_queries/api.py index 81204a8b1c98e..1defb198a5dd2 100644 --- a/superset/queries/saved_queries/api.py +++ b/superset/queries/saved_queries/api.py @@ -32,7 +32,10 @@ SavedQueryBulkDeleteFailedError, SavedQueryNotFoundError, ) -from superset.queries.saved_queries.filters import SavedQueryFilter +from superset.queries.saved_queries.filters import ( + SavedQueryAllTextFilter, + SavedQueryFilter, +) from superset.queries.saved_queries.schemas import ( get_delete_ids_schema, openapi_spec_methods_override, @@ -93,6 +96,8 @@ class SavedQueryRestApi(BaseSupersetModelRestApi): "database.database_name", ] + search_filters = {"label": [SavedQueryAllTextFilter]} + apispec_parameter_schemas = { "get_delete_ids_schema": get_delete_ids_schema, } diff --git a/superset/queries/saved_queries/filters.py b/superset/queries/saved_queries/filters.py index 498a061edce10..09636cc3a8a47 100644 --- a/superset/queries/saved_queries/filters.py +++ b/superset/queries/saved_queries/filters.py @@ -17,12 +17,33 @@ from typing import Any from flask import g +from flask_babel import lazy_gettext as _ from flask_sqlalchemy import BaseQuery +from sqlalchemy import or_ +from sqlalchemy.orm.query import Query from superset.models.sql_lab import SavedQuery from superset.views.base import BaseFilter +class SavedQueryAllTextFilter(BaseFilter): # pylint: disable=too-few-public-methods + name = _("All Text") + arg_name = "all_text" + + def apply(self, query: Query, value: Any) -> Query: + if not value: + return query + ilike_value = f"%{value}%" + return query.filter( + or_( + SavedQuery.schema.ilike(ilike_value), + SavedQuery.label.ilike(ilike_value), + SavedQuery.description.ilike(ilike_value), + SavedQuery.sql.ilike(ilike_value), + ) + ) + + class SavedQueryFilter(BaseFilter): # pylint: disable=too-few-public-methods def apply(self, query: BaseQuery, value: Any) -> BaseQuery: """ diff --git a/tests/queries/saved_queries/api_tests.py b/tests/queries/saved_queries/api_tests.py index b3ce625b1fd6b..f268b1dc06231 100644 --- a/tests/queries/saved_queries/api_tests.py +++ b/tests/queries/saved_queries/api_tests.py @@ -43,6 +43,7 @@ def insert_saved_query( db_id: Optional[int] = None, created_by=None, schema: Optional[str] = "", + description: Optional[str] = "", ) -> SavedQuery: database = None if db_id: @@ -53,6 +54,7 @@ def insert_saved_query( sql=sql, label=label, schema=schema, + description=description, ) db.session.add(query) db.session.commit() @@ -69,6 +71,7 @@ def insert_default_saved_query( db_id=example_db.id, created_by=admin, schema=schema, + description="cool description", ) @pytest.fixture() @@ -195,6 +198,95 @@ def test_get_list_filter_saved_query(self): data = json.loads(rv.data.decode("utf-8")) assert data["count"] == len(all_queries) + @pytest.mark.usefixtures("create_saved_queries") + def test_get_list_custom_filter_schema_saved_query(self): + """ + Saved Query API: Test get list and custom filter (schema) saved query + """ + self.login(username="admin") + admin = self.get_user("admin") + + all_queries = ( + db.session.query(SavedQuery) + .filter(SavedQuery.created_by == admin) + .filter(SavedQuery.schema.ilike("%2%")) + .all() + ) + query_string = { + "filters": [{"col": "label", "opr": "all_text", "value": "schema2"}], + } + uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" + rv = self.get_assert_metric(uri, "get_list") + assert rv.status_code == 200 + data = json.loads(rv.data.decode("utf-8")) + assert data["count"] == len(all_queries) + + @pytest.mark.usefixtures("create_saved_queries") + def test_get_list_custom_filter_label_saved_query(self): + """ + Saved Query API: Test get list and custom filter (label) saved query + """ + self.login(username="admin") + admin = self.get_user("admin") + all_queries = ( + db.session.query(SavedQuery) + .filter(SavedQuery.created_by == admin) + .filter(SavedQuery.label.ilike("%3%")) + .all() + ) + query_string = { + "filters": [{"col": "label", "opr": "all_text", "value": "label3"}], + } + uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" + rv = self.get_assert_metric(uri, "get_list") + assert rv.status_code == 200 + data = json.loads(rv.data.decode("utf-8")) + assert data["count"] == len(all_queries) + + @pytest.mark.usefixtures("create_saved_queries") + def test_get_list_custom_filter_sql_saved_query(self): + """ + Saved Query API: Test get list and custom filter (sql) saved query + """ + self.login(username="admin") + admin = self.get_user("admin") + all_queries = ( + db.session.query(SavedQuery) + .filter(SavedQuery.created_by == admin) + .filter(SavedQuery.sql.ilike("%table%")) + .all() + ) + query_string = { + "filters": [{"col": "label", "opr": "all_text", "value": "table"}], + } + uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" + rv = self.get_assert_metric(uri, "get_list") + assert rv.status_code == 200 + data = json.loads(rv.data.decode("utf-8")) + assert data["count"] == len(all_queries) + + @pytest.mark.usefixtures("create_saved_queries") + def test_get_list_custom_filter_description_saved_query(self): + """ + Saved Query API: Test get list and custom filter (description) saved query + """ + self.login(username="admin") + admin = self.get_user("admin") + all_queries = ( + db.session.query(SavedQuery) + .filter(SavedQuery.created_by == admin) + .filter(SavedQuery.description.ilike("%cool%")) + .all() + ) + query_string = { + "filters": [{"col": "label", "opr": "all_text", "value": "cool"}], + } + uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" + rv = self.get_assert_metric(uri, "get_list") + assert rv.status_code == 200 + data = json.loads(rv.data.decode("utf-8")) + assert data["count"] == len(all_queries) + def test_info_saved_query(self): """ SavedQuery API: Test info @@ -281,7 +373,7 @@ def test_get_saved_query(self): expected_result = { "id": saved_query.id, "database": {"id": saved_query.database.id, "database_name": "examples"}, - "description": None, + "description": "cool description", "created_by": { "first_name": saved_query.created_by.first_name, "id": saved_query.created_by.id,