diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 230066487..441ea4362 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -434,6 +434,20 @@ def _zip_indices(self, start_index, arrays): return output_arr + def copy_store(self, store): + store_to_copy = DeferredArray( + self.runtime, + base=store, + dtype=self.dtype, + ) + store_copy = self.runtime.create_empty_thunk( + store_to_copy.shape, + self.dtype, + inputs=[store_to_copy], + ) + store_copy.copy(store_to_copy, deep=True) + return store_copy, store_copy.base + def _create_indexing_array(self, key, is_set=False): store = self.base rhs = self @@ -514,26 +528,15 @@ def _create_indexing_array(self, key, is_set=False): "Unsupported entry type passed to advanced", "indexing operation", ) - - if copy_needed: + if copy_needed or (not store._transform.bottom): # after store is transformed we need to to return a copy of # the store since Copy operation can't be done on # the store with transformation - store_to_copy = DeferredArray( - self.runtime, - base=store, - dtype=self.dtype, - ) - store_copy = self.runtime.create_empty_thunk( - store_to_copy.shape, - self.dtype, - inputs=[store_to_copy], - ) - store_copy.copy(store_to_copy, deep=True) - rhs = store_copy - store = store_copy.base + rhs, store = self.copy_store(store) else: assert isinstance(key, NumPyThunk) + if not store._transform.bottom: + rhs, store = self.copy_store(store) # the use case when index array ndim >1 and input array ndim ==1 if key.ndim > store.ndim: if store.ndim != 1: @@ -544,23 +547,23 @@ def _create_indexing_array(self, key, is_set=False): # Handle the boolean array case if key.dtype == np.bool: - if key.shape == self.shape: - out_dtype = self.dtype + if key.shape == rhs.shape: + out_dtype = rhs.dtype if is_set: - N = self.ndim - out_dtype = self.runtime.get_point_type(N) + N = rhs.ndim + out_dtype = rhs.runtime.get_point_type(N) - out = self.runtime.create_unbound_thunk(out_dtype) - task = self.context.create_task( + out = rhs.runtime.create_unbound_thunk(out_dtype) + task = rhs.context.create_task( CuNumericOpCode.ADVANCED_INDEXING ) task.add_output(out.base) - task.add_input(self.base) + task.add_input(rhs.base) task.add_input(key.base) task.add_scalar_arg(is_set, bool) - task.add_alignment(self.base, key.base) + task.add_alignment(rhs.base, key.base) task.add_broadcast( - self.base, axes=tuple(range(1, len(self.shape))) + rhs.base, axes=tuple(range(1, len(rhs.shape))) ) task.add_broadcast( key.base, axes=tuple(range(1, len(key.shape))) @@ -575,7 +578,7 @@ def _create_indexing_array(self, key, is_set=False): output_arr = rhs._zip_indices(start_index, (key,)) return True, store, output_arr else: - tuple_of_arrays = (self.runtime.to_deferred_array(key),) + tuple_of_arrays = (rhs.runtime.to_deferred_array(key),) if len(tuple_of_arrays) > rhs.ndim: raise TypeError("Advanced indexing dimension mismatch") diff --git a/tests/index_routines.py b/tests/index_routines.py index a3c0231a2..033e9bdd3 100644 --- a/tests/index_routines.py +++ b/tests/index_routines.py @@ -283,6 +283,16 @@ def advanced_indexing(): x_num[indx0_num, indx1_num] = 2.0 assert np.array_equal(x, x_num) + # use case when advanced indexing is called on a transformed array: + print("advanced indexing test 11") + z = z[:, 1:] + z_num = z_num[:, 1:] + indx = np.array([1, 1]) + indx_num = num.array(indx) + res = z[indx] + res_num = z_num[indx_num] + assert np.array_equal(res, res_num) + # we do less than LEGATE_MAX_DIM becasue the dimension will be increased by # 1 when passig 2d index array for ndim in range(2, LEGATE_MAX_DIM):