Skip to content

Commit

Permalink
fixing an issue when advanced indexing operation is performed on tran…
Browse files Browse the repository at this point in the history
…sformed store
  • Loading branch information
ipdemes committed Apr 6, 2022
1 parent f3012c0 commit 61e6847
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 25 deletions.
53 changes: 28 additions & 25 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,20 @@ def _zip_indices(self, start_index, arrays):

return output_arr

def copy_store(self, store):
store_to_copy = DeferredArray(
self.runtime,
base=store,
dtype=self.dtype,
)
store_copy = self.runtime.create_empty_thunk(
store_to_copy.shape,
self.dtype,
inputs=[store_to_copy],
)
store_copy.copy(store_to_copy, deep=True)
return store_copy, store_copy.base

def _create_indexing_array(self, key, is_set=False):
store = self.base
rhs = self
Expand Down Expand Up @@ -514,26 +528,15 @@ def _create_indexing_array(self, key, is_set=False):
"Unsupported entry type passed to advanced",
"indexing operation",
)

if copy_needed:
if copy_needed or (not store._transform.bottom):
# after store is transformed we need to to return a copy of
# the store since Copy operation can't be done on
# the store with transformation
store_to_copy = DeferredArray(
self.runtime,
base=store,
dtype=self.dtype,
)
store_copy = self.runtime.create_empty_thunk(
store_to_copy.shape,
self.dtype,
inputs=[store_to_copy],
)
store_copy.copy(store_to_copy, deep=True)
rhs = store_copy
store = store_copy.base
rhs, store = self.copy_store(store)
else:
assert isinstance(key, NumPyThunk)
if not store._transform.bottom:
rhs, store = self.copy_store(store)
# the use case when index array ndim >1 and input array ndim ==1
if key.ndim > store.ndim:
if store.ndim != 1:
Expand All @@ -544,23 +547,23 @@ def _create_indexing_array(self, key, is_set=False):

# Handle the boolean array case
if key.dtype == np.bool:
if key.shape == self.shape:
out_dtype = self.dtype
if key.shape == rhs.shape:
out_dtype = rhs.dtype
if is_set:
N = self.ndim
out_dtype = self.runtime.get_point_type(N)
N = rhs.ndim
out_dtype = rhs.runtime.get_point_type(N)

out = self.runtime.create_unbound_thunk(out_dtype)
task = self.context.create_task(
out = rhs.runtime.create_unbound_thunk(out_dtype)
task = rhs.context.create_task(
CuNumericOpCode.ADVANCED_INDEXING
)
task.add_output(out.base)
task.add_input(self.base)
task.add_input(rhs.base)
task.add_input(key.base)
task.add_scalar_arg(is_set, bool)
task.add_alignment(self.base, key.base)
task.add_alignment(rhs.base, key.base)
task.add_broadcast(
self.base, axes=tuple(range(1, len(self.shape)))
rhs.base, axes=tuple(range(1, len(rhs.shape)))
)
task.add_broadcast(
key.base, axes=tuple(range(1, len(key.shape)))
Expand All @@ -575,7 +578,7 @@ def _create_indexing_array(self, key, is_set=False):
output_arr = rhs._zip_indices(start_index, (key,))
return True, store, output_arr
else:
tuple_of_arrays = (self.runtime.to_deferred_array(key),)
tuple_of_arrays = (rhs.runtime.to_deferred_array(key),)

if len(tuple_of_arrays) > rhs.ndim:
raise TypeError("Advanced indexing dimension mismatch")
Expand Down
10 changes: 10 additions & 0 deletions tests/index_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,16 @@ def advanced_indexing():
x_num[indx0_num, indx1_num] = 2.0
assert np.array_equal(x, x_num)

# use case when advanced indexing is called on a transformed array:
print("advanced indexing test 11")
z = z[:, 1:]
z_num = z_num[:, 1:]
indx = np.array([1, 1])
indx_num = num.array(indx)
res = z[indx]
res_num = z_num[indx_num]
assert np.array_equal(res, res_num)

# we do less than LEGATE_MAX_DIM becasue the dimension will be increased by
# 1 when passig 2d index array
for ndim in range(2, LEGATE_MAX_DIM):
Expand Down

0 comments on commit 61e6847

Please sign in to comment.