Skip to content

Commit

Permalink
Add in protections against call to eval(expression)
Browse files Browse the repository at this point in the history
  • Loading branch information
robbmcleod committed Jul 22, 2023
1 parent 74d5973 commit 4b2d89c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 15 deletions.
26 changes: 16 additions & 10 deletions numexpr/necompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sys
import numpy
import threading
import re

is_cpu_amd_intel = False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE
from numexpr import interpreter, expressions, use_vml
Expand Down Expand Up @@ -259,10 +260,17 @@ def __init__(self, astnode):
def __str__(self):
return 'Immediate(%d)' % (self.node.value,)


_forbidden_re = re.compile('[\;[\:]|__')
def stringToExpression(s, types, context):
"""Given a string, convert it to a tree of ExpressionNode's.
"""
# sanitize the string for obvious attack vectors that NumExpr cannot
# parse into its homebrew AST. This is to protect the call to `eval` below.
# We forbid `;`, `:`. `[` and `__`
# We would like to forbid `.` but it is both a reference and decimal point.
if _forbidden_re.search(s) is not None:
raise ValueError(f'Expression {s} has forbidden control characters.')

old_ctx = expressions._context.get_current_context()
try:
expressions._context.set_new_context(context)
Expand All @@ -285,8 +293,10 @@ def stringToExpression(s, types, context):
t = types.get(name, default_type)
names[name] = expressions.VariableNode(name, type_to_kind[t])
names.update(expressions.functions)

# now build the expression
ex = eval(c, names)

if expressions.isConstant(ex):
ex = expressions.ConstantNode(ex, expressions.getKind(ex))
elif not isinstance(ex, expressions.ExpressionNode):
Expand Down Expand Up @@ -611,9 +621,7 @@ def NumExpr(ex, signature=(), **kwargs):
Returns a `NumExpr` object containing the compiled function.
"""
# NumExpr can be called either directly by the end-user, in which case
# kwargs need to be sanitized by getContext, or by evaluate,
# in which case kwargs are in already sanitized.

# In that case _frame_depth is wrong (it should be 2) but it doesn't matter
# since it will not be used (because truediv='auto' has already been
# translated to either True or False).
Expand Down Expand Up @@ -758,7 +766,7 @@ def getArguments(names, local_dict=None, global_dict=None, _frame_depth: int=2):
_names_cache = CacheDict(256)
_numexpr_cache = CacheDict(256)
_numexpr_last = {}

_numexpr_sanity = set()
evaluate_lock = threading.Lock()

# MAYBE: decorate this function to add attributes instead of having the
Expand Down Expand Up @@ -861,7 +869,7 @@ def evaluate(ex: str,
out: numpy.ndarray = None,
order: str = 'K',
casting: str = 'safe',
_frame_depth: int=3,
_frame_depth: int = 3,
**kwargs) -> numpy.ndarray:
"""
Evaluate a simple array expression element-wise using the virtual machine.
Expand Down Expand Up @@ -909,6 +917,8 @@ def evaluate(ex: str,
_frame_depth: int
The calling frame depth. Unless you are a NumExpr developer you should
not set this value.
"""
# We could avoid code duplication if we called validate and then re_evaluate
# here, but they we have difficulties with the `sys.getframe(2)` call in
Expand All @@ -921,10 +931,6 @@ def evaluate(ex: str,
else:
raise e





def re_evaluate(local_dict: Optional[Dict] = None,
_frame_depth: int=2) -> numpy.ndarray:
"""
Expand Down
50 changes: 45 additions & 5 deletions numexpr/tests/test_numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,9 @@ def test_re_evaluate_dict(self):
a1 = array([1., 2., 3.])
b1 = array([4., 5., 6.])
c1 = array([7., 8., 9.])
x = evaluate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1})
x = re_evaluate()
local_dict={'a': a1, 'b': b1, 'c': c1}
x = evaluate("2*a + 3*b*c", local_dict=local_dict)
x = re_evaluate(local_dict=local_dict)
assert_array_equal(x, array([86., 124., 168.]))

def test_validate(self):
Expand All @@ -400,9 +401,10 @@ def test_validate_dict(self):
a1 = array([1., 2., 3.])
b1 = array([4., 5., 6.])
c1 = array([7., 8., 9.])
retval = validate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1})
local_dict={'a': a1, 'b': b1, 'c': c1}
retval = validate("2*a + 3*b*c", local_dict=local_dict)
assert(retval is None)
x = re_evaluate()
x = re_evaluate(local_dict=local_dict)
assert_array_equal(x, array([86., 124., 168.]))

# Test for issue #22
Expand Down Expand Up @@ -502,11 +504,49 @@ def test_illegal_value(self):
a = arange(3)
try:
evaluate("a < [0, 0, 0]")
except TypeError:
except (ValueError, TypeError):
pass
else:
self.fail()

def test_forbidden_tokens(self):
# Forbid dunder
try:
evaluate('__builtins__')
except ValueError:
pass
else:
self.fail()

# Forbid colon for lambda funcs
try:
evaluate('lambda x: x')
except ValueError:
pass
else:
self.fail()

# Forbid indexing
try:
evaluate('locals()[]')
except ValueError:
pass
else:
self.fail()

# Forbid semicolon
try:
evaluate('import os; os.cpu_count()')
except ValueError:
pass
else:
self.fail()

# I struggle to come up with cases for our ban on `'` and `"`




def test_disassemble(self):
assert_equal(disassemble(NumExpr(
"where(m, a, -1)", [('m', bool), ('a', float)])),
Expand Down

0 comments on commit 4b2d89c

Please sign in to comment.