Skip to content

Commit

Permalink
support getitem when index is a all-false bool tensor (#41297)
Browse files Browse the repository at this point in the history
* support getitem when index is a all-false bool tensor

* use cond to replace if

* add static_graph geitem unit test when index is a bool tensor
  • Loading branch information
FlyingQianMM authored Apr 4, 2022
1 parent 3e9ad09 commit eb6d7da
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 21 deletions.
11 changes: 6 additions & 5 deletions python/paddle/fluid/tests/unittests/test_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,24 +795,25 @@ def _test_bool_index(self):
np_value = np.random.random(shape).astype('float32')
var_tensor = paddle.to_tensor(np_value)
index = [[True, True, True, True], [True, False, True, True],
[True, False, False, True], [False, 0, 1, True, True]]
[True, False, False, True], [False, 0, 1, True, True],
[False, False, False, False]]
index2d = np.array([[True, True], [False, False], [True, False],
[True, True]])
tensor_index = paddle.to_tensor(index2d)
var = [
var_tensor[index[0]].numpy(),
var_tensor[index[1]].numpy(),
var_tensor[index[2]].numpy(),
var_tensor[index[3]].numpy(),
var_tensor[index[0]].numpy(), var_tensor[index[1]].numpy(),
var_tensor[index[2]].numpy(), var_tensor[index[3]].numpy(),
var_tensor[paddle.to_tensor(index[0])].numpy(),
var_tensor[tensor_index].numpy(),
var_tensor[paddle.to_tensor(index[4])].numpy()
]
self.assertTrue(np.array_equal(var[0], np_value[index[0]]))
self.assertTrue(np.array_equal(var[1], np_value[index[1]]))
self.assertTrue(np.array_equal(var[2], np_value[index[2]]))
self.assertTrue(np.array_equal(var[3], np_value[index[3]]))
self.assertTrue(np.array_equal(var[4], np_value[index[0]]))
self.assertTrue(np.array_equal(var[5], np_value[index2d]))
self.assertTrue(np.array_equal(var[6], np_value[index[4]]))
self.assertTrue(
np.array_equal(var_tensor[var_tensor > 0.67], np_value[np_value >
0.67]))
Expand Down
55 changes: 55 additions & 0 deletions python/paddle/fluid/tests/unittests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,61 @@ def test_dygraph_list_index_muti_dim(self):
y = x[index_t1, index_t2]
self.assertTrue(np.array_equal(y.numpy(), y_np))

def run_getitem_list_index(self, array, index):
x = paddle.static.data(name='x', shape=array.shape, dtype='float32')

y = x[index]
place = paddle.fluid.CPUPlace()

prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)

exe.run(paddle.static.default_startup_program())
fetch_list = [y.name]
array2 = array.copy()

try:
value_np = array2[index]
except:
with self.assertRaises(ValueError):
getitem_pp = exe.run(prog,
feed={x.name: array},
fetch_list=fetch_list)
return
getitem_pp = exe.run(prog, feed={x.name: array}, fetch_list=fetch_list)

print(getitem_pp)
self.assertTrue(
np.array_equal(value_np, getitem_pp[0]),
msg='\n numpy:{},\n paddle:{}'.format(value_np, getitem_pp[0]))

def test_static_graph_getitem_bool_index(self):
paddle.enable_static()

# case 1:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True, False, False, False])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_getitem_list_index(array, index)

# case 2:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([False, True, False, False])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_getitem_list_index(array, index)

# case 3:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True, True, True, True])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_getitem_list_index(array, index)

def run_setitem_list_index(self, array, index, value_np):
x = paddle.static.data(name='x', shape=array.shape, dtype='float32')

Expand Down
49 changes: 33 additions & 16 deletions python/paddle/fluid/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,37 @@ def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags):
attrs[attr_name] = attr


# the item is a tensor of bool
def get_value_for_bool_tensor(var, item):
if len(item.shape) > len(var.shape):
raise IndexError("The dims of bool index doesn't match indexed array, "
"the dims of bool index except to be equal or less "
"than {}, but received {}.".format(
len(var.shape), len(item.shape)))
for i, dim_len in enumerate(item.shape):
if dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "\
"dimension {}, the target dimension is {}, but received {}.".
format(i, var.shape[i], dim_len))

def idx_not_empty(var, item):
from .layers.nn import where
from ..tensor import gather_nd

bool_2_idx = where(item == True)
return gather_nd(var, bool_2_idx)

def idx_empty(var):
var_shape = list(var.shape)
var_shape[0] = 0
return paddle.empty(var_shape, dtype=var.dtype)

from .layers.control_flow import cond
return cond(item.any(), lambda: idx_not_empty(var, item),
lambda: idx_empty(var))


def _getitem_impl_(var, item):
"""
Slice the variable.
Expand Down Expand Up @@ -393,24 +424,10 @@ def _getitem_impl_(var, item):
elif isinstance(slice_item, (Variable, core.eager.Tensor)):
if len(item) == 1:

from ..tensor import index_select, gather_nd
from .layers.nn import where
from ..tensor import index_select

if slice_item.dtype == paddle.bool:
if len(slice_item.shape) > len(var.shape):
raise IndexError(
"The dims of bool index doesn't match indexed array, "
"the dims of bool index except to be equal or less "
"than {}, but received {}.".format(
len(var.shape), len(slice_item.shape)))
for i, dim_len in enumerate(slice_item.shape):
if dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "\
"dimension {}, the target dimension is {}, but received {}.".
format(i, var.shape[i], dim_len))
bool_2_idx = where(slice_item == True)
return gather_nd(var, bool_2_idx)
return get_value_for_bool_tensor(var, slice_item)
else:
if len(slice_item.shape) == 1:
return index_select(var, index=slice_item, axis=0)
Expand Down

0 comments on commit eb6d7da

Please sign in to comment.