Skip to content

Commit

Permalink
fixing logic for transpose operation in advanced indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
ipdemes committed Apr 5, 2022
1 parent dc287a6 commit f3012c0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
57 changes: 32 additions & 25 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,43 @@ def _create_indexing_array(self, key, is_set=False):
# the transformation, we need to return a copy
copy_needed = False
tuple_of_arrays = ()
index_map = []

# First, we need to check if transpose is needed
for dim, k in enumerate(key):
if np.isscalar(k) or isinstance(k, NumPyThunk):
if start_index == -1:
start_index = dim
transpose_indices += (dim,)
transpose_needed = transpose_needed or (
(dim - last_index) > 1
)
last_index = dim

if transpose_needed:
copy_needed = True
start_index = 0
post_indices = tuple(
i for i in range(store.ndim) if i not in transpose_indices
)
transpose_indices += post_indices
store = store.transpose(transpose_indices)
index_map = list(transpose_indices)
count = 0
for i in transpose_indices:
index_map[i] = count
count += 1
else:
index_map = tuple(range(len(key)))

for d, k in enumerate(key):
dim = index_map[d]
if np.isscalar(k):
if k < 0:
k += store.shape[dim + shift]
store = store.project(dim + shift, k)
shift -= 1
copy_needed = True
last_index = dim + shift
elif k is np.newaxis:
store = store.promote(dim + shift, 1)
copy_needed = True
Expand All @@ -469,22 +497,8 @@ def _create_indexing_array(self, key, is_set=False):
if k != slice(None):
copy_needed = True
elif isinstance(k, NumPyThunk):
# the very first time we get cunumeric array, record
# start_index
if start_index == -1:
start_index = dim + shift
if (start_index - last_index) > 1:
transpose_needed = True
last_index = dim + shift
transpose_indices += (dim + shift,)
else:
transpose_needed = transpose_needed or (
(dim + shift - last_index) > 1
)
transpose_indices += (dim + shift,)
last_index = dim + shift
if k.dtype == np.bool:
if k.shape[0] != self.shape[dim]:
if k.shape[0] != store.shape[dim]:
raise ValueError(
"boolean index did not match "
"indexed array along dimension "
Expand All @@ -500,14 +514,7 @@ def _create_indexing_array(self, key, is_set=False):
"Unsupported entry type passed to advanced",
"indexing operation",
)
if transpose_needed:
copy_needed = True
start_index = 0
post_indices = tuple(
i for i in range(store.ndim) if i not in transpose_indices
)
transpose_indices += post_indices
store = store.transpose(transpose_indices)

if copy_needed:
# after store is transformed we need to to return a copy of
# the store since Copy operation can't be done on
Expand Down Expand Up @@ -565,7 +572,7 @@ def _create_indexing_array(self, key, is_set=False):
# output regions when ND output regions are available
tuple_of_arrays = key.nonzero()
elif key.ndim < store.ndim:
output_arr = self._zip_indices(start_index, (key,))
output_arr = rhs._zip_indices(start_index, (key,))
return True, store, output_arr
else:
tuple_of_arrays = (self.runtime.to_deferred_array(key),)
Expand Down
8 changes: 8 additions & 0 deletions tests/index_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ def advanced_indexing():
res_num = x_num[..., [0, 1], 2]
assert np.array_equal(res, res_num)

res = x[:, [0, 1], :, -1]
res_num = x_num[:, [0, 1], :, -1]
assert np.array_equal(res, res_num)

res = x[:, [0, 1], :, 1:]
res_num = x_num[:, [0, 1], :, 1:]
assert np.array_equal(res, res_num)

# In-Place & Augmented Assignments via Advanced Indexing
# simple 1d case
y = np.array([0, -1, -2, -3, -4, -5])
Expand Down

0 comments on commit f3012c0

Please sign in to comment.