From 318eff732764f89e49c6eb699238c8950f533726 Mon Sep 17 00:00:00 2001 From: Geido <60598000+geido@users.noreply.github.com> Date: Wed, 9 Oct 2024 23:26:32 +0300 Subject: [PATCH] fix(Jinja): Extra cache keys to consider vars with set (#30549) --- superset/jinja_context.py | 15 +- tests/integration_tests/sqla_models_tests.py | 194 ++++++++++++------- 2 files changed, 134 insertions(+), 75 deletions(-) diff --git a/superset/jinja_context.py b/superset/jinja_context.py index d7ae892301689..e4a83422315d8 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -104,13 +104,14 @@ class ExtraCache: # Regular expression for detecting the presence of templated methods which could # be added to the cache key. regex = re.compile( - r"\{\{.*(" - r"current_user_id\(.*\)|" - r"current_username\(.*\)|" - r"current_user_email\(.*\)|" - r"cache_key_wrapper\(.*\)|" - r"url_param\(.*\)" - r").*\}\}" + r"(\{\{|\{%)[^{}]*?(" + r"current_user_id\([^()]*\)|" + r"current_username\([^()]*\)|" + r"current_user_email\([^()]*\)|" + r"cache_key_wrapper\([^()]*\)|" + r"url_param\([^()]*\)" + r")" + r"[^{}]*?(\}\}|\%\})" ) def __init__( # pylint: disable=too-many-arguments diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index fb03f37e62c7b..922cbf67fd65e 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -133,74 +133,6 @@ def test_db_column_types(self): col = TableColumn(column_name="foo", type=str_type, table=tbl, is_dttm=True) assert col.is_temporal - @patch("superset.jinja_context.get_user_id", return_value=1) - @patch("superset.jinja_context.get_username", return_value="abc") - @patch("superset.jinja_context.get_user_email", return_value="abc@test.com") - def test_extra_cache_keys(self, mock_user_email, mock_username, mock_user_id): - base_query_obj = { - "granularity": None, - "from_dttm": None, - "to_dttm": None, - "groupby": ["id", "username", "email"], - "metrics": [], - "is_timeseries": False, - "filter": [], - } - - # Table with Jinja callable. - table1 = SqlaTable( - table_name="test_has_extra_cache_keys_table", - sql=""" - SELECT - '{{ current_user_id() }}' as id, - '{{ current_username() }}' as username, - '{{ current_user_email() }}' as email - """, - database=get_example_database(), - ) - - query_obj = dict(**base_query_obj, extras={}) - extra_cache_keys = table1.get_extra_cache_keys(query_obj) - assert table1.has_extra_cache_key_calls(query_obj) - assert set(extra_cache_keys) == {1, "abc", "abc@test.com"} - - # Table with Jinja callable disabled. - table2 = SqlaTable( - table_name="test_has_extra_cache_keys_disabled_table", - sql=""" - SELECT - '{{ current_user_id(False) }}' as id, - '{{ current_username(False) }}' as username, - '{{ current_user_email(False) }}' as email, - """, - database=get_example_database(), - ) - query_obj = dict(**base_query_obj, extras={}) - extra_cache_keys = table2.get_extra_cache_keys(query_obj) - assert table2.has_extra_cache_key_calls(query_obj) - self.assertListEqual(extra_cache_keys, []) # noqa: PT009 - - # Table with no Jinja callable. - query = "SELECT 'abc' as user" - table3 = SqlaTable( - table_name="test_has_no_extra_cache_keys_table", - sql=query, - database=get_example_database(), - ) - - query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"}) - extra_cache_keys = table3.get_extra_cache_keys(query_obj) - assert not table3.has_extra_cache_key_calls(query_obj) - self.assertListEqual(extra_cache_keys, []) # noqa: PT009 - - # With Jinja callable in SQL expression. - query_obj = dict( - **base_query_obj, extras={"where": "(user != '{{ current_username() }}')"} - ) - extra_cache_keys = table3.get_extra_cache_keys(query_obj) - assert table3.has_extra_cache_key_calls(query_obj) - assert extra_cache_keys == ["abc"] - @patch("superset.jinja_context.get_username", return_value="abc") def test_jinja_metrics_and_calc_columns(self, mock_username): base_query_obj = { @@ -859,6 +791,132 @@ def test_none_operand_in_filter(login_as_admin, physical_dataset): ) +@pytest.mark.usefixtures("app_context") +@pytest.mark.parametrize( + "table_name,sql,expected_cache_keys,has_extra_cache_keys", + [ + ( + "test_has_extra_cache_keys_table", + """ + SELECT + '{{ current_user_id() }}' as id, + '{{ current_username() }}' as username, + '{{ current_user_email() }}' as email + """, + {1, "abc", "abc@test.com"}, + True, + ), + ( + "test_has_extra_cache_keys_table_with_set", + """ + {% set user_email = current_user_email() %} + SELECT + '{{ current_user_id() }}' as id, + '{{ current_username() }}' as username, + '{{ user_email }}' as email + """, + {1, "abc", "abc@test.com"}, + True, + ), + ( + "test_has_extra_cache_keys_table_with_se_multiple", + """ + {% set user_conditional_id = current_user_email() and current_user_id() %} + SELECT + '{{ user_conditional_id }}' as conditional + """, + {1, "abc@test.com"}, + True, + ), + ( + "test_has_extra_cache_keys_disabled_table", + """ + SELECT + '{{ current_user_id(False) }}' as id, + '{{ current_username(False) }}' as username, + '{{ current_user_email(False) }}' as email + """, + [], + True, + ), + ("test_has_no_extra_cache_keys_table", "SELECT 'abc' as user", [], False), + ], +) +@patch("superset.jinja_context.get_user_id", return_value=1) +@patch("superset.jinja_context.get_username", return_value="abc") +@patch("superset.jinja_context.get_user_email", return_value="abc@test.com") +def test_extra_cache_keys( + mock_user_email, + mock_username, + mock_user_id, + table_name, + sql, + expected_cache_keys, + has_extra_cache_keys, +): + table = SqlaTable( + table_name=table_name, + sql=sql, + database=get_example_database(), + ) + base_query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": ["id", "username", "email"], + "metrics": [], + "is_timeseries": False, + "filter": [], + } + + query_obj = dict(**base_query_obj, extras={}) + + extra_cache_keys = table.get_extra_cache_keys(query_obj) + assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys + assert set(extra_cache_keys) == set(expected_cache_keys) + + +@pytest.mark.usefixtures("app_context") +@pytest.mark.parametrize( + "sql_expression,expected_cache_keys,has_extra_cache_keys", + [ + ("(user != '{{ current_username() }}')", ["abc"], True), + ("(user != 'abc')", [], False), + ], +) +@patch("superset.jinja_context.get_user_id", return_value=1) +@patch("superset.jinja_context.get_username", return_value="abc") +@patch("superset.jinja_context.get_user_email", return_value="abc@test.com") +def test_extra_cache_keys_in_sql_expression( + mock_user_email, + mock_username, + mock_user_id, + sql_expression, + expected_cache_keys, + has_extra_cache_keys, +): + table = SqlaTable( + table_name="test_has_no_extra_cache_keys_table", + sql="SELECT 'abc' as user", + database=get_example_database(), + ) + base_query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": ["id", "username", "email"], + "metrics": [], + "is_timeseries": False, + "filter": [], + } + + query_obj = dict(**base_query_obj, extras={"where": sql_expression}) + + extra_cache_keys = table.get_extra_cache_keys(query_obj) + assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys + assert extra_cache_keys == expected_cache_keys + + @pytest.mark.usefixtures("app_context") @pytest.mark.parametrize( "row,dimension,result",