Skip to content

Commit

Permalink
Support passing custom filters with the same name as built-in flags
Browse files Browse the repository at this point in the history
  • Loading branch information
cocolato committed Oct 24, 2024
1 parent ee988d2 commit c5aafa1
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 12 deletions.
31 changes: 27 additions & 4 deletions mako/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mako import filters
from mako import parsetree
from mako import util
from mako.filters import CONFLICT_PREFIX
from mako.pygen import PythonPrinter


Expand Down Expand Up @@ -522,6 +523,8 @@ def write_variable_declares(self, identifiers, toplevel=False, limit=None):
self.printer.writeline("loop = __M_loop = runtime.LoopStack()")

for ident in to_write:
if ident.startswith(CONFLICT_PREFIX):
ident = ident.replace(CONFLICT_PREFIX, "")
if ident in comp_idents:
comp = comp_idents[ident]
if comp.is_block:
Expand Down Expand Up @@ -785,16 +788,36 @@ def locate_encode(name):
else:
return filters.DEFAULT_ESCAPES.get(name, name)

if "n" not in args:
filter_args = []
conflict_n = "%sn" % CONFLICT_PREFIX
if conflict_n not in args:
if is_expression:
if self.compiler.pagetag:
args = self.compiler.pagetag.filter_args.args + args
if self.compiler.default_filters and "n" not in args:
filter_args = self.compiler.pagetag.filter_args.args
if self.compiler.default_filters and conflict_n not in args:
args = self.compiler.default_filters + args
for e in args:
# if filter given as a function, get just the identifier portion
if e == "n":
if e == conflict_n:
continue
if e.startswith(CONFLICT_PREFIX):
if e not in filter_args:
ident = e.replace(CONFLICT_PREFIX, "")
m = re.match(r"(.+?)(\(.*\))", e)
if m:
target = "%s(%s)" % (ident, target)
continue
target = "%s(%s) if %s is not UNDEFINED else %s(%s)" % (
ident,
target,
ident,
locate_encode(ident),
target,
)
continue
e = e.replace(CONFLICT_PREFIX, "")

# if filter given as a function, get just the identifier portion
m = re.match(r"(.+?)(\(.*\))", e)
if m:
ident, fargs = m.group(1, 2)
Expand Down
2 changes: 2 additions & 0 deletions mako/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,5 @@ def htmlentityreplace_errors(ex):
"str": "str",
"n": "n",
}

CONFLICT_PREFIX = "__ALIAS_"
21 changes: 19 additions & 2 deletions mako/pyparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from mako import compat
from mako import exceptions
from mako import util
from mako.filters import CONFLICT_PREFIX
from mako.filters import DEFAULT_ESCAPES

# words that cannot be assigned to (notably
# smaller than the total keys in __builtins__)
Expand Down Expand Up @@ -196,9 +198,24 @@ def visit_Tuple(self, node):
p.declared_identifiers
)
lui = self.listener.undeclared_identifiers
self.listener.undeclared_identifiers = lui.union(
p.undeclared_identifiers
# self.listener.undeclared_identifiers = lui.union(
# p.undeclared_identifiers
# )
undeclared_identifiers = lui.union(p.undeclared_identifiers)
conflict_identifiers = undeclared_identifiers.intersection(
DEFAULT_ESCAPES
)
if conflict_identifiers:
_map = {i: CONFLICT_PREFIX + i for i in conflict_identifiers}
# for k, v in _map.items():
for i, arg in enumerate(self.listener.args):
if arg in _map:
self.listener.args[i] = _map[arg]
self.listener.undeclared_identifiers = (
undeclared_identifiers ^ conflict_identifiers
).union(_map.values())
else:
self.listener.undeclared_identifiers = undeclared_identifiers


class ParseFunc(_ast_util.NodeVisitor):
Expand Down
16 changes: 11 additions & 5 deletions test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,22 @@ def test_python_fragment(self):

def test_argument_list(self):
parsed = ast.ArgumentList(
"3, 5, 'hi', x+5, " "context.get('lala')", **exception_kwargs
"3, 5, 'hi', g+5, " "context.get('lala')", **exception_kwargs
)
eq_(parsed.undeclared_identifiers, {"x", "context"})
eq_(parsed.undeclared_identifiers, {"g", "context"})
eq_(
[x for x in parsed.args],
["3", "5", "'hi'", "(x + 5)", "context.get('lala')"],
["3", "5", "'hi'", "(g + 5)", "context.get('lala')"],
)

parsed = ast.ArgumentList("h", **exception_kwargs)
eq_(parsed.args, ["h"])
parsed = ast.ArgumentList("m", **exception_kwargs)
eq_(parsed.args, ["m"])

def test_conflict_argument_list(self):
parsed = ast.ArgumentList(
"3, 5, 'hi', n+5, " "context.get('lala')", **exception_kwargs
)
eq_(parsed.undeclared_identifiers, {"__ALIAS_n", "context"})

def test_function_decl(self):
"""test getting the arguments from a function"""
Expand Down
33 changes: 33 additions & 0 deletions test/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,36 @@ def test_capture_ccall(self):

# print t.render()
assert flatten_result(t.render()) == "this is foo. body: ccall body"

def test_conflict_filter_ident(self):
class h(object):
foo = str

t = Template(
"""
X:
${"asdf" | h.foo}
"""
)
assert flatten_result(t.render(h=h)) == "X: asdf"

def h(i):
return str(i) + "1"

t = Template(
"""
${123 | h}
"""
)
assert flatten_result(t.render()) == "123"
assert flatten_result(t.render(h=h)) == "1231"

t = Template(
"""
<%def name="foo()" filter="h">
this is foo</%def>
${foo()}
"""
)
assert flatten_result(t.render()) == "this is foo"
assert flatten_result(t.render(h=h)) == "this is foo1"
2 changes: 1 addition & 1 deletion test/test_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,7 @@ def test_integration(self):
Text(" <tr>\n", (14, 1)),
ControlLine("for", "for x in j:", False, (15, 1)),
Text(" <td>Hello ", (16, 1)),
Expression("x", ["h"], (16, 23)),
Expression("x", ["__ALIAS_h"], (16, 23)),
Text("</td>\n", (16, 30)),
ControlLine("for", "endfor", True, (17, 1)),
Text(" </tr>\n", (18, 1)),
Expand Down

0 comments on commit c5aafa1

Please sign in to comment.