Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: DB-specific quoting in Jinja macro #25779

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from jinja2 import DebugUndefined
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.expression import bindparam
from sqlalchemy.types import String

from superset.constants import LRU_CACHE_MAX_SIZE
Expand Down Expand Up @@ -396,23 +397,39 @@ def validate_template_context(
return validate_context_types(context)


def where_in(values: list[Any], mark: str = "'") -> str:
"""
Given a list of values, build a parenthesis list suitable for an IN expression.
class WhereInMacro: # pylint: disable=too-few-public-methods
def __init__(self, dialect: Dialect):
self.dialect = dialect

>>> where_in([1, "b", 3])
(1, 'b', 3)
def __call__(self, values: list[Any], mark: Optional[str] = None) -> str:
"""
Given a list of values, build a parenthesis list suitable for an IN expression.

"""
>>> from sqlalchemy.dialects import mysql
>>> where_in = WhereInMacro(dialect=mysql.dialect())
>>> where_in([1, "Joe's", 3])
(1, 'Joe''s', 3)

def quote(value: Any) -> str:
if isinstance(value, str):
value = value.replace(mark, mark * 2)
return f"{mark}{value}{mark}"
return str(value)
"""
binds = [bindparam(f"value_{i}", value) for i, value in enumerate(values)]
string_representations = [
str(
bind.compile(
dialect=self.dialect, compile_kwargs={"literal_binds": True}
)
)
for bind in binds
]
joined_values = ", ".join(string_representations)
result = f"({joined_values})"

if mark:
result += (
"\n-- WARNING: the `mark` parameter was removed from the `where_in` "
"macro for security reasons\n"
)

joined_values = ", ".join(quote(value) for value in values)
return f"({joined_values})"
return result


class BaseTemplateProcessor:
Expand Down Expand Up @@ -448,7 +465,7 @@ def __init__(
self.set_context(**kwargs)

# custom filters
self._env.filters["where_in"] = where_in
self._env.filters["where_in"] = WhereInMacro(database.get_dialect())

def set_context(self, **kwargs: Any) -> None:
self._context.update(kwargs)
Expand Down
9 changes: 7 additions & 2 deletions tests/unit_tests/jinja_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@

import pytest
from pytest_mock import MockFixture
from sqlalchemy.dialects import mysql

from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.jinja_context import dataset_macro, where_in
from superset.jinja_context import dataset_macro, WhereInMacro


def test_where_in() -> None:
"""
Test the ``where_in`` Jinja2 filter.
"""
where_in = WhereInMacro(mysql.dialect())
assert where_in([1, "b", 3]) == "(1, 'b', 3)"
assert where_in([1, "b", 3], '"') == '(1, "b", 3)'
assert where_in([1, "b", 3], '"') == (
"(1, 'b', 3)\n-- WARNING: the `mark` parameter was removed from the "
"`where_in` macro for security reasons\n"
)
assert where_in(["O'Malley's"]) == "('O''Malley''s')"


Expand Down
Loading