Skip to content

Commit

Permalink
[Dy2St]Fix BUG with Potential security vulnerabilities (#60100)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 authored Dec 18, 2023
1 parent e8ee704 commit fbba94f
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 194 deletions.
1 change: 0 additions & 1 deletion python/paddle/jit/dy2static/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
convert_logical_or as Or,
convert_pop as Pop,
convert_shape as Shape,
convert_shape_compare,
convert_var_dtype as AsDtype,
convert_while_loop as While,
indexable as Indexable,
Expand Down
71 changes: 0 additions & 71 deletions python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,77 +693,6 @@ def has_negative(list_shape):
return x.shape


def convert_shape_compare(left, *args):
"""
A function handles comparison difference between Paddle and Python.
For example, if x and y are Tensors, x.shape == y.shape will return single
boolean Value (True/False). However, paddle.shape(x) == paddle.shape(y) is
an element-wise comparison. The difference can cause dy2stat error. So we
create this function to handle the difference.
Args:
left: variable
*args: compare_op(str), variable, compare_op(str), variable, where
compare_op means "<", ">", "==", "!=", etc.
Returns:
If the variables to compare are NOT Paddle Variables, we will return as
Python like "a op1 b and b op2 c and ... ".
If the variables to compare are Paddle Variables, we will do elementwise
comparsion first and then reduce to a boolean whose numel is 1.
"""
args_len = len(args)
assert (
args_len >= 2
), "convert_shape_compare needs at least one right compare variable"
assert (
args_len % 2 == 0
), "Illegal input for convert_shape_compare, *args should be op(str), var, op(str), var ..."
num_cmp = args_len // 2
if isinstance(left, (Variable, Value)):

def reduce_compare(x, op_str, y):
element_wise_result = eval("x " + op_str + " y")
if op_str == "!=":
return paddle.any(element_wise_result)
elif (
op_str == "is"
or op_str == "is not"
or op_str == "in"
or op_str == "not in"
):
return element_wise_result
else:
return paddle.all(element_wise_result)

final_result = reduce_compare(left, args[0], args[1])
for i in range(1, num_cmp):
cmp_left = args[i * 2 - 1]
cmp_op = args[i * 2]
cmp_right = args[i * 2 + 1]
cur_result = reduce_compare(cmp_left, cmp_op, cmp_right)
final_result = convert_logical_and(
lambda: final_result, lambda: cur_result
)
return final_result
else:
cmp_left = left
final_result = None
for i in range(num_cmp):
cmp_op = args[i * 2]
cmp_right = args[i * 2 + 1]
cur_result = eval("cmp_left " + cmp_op + " cmp_right")
if final_result is None:
final_result = cur_result
else:
final_result = final_result and cur_result

if final_result is False:
return False
cmp_left = cmp_right
return final_result


def cast_bool_if_necessary(var):
assert isinstance(var, (Variable, Value))
if convert_dtype(var.dtype) not in ['bool']:
Expand Down
122 changes: 0 additions & 122 deletions test/dygraph_to_static/test_convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,128 +71,6 @@ def callable_list(x, y):
self.assertEqual(paddle.jit.to_static(callable_list)(1, 2), 3)


class TestConvertShapeCompare(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_non_variable(self):
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2), True
)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "<=", 3), True
)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, ">", 2, "<=", 3),
False,
)

def error_func():
"""
Function used to test that comparison doesn't run after first False
"""
raise ValueError("Used for test")

self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(
1, ">", 2, "<=", lambda: error_func()
),
False,
)

self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(
1, "<", 2, "in", [1, 2, 3]
),
True,
)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(
1, "<", 2, "not in", [1, 2, 3]
),
False,
)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "is", 3),
False,
)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(
1, "<", 2, "is not", [1, 2, 3]
),
True,
)

self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(
[1, 2], "==", [1, 2], "!=", [1, 2, 3]
),
True,
)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(
[1, 2], "!=", [1, 2, 3], "==", [1, 2]
),
False,
)

def test_variable(self):
paddle.enable_static()
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(name='x', shape=[3, 2], dtype='float32')
y = paddle.static.data(name='y', shape=[3, 2], dtype='float32')
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(
x, "is", x, "is not", y
),
True,
)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(
x, "is not", x, "is not", y
),
False,
)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(x, "is", x, "is", y),
False,
)

eq_out = paddle.jit.dy2static.convert_shape_compare(x, "==", y)
not_eq_out = paddle.jit.dy2static.convert_shape_compare(x, "!=", y)
long_eq_out = paddle.jit.dy2static.convert_shape_compare(
x, "==", x, "!=", y
)

place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
exe = paddle.static.Executor(place)
x_y_eq_out = exe.run(
feed={
"x": np.ones([3, 2]).astype(np.float32),
"y": np.ones([3, 2]).astype(np.float32),
},
fetch_list=[eq_out, not_eq_out, long_eq_out],
)
np.testing.assert_array_equal(
np.array(x_y_eq_out), np.array([True, False, False])
)

set_a_zero = np.ones([3, 2]).astype(np.float32)
set_a_zero[0][0] = 0.0
x_y_not_eq_out = exe.run(
feed={"x": np.ones([3, 2]).astype(np.float32), "y": set_a_zero},
fetch_list=[eq_out, not_eq_out, long_eq_out],
)
np.testing.assert_array_equal(
np.array(x_y_not_eq_out), np.array([False, True, True])
)
paddle.disable_static()


class ShapeLayer(paddle.nn.Layer):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit fbba94f

Please sign in to comment.