diff --git a/cunumeric/array.py b/cunumeric/array.py index 64784d59a..cd14eda7c 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -920,12 +920,8 @@ def _convert_key(self, key: Any, first: bool = True) -> Any: key = convert_to_cunumeric_ndarray(key) if key.dtype != bool and not np.issubdtype(key.dtype, np.integer): raise TypeError("index arrays should be int or bool type") - if key.dtype != bool and key.dtype != np.int64: - runtime.warn( - "converting index array to int64 type", - category=RuntimeWarning, - ) - key = key.astype(np.int64) + if key.dtype != bool: + key = key._warn_and_convert(np.dtype(np.int64)) return key._thunk @@ -2104,12 +2100,8 @@ def compress( raise ValueError( "Dimension mismatch: condition must be a 1D array" ) - if condition.dtype != bool: - runtime.warn( - "converting condition to bool type", - category=RuntimeWarning, - ) - condition = condition.astype(bool) + + condition = condition._warn_and_convert(np.dtype(bool)) if axis is None: axis = 0 @@ -2476,6 +2468,62 @@ def diagonal( raise ValueError("Either axis1/axis2 or axes must be supplied") return self._diag_helper(offset=offset, axes=axes, extract=extract) + @add_boilerplate("indices", "values") + def put( + self, indices: ndarray, values: ndarray, mode: str = "raise" + ) -> None: + """ + Replaces specified elements of the array with given values. + + Refer to :func:`cunumeric.put` for full documentation. + + See Also + -------- + cunumeric.put : equivalent function + + Availability + -------- + Multiple GPUs, Multiple CPUs + + """ + + if values.size == 0 or indices.size == 0 or self.size == 0: + return + + if mode not in ("raise", "wrap", "clip"): + raise ValueError( + "mode must be one of 'clip', 'raise', or 'wrap' " + f"(got {mode})" + ) + + if mode == "wrap": + indices = indices % self.size + elif mode == "clip": + indices = indices.clip(0, self.size - 1) + + indices = indices._warn_and_convert(np.dtype(np.int64)) + values = values._warn_and_convert(self.dtype) + + if indices.ndim > 1: + indices = indices.ravel() + + if self.shape == (): + if mode == "raise": + if indices.min() < -1 or indices.max() > 0: + raise ValueError("Indices out of bounds") + if values.shape == (): + v = values + else: + v = values[0] + self._thunk.copy(v._thunk, deep=False) + return + + # call _wrap on the values if they need to be wrapped + if values.ndim != indices.ndim or values.size != indices.size: + values = values._wrap(indices.size) + + self._thunk.put(indices._thunk, values._thunk) + @add_boilerplate() def trace( self, @@ -3822,6 +3870,16 @@ def _maybe_convert(self, dtype: np.dtype[Any], hints: Any) -> ndarray: copy._thunk.convert(self._thunk) return copy + def _warn_and_convert(self, dtype: np.dtype[Any]) -> ndarray: + if self.dtype != dtype: + runtime.warn( + f"converting array to {dtype} type", + category=RuntimeWarning, + ) + return self.astype(dtype) + else: + return self + # For performing normal/broadcast unary operations @classmethod def _perform_unary_op( diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 54f481977..3bb5c4db7 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -796,10 +796,16 @@ def _broadcast(self, shape: NdShape) -> Any: return result - def _convert_future_to_regionfield(self) -> DeferredArray: + def _convert_future_to_regionfield( + self, change_shape: bool = False + ) -> DeferredArray: + if change_shape and self.shape == (): + shape: NdShape = (1,) + else: + shape = self.shape store = self.context.create_store( self.dtype, - shape=self.shape, + shape=shape, optimize_scalar=False, ) thunk_copy = DeferredArray( @@ -1679,6 +1685,60 @@ def _diag_helper( task.execute() + @auto_convert("indices", "values") + def put(self, indices: Any, values: Any) -> None: + + if indices.base.kind == Future or indices.base.transformed: + change_shape = indices.base.kind == Future + indices = indices._convert_future_to_regionfield(change_shape) + if values.base.kind == Future or values.base.transformed: + change_shape = values.base.kind == Future + values = values._convert_future_to_regionfield(change_shape) + + if self.base.kind == Future or self.base.transformed: + change_shape = self.base.kind == Future + self_tmp = self._convert_future_to_regionfield(change_shape) + else: + self_tmp = self + + assert indices.size == values.size + + # first, we create indirect array with PointN type that + # (indices.size,) shape and is used to copy data from values + # to the target ND array (self) + N = self_tmp.ndim + pointN_dtype = self.runtime.get_point_type(N) + indirect = cast( + DeferredArray, + self.runtime.create_empty_thunk( + shape=indices.shape, + dtype=pointN_dtype, + inputs=[indices], + ), + ) + + shape = self_tmp.shape + task = self.context.create_task(CuNumericOpCode.WRAP) + task.add_output(indirect.base) + task.add_scalar_arg(shape, (ty.int64,)) + task.add_scalar_arg(True, bool) # has_input + task.add_input(indices.base) + task.add_alignment(indices.base, indirect.base) + task.throws_exception(IndexError) + task.execute() + if indirect.base.kind == Future: + indirect = indirect._convert_future_to_regionfield() + + copy = self.context.create_copy() + copy.set_target_indirect_out_of_range(False) + copy.add_input(values.base) + copy.add_target_indirect(indirect.base) + copy.add_output(self_tmp.base) + copy.execute() + + if self_tmp is not self: + self.copy(self_tmp, deep=True) + # Create an identity array with the ones offset from the diagonal by k def eye(self, k: int) -> None: assert self.ndim == 2 # Only 2-D arrays should be here @@ -2896,6 +2956,7 @@ def unary_op( args: Any, multiout: Optional[Any] = None, ) -> None: + lhs = self.base rhs = src._broadcast(lhs.shape) @@ -3355,7 +3416,8 @@ def unpackbits( @auto_convert("src") def _wrap(self, src: Any, new_len: int) -> None: if src.base.kind == Future or src.base.transformed: - src = src._convert_future_to_regionfield() + change_shape = src.base.kind == Future + src = src._convert_future_to_regionfield(change_shape) # first, we create indirect array with PointN type that # (len,) shape and is used to copy data from original array @@ -3374,6 +3436,7 @@ def _wrap(self, src: Any, new_len: int) -> None: task = self.context.create_task(CuNumericOpCode.WRAP) task.add_output(indirect.base) task.add_scalar_arg(src.shape, (ty.int64,)) + task.add_scalar_arg(False, bool) # has_input task.execute() copy = self.context.create_copy() diff --git a/cunumeric/eager.py b/cunumeric/eager.py index fdb8f7989..b8cb36ecd 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -620,6 +620,13 @@ def _diag_helper( axes = tuple(range(ndims - naxes, ndims)) self.array = diagonal_reference(rhs.array, axes) + def put(self, indices: Any, values: Any) -> None: + self.check_eager_args(indices, values) + if self.deferred is not None: + self.deferred.put(indices, values) + else: + np.put(self.array, indices.array, values.array) + def eye(self, k: int) -> None: if self.deferred is not None: self.deferred.eye(k) diff --git a/cunumeric/module.py b/cunumeric/module.py index 69647b3cb..0a4e97a5a 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -2410,12 +2410,7 @@ def repeat(a: ndarray, repeats: Any, axis: Optional[int] = None) -> ndarray: # repeats is an array else: # repeats should be integer type - if repeats.dtype != np.int64: - runtime.warn( - "converting repeats to an integer type", - category=RuntimeWarning, - ) - repeats = repeats.astype(np.int64) + repeats = repeats._warn_and_convert(np.int64) if repeats.shape[0] != array.shape[axis]: raise ValueError("incorrect shape of repeats array") result = array._thunk.repeat( @@ -3473,6 +3468,44 @@ def diagonal( ) +@add_boilerplate("a", "indices", "values") +def put( + a: ndarray, indices: ndarray, values: ndarray, mode: str = "raise" +) -> None: + """ + Replaces specified elements of an array with given values. + The indexing works as if the target array is first flattened. + + Parameters + ---------- + a : array_like + Array to put data into + indices : array_like + Target indices, interpreted as integers. + WARNING: In case there are repeated entries in the + indices array, Legate doesn't guarantee the order in + which values are updated. + + values : array_like + Values to place in `a` at target indices. If values array is shorter + than indices, it will be repeated as necessary. + mode : {'raise', 'wrap', 'clip'}, optional + Specifies how out-of-bounds indices will behave. + 'raise' : raise an error. + 'wrap' : wrap around. + 'clip' : clip to the range. + + See Also + -------- + numpy.put + + Availability + -------- + Multiple GPUs, Multiple CPUs + """ + a.put(indices=indices, values=values, mode=mode) + + @add_boilerplate("a", "val") def fill_diagonal(a: ndarray, val: ndarray, wrap: bool = False) -> None: """ diff --git a/cunumeric/thunk.py b/cunumeric/thunk.py index bdc773aeb..e1f1dab77 100644 --- a/cunumeric/thunk.py +++ b/cunumeric/thunk.py @@ -197,6 +197,10 @@ def _diag_helper( ) -> None: ... + @abstractmethod + def put(self, indices: Any, values: Any) -> None: + ... + @abstractmethod def eye(self, k: int) -> None: ... diff --git a/docs/cunumeric/source/api/indexing.rst b/docs/cunumeric/source/api/indexing.rst index 1023ed1d4..1ace111d4 100644 --- a/docs/cunumeric/source/api/indexing.rst +++ b/docs/cunumeric/source/api/indexing.rst @@ -43,5 +43,6 @@ Inserting data into arrays :toctree: generated/ fill_diagonal + put put_along_axis place diff --git a/src/cunumeric/index/wrap.cc b/src/cunumeric/index/wrap.cc index 33dfcfe4b..a5483cbdd 100644 --- a/src/cunumeric/index/wrap.cc +++ b/src/cunumeric/index/wrap.cc @@ -24,28 +24,30 @@ using namespace legate; template struct WrapImplBody { + template void operator()(const AccessorWO, 1>& out, const Pitches<0>& pitches_out, const Rect<1>& out_rect, const Pitches& pitches_in, const Rect& in_rect, - const bool dense) const + const bool dense, + const IND& indices) const { const int64_t start = out_rect.lo[0]; const int64_t end = out_rect.hi[0]; const auto in_volume = in_rect.volume(); if (dense) { - int64_t out_idx = 0; - auto outptr = out.ptr(out_rect); + auto outptr = out.ptr(out_rect); for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = i % in_volume; + check_idx(i, in_volume, indices); + const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); - outptr[out_idx] = point; - out_idx++; + outptr[i - start] = point; } } else { for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = i % in_volume; + check_idx(i, in_volume, indices); + const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); out[i] = point; } diff --git a/src/cunumeric/index/wrap.cu b/src/cunumeric/index/wrap.cu index 0f118eadf..af81073d6 100644 --- a/src/cunumeric/index/wrap.cu +++ b/src/cunumeric/index/wrap.cu @@ -23,7 +23,28 @@ namespace cunumeric { using namespace Legion; using namespace legate; -template +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + check_kernel(Output out, + const AccessorRO indices, + const int64_t start, + const int64_t volume, + const int64_t in_volume, + const int64_t iters) +{ + bool value = false; + for (size_t i = 0; i < iters; i++) { + const auto idx = (i * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= volume) break; + auto index_tmp = indices[idx + start]; + int64_t index = index_tmp < 0 ? index_tmp + in_volume : index_tmp; + bool val = (index < 0 || index >= in_volume); + SumReduction::fold(value, val); + } + reduce_output(out, value); +} + +template __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) wrap_kernel(const AccessorWO, 1> out, const int64_t start, @@ -32,53 +53,93 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) const Point<1> out_lo, const Pitches pitches_in, const Point in_lo, - const size_t in_volume) + const size_t in_volume, + const IND indices) { const auto idx = global_tid_1d(); if (idx >= volume) return; - const int64_t input_idx = (idx + start) % in_volume; + const int64_t input_idx = compute_idx((idx + start), in_volume, indices); auto out_p = pitches_out.unflatten(idx, out_lo); auto p = pitches_in.unflatten(input_idx, in_lo); out[out_p] = p; } -template +template __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) wrap_kernel_dense(Point* out, const int64_t start, const int64_t volume, const Pitches pitches_in, const Point in_lo, - const size_t in_volume) + const size_t in_volume, + const IND indices) { const auto idx = global_tid_1d(); if (idx >= volume) return; - const int64_t input_idx = (idx + start) % in_volume; + const int64_t input_idx = compute_idx((idx + start), in_volume, indices); auto p = pitches_in.unflatten(input_idx, in_lo); out[idx] = p; } +// don't do anything when indices is a boolean +void check_out_of_bounds(const bool& indices, + const int64_t start, + const int64_t volume, + const int64_t volume_in, + cudaStream_t stream) +{ +} + +void check_out_of_bounds(const AccessorRO& indices, + const int64_t start, + const int64_t volume, + const int64_t volume_in, + cudaStream_t stream) +{ + const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(bool); + DeviceScalarReductionBuffer> out_of_bounds(stream); + + if (blocks >= MAX_REDUCTION_CTAS) { + const size_t iters = (blocks + MAX_REDUCTION_CTAS - 1) / MAX_REDUCTION_CTAS; + check_kernel<<>>( + out_of_bounds, indices, start, volume, volume_in, iters); + } else { + check_kernel<<>>( + out_of_bounds, indices, start, volume, volume_in, 1); + } + CHECK_CUDA_STREAM(stream); + + bool res = out_of_bounds.read(stream); + if (res) throw legate::TaskException("index is out of bounds in index array"); +} + template struct WrapImplBody { + template void operator()(const AccessorWO, 1>& out, const Pitches<0>& pitches_out, const Rect<1>& out_rect, const Pitches& pitches_in, const Rect& in_rect, - const bool dense) const + const bool dense, + const IND& indices) const { auto stream = get_cached_stream(); const int64_t start = out_rect.lo[0]; const int64_t volume = out_rect.volume(); const auto in_volume = in_rect.volume(); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + + check_out_of_bounds(indices, start, volume, in_volume, stream); + if (dense) { auto outptr = out.ptr(out_rect); - wrap_kernel_dense<<>>( - outptr, start, volume, pitches_in, in_rect.lo, in_volume); + wrap_kernel_dense<<>>( + outptr, start, volume, pitches_in, in_rect.lo, in_volume, indices); } else { - wrap_kernel<<>>( - out, start, volume, pitches_out, out_rect.lo, pitches_in, in_rect.lo, in_volume); + wrap_kernel<<>>( + out, start, volume, pitches_out, out_rect.lo, pitches_in, in_rect.lo, in_volume, indices); } CHECK_CUDA_STREAM(stream); } diff --git a/src/cunumeric/index/wrap.h b/src/cunumeric/index/wrap.h index 91c3f2326..181a9b97c 100644 --- a/src/cunumeric/index/wrap.h +++ b/src/cunumeric/index/wrap.h @@ -25,6 +25,8 @@ struct WrapArgs { // copy information from original array to the // `wrapped` one const Legion::DomainPoint shape; // shape of the original array + const bool has_input; + const Array& in = Array(); }; class WrapTask : public CuNumericTask { @@ -41,4 +43,31 @@ class WrapTask : public CuNumericTask { #endif }; +__CUDA_HD__ static int64_t compute_idx(const int64_t i, const int64_t volume, const bool&) +{ + return i % volume; +} + +__CUDA_HD__ static int64_t compute_idx(const int64_t i, + const int64_t volume, + const legate::AccessorRO& indices) +{ + int64_t idx = indices[i]; + int64_t index = idx < 0 ? idx + volume : idx; + return index; +} + +static void check_idx(const int64_t i, + const int64_t volume, + const legate::AccessorRO& indices) +{ + int64_t idx = indices[i]; + int64_t index = idx < 0 ? idx + volume : idx; + if (index < 0 || index >= volume) + throw legate::TaskException("index is out of bounds in index array"); +} +static void check_idx(const int64_t i, const int64_t volume, const bool&) +{ + // don't do anything when wrapping indices +} } // namespace cunumeric diff --git a/src/cunumeric/index/wrap_omp.cc b/src/cunumeric/index/wrap_omp.cc index f95e9123c..531592df9 100644 --- a/src/cunumeric/index/wrap_omp.cc +++ b/src/cunumeric/index/wrap_omp.cc @@ -24,12 +24,14 @@ using namespace legate; template struct WrapImplBody { + template void operator()(const AccessorWO, 1>& out, const Pitches<0>& pitches_out, const Rect<1>& out_rect, const Pitches& pitches_in, const Rect& in_rect, - const bool dense) const + const bool dense, + const IND& indices) const { const int64_t start = out_rect.lo[0]; const int64_t end = out_rect.hi[0]; @@ -38,14 +40,16 @@ struct WrapImplBody { auto outptr = out.ptr(out_rect); #pragma omp parallel for schedule(static) for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = i % in_volume; + check_idx(i, in_volume, indices); + const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); outptr[i - start] = point; } } else { #pragma omp parallel for schedule(static) for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = i % in_volume; + check_idx(i, in_volume, indices); + const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); out[i] = point; } diff --git a/src/cunumeric/index/wrap_template.inl b/src/cunumeric/index/wrap_template.inl index 46885f24e..093f5f5b1 100644 --- a/src/cunumeric/index/wrap_template.inl +++ b/src/cunumeric/index/wrap_template.inl @@ -60,16 +60,30 @@ struct WrapImpl { assert(volume_in != 0); #endif - WrapImplBody()(out, pitches_out, out_rect, pitches_in, input_rect, dense); + if (args.has_input) { + auto in_rect = args.in.shape<1>(); + auto in = args.in.read_accessor(in_rect); // input should be always integer type +#ifdef DEBUG_CUNUMERIC + assert(in_rect == out_rect); +#endif + WrapImplBody()(out, pitches_out, out_rect, pitches_in, input_rect, dense, in); + + } else { + bool tmp = false; + WrapImplBody()(out, pitches_out, out_rect, pitches_in, input_rect, dense, tmp); + } // else } }; template static void wrap_template(TaskContext& context) { - auto shape = context.scalars()[0].value(); - int dim = shape.dim; - WrapArgs args{context.outputs()[0], shape}; + auto shape = context.scalars()[0].value(); + int dim = shape.dim; + bool has_input = context.scalars()[1].value(); + Array tmp_array = Array(); + WrapArgs args{ + context.outputs()[0], shape, has_input, has_input ? context.inputs()[0] : tmp_array}; dim_dispatch(dim, WrapImpl{}, args); } diff --git a/src/cunumeric/index/zip.cu b/src/cunumeric/index/zip.cu index 8bdfcd3f0..82d162126 100644 --- a/src/cunumeric/index/zip.cu +++ b/src/cunumeric/index/zip.cu @@ -28,15 +28,15 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) const Buffer, 1> index_arrays, const Rect rect, const Pitches pitches, - size_t volume, - DomainPoint shape, + const size_t volume, + const DomainPoint shape, std::index_sequence) { const size_t idx = global_tid_1d(); if (idx >= volume) return; auto p = pitches.unflatten(idx, rect.lo); Legion::Point new_point; - for (size_t i = 0; i < N; i++) { new_point[i] = compute_idx(index_arrays[i][p], shape[i]); } + for (size_t i = 0; i < N; i++) { new_point[i] = compute_idx_cuda(index_arrays[i][p], shape[i]); } out[p] = new_point; } @@ -45,14 +45,16 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) zip_kernel_dense(Point* out, const Buffer index_arrays, const Rect rect, - size_t volume, - DomainPoint shape, + const size_t volume, + const DomainPoint shape, std::index_sequence) { const size_t idx = global_tid_1d(); if (idx >= volume) return; Legion::Point new_point; - for (size_t i = 0; i < N; i++) { new_point[i] = compute_idx(index_arrays[i][idx], shape[i]); } + for (size_t i = 0; i < N; i++) { + new_point[i] = compute_idx_cuda(index_arrays[i][idx], shape[i]); + } out[idx] = new_point; } @@ -62,11 +64,11 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) const Buffer, 1> index_arrays, const Rect rect, const Pitches pitches, - int narrays, - size_t volume, - int64_t key_dim, - int64_t start_index, - DomainPoint shape) + const int64_t narrays, + const size_t volume, + const int64_t key_dim, + const int64_t start_index, + const DomainPoint shape) { const size_t idx = global_tid_1d(); if (idx >= volume) return; @@ -74,7 +76,7 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) Legion::Point new_point; for (size_t i = 0; i < start_index; i++) { new_point[i] = p[i]; } for (size_t i = 0; i < narrays; i++) { - new_point[start_index + i] = compute_idx(index_arrays[i][p], shape[start_index + i]); + new_point[start_index + i] = compute_idx_cuda(index_arrays[i][p], shape[start_index + i]); } for (size_t i = (start_index + narrays); i < N; i++) { int64_t j = key_dim + i - narrays; @@ -83,10 +85,63 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[p] = new_point; } +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + check_kernel(Output out, + const Buffer, 1> index_arrays, + const int64_t volume, + const int64_t iters, + const Rect rect, + const Pitches pitches, + const int64_t narrays, + const int64_t start_index, + const DomainPoint shape) +{ + bool value = false; + for (size_t i = 0; i < iters; i++) { + const auto idx = (i * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= volume) break; + auto p = pitches.unflatten(idx, rect.lo); + for (size_t n = 0; n < narrays; n++) { + const int64_t extent = shape[start_index + n]; + coord_t index = index_arrays[n][p] < 0 ? index_arrays[n][p] + extent : index_arrays[n][p]; + bool val = (index < 0 || index >= extent); + SumReduction::fold(value, val); + } // for n + } + reduce_output(out, value); +} + template struct ZipImplBody { using VAL = int64_t; + void check_out_of_bounds(const Buffer, 1>& index_arrays, + const int64_t volume, + const Rect& rect, + const Pitches& pitches, + const int64_t narrays, + const int64_t start_index, + const DomainPoint& shape, + cudaStream_t stream) const + { + const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(bool); + DeviceScalarReductionBuffer> out_of_bounds(stream); + if (blocks >= MAX_REDUCTION_CTAS) { + const size_t iters = (blocks + MAX_REDUCTION_CTAS - 1) / MAX_REDUCTION_CTAS; + check_kernel<<>>( + out_of_bounds, index_arrays, volume, iters, rect, pitches, narrays, start_index, shape); + } else { + check_kernel<<>>( + out_of_bounds, index_arrays, volume, 1, rect, pitches, narrays, start_index, shape); + } + CHECK_CUDA_STREAM(stream); + + bool res = out_of_bounds.read(stream); + if (res) throw legate::TaskException("index is out of bounds in index array"); + } + template void operator()(const AccessorWO, DIM>& out, const std::vector>& index_arrays, @@ -101,19 +156,23 @@ struct ZipImplBody { auto stream = get_cached_stream(); const size_t volume = rect.volume(); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + + auto index_buf = + create_buffer, 1>(index_arrays.size(), Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < index_arrays.size(); ++idx) index_buf[idx] = index_arrays[idx]; + check_out_of_bounds( + index_buf, volume, rect, pitches, index_arrays.size(), start_index, shape, stream); + if (index_arrays.size() == N) { if (dense) { - auto index_buf = + auto index_buf_dense = create_buffer(index_arrays.size(), Memory::Kind::Z_COPY_MEM); for (uint32_t idx = 0; idx < index_arrays.size(); ++idx) { - index_buf[idx] = index_arrays[idx].ptr(rect); + index_buf_dense[idx] = index_arrays[idx].ptr(rect); } zip_kernel_dense<<>>( - out.ptr(rect), index_buf, rect, volume, shape, std::make_index_sequence()); + out.ptr(rect), index_buf_dense, rect, volume, shape, std::make_index_sequence()); } else { - auto index_buf = - create_buffer, 1>(index_arrays.size(), Memory::Kind::Z_COPY_MEM); - for (uint32_t idx = 0; idx < index_arrays.size(); ++idx) index_buf[idx] = index_arrays[idx]; zip_kernel<<>>( out, index_buf, rect, pitches, volume, shape, std::make_index_sequence()); } @@ -121,9 +180,6 @@ struct ZipImplBody { #ifdef DEBUG_CUNUMERIC assert(index_arrays.size() < N); #endif - auto index_buf = - create_buffer, 1>(index_arrays.size(), Memory::Kind::Z_COPY_MEM); - for (uint32_t idx = 0; idx < index_arrays.size(); ++idx) index_buf[idx] = index_arrays[idx]; int num_arrays = index_arrays.size(); zip_kernel<<>>( out, index_buf, rect, pitches, num_arrays, volume, key_dim, start_index, shape); diff --git a/src/cunumeric/index/zip.h b/src/cunumeric/index/zip.h index 61a87104c..ffa5941d5 100644 --- a/src/cunumeric/index/zip.h +++ b/src/cunumeric/index/zip.h @@ -51,4 +51,10 @@ constexpr coord_t compute_idx(coord_t index, coord_t extent) return new_index; } +constexpr coord_t compute_idx_cuda(coord_t index, coord_t extent) +{ + coord_t new_index = index < 0 ? index + extent : index; + return new_index; +} + } // namespace cunumeric diff --git a/tests/integration/test_put.py b/tests/integration/test_put.py new file mode 100644 index 000000000..1c69a705b --- /dev/null +++ b/tests/integration/test_put.py @@ -0,0 +1,131 @@ +# Copyright 2022 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from legate.core import LEGATE_MAX_DIM +from utils.generators import mk_seq_array + +import cunumeric as num + + +@pytest.mark.parametrize("mode", ("wrap", "clip")) +def test_scalar(mode): + # testing the case when indices is a scalar + x = mk_seq_array(np, (3, 4, 5)) + x_num = mk_seq_array(num, (3, 4, 5)) + values = mk_seq_array(np, (6,)) * 10 + values_num = num.array(values) + + np.put(x, 0, values) + num.put(x_num, 0, values_num) + assert np.array_equal(x_num, x) + + np.put(x, 1, -10, mode) + num.put(x_num, 1, -10, mode) + assert np.array_equal(x_num, x) + + # checking transformed array + y = x[:1] + y_num = x_num[:1] + np.put(y, 0, values) + num.put(y_num, 0, values_num) + assert np.array_equal(x_num, x) + + x = np.zeros(1) + x_num = num.zeros(1) + np.put(x, np.arange(4), np.ones(4), mode="clip") + num.put(x_num, num.arange(4), num.ones(4), mode="clip") + assert np.array_equal(x_num, x) + + x = np.arange(5) + x_num = num.array(x) + indices = np.array([1, 4]) + indices_num = num.array(indices) + np.put(x, indices, 10) + num.put(x_num, indices_num, 10) + assert np.array_equal(x_num, x) + + x = np.zeros(()) + x_num = num.zeros(()) + np.put(x, 0, 1) + num.put(x_num, 0, 1) + assert np.array_equal(x_num, x) + + x = np.zeros(()) + x_num = num.zeros(()) + np.put(x, [0], 1) + num.put(x_num, [0], 1) + assert np.array_equal(x_num, x) + + x = np.zeros(()) + x_num = num.zeros(()) + np.put(x, [0], [1]) + num.put(x_num, [0], [1]) + assert np.array_equal(x_num, x) + + +def test_indices_type_convert(): + x = mk_seq_array(np, (3, 4, 5)) + x_num = mk_seq_array(num, (3, 4, 5)) + values = mk_seq_array(np, (6,)) * 10 + values_num = num.array(values) + indices = np.array([-2, 2], dtype=np.int32) + indices_num = num.array(indices) + np.put(x, indices, values) + num.put(x_num, indices_num, values_num) + assert np.array_equal(x_num, x) + + +@pytest.mark.parametrize("ndim", range(1, LEGATE_MAX_DIM + 1)) +def test_ndim(ndim): + shape = (5,) * ndim + np_arr = mk_seq_array(np, shape) + num_arr = mk_seq_array(num, shape) + shape_in = (3,) * ndim + np_indices = mk_seq_array(np, shape_in) + num_indices = mk_seq_array(num, shape_in) + shape_val = (2,) * ndim + np_values = mk_seq_array(np, shape_val) * 10 + num_values = mk_seq_array(num, shape_val) * 10 + + np.put(np_arr, np_indices, np_values) + num.put(num_arr, num_indices, num_values) + assert np.array_equal(np_arr, num_arr) + + +INDICES = ([1, 2, 3, 100], [[2, 1], [3, 100]], [1], [100]) + + +@pytest.mark.parametrize("ndim", range(1, LEGATE_MAX_DIM + 1)) +@pytest.mark.parametrize("mode", ("wrap", "clip")) +@pytest.mark.parametrize("indices", INDICES) +def test_ndim_mode(ndim, mode, indices): + shape = (5,) * ndim + np_arr = mk_seq_array(np, shape) + num_arr = mk_seq_array(num, shape) + shape_val = (2,) * ndim + np_values = mk_seq_array(np, shape_val) * 10 + num_values = mk_seq_array(num, shape_val) * 10 + + np.put(np_arr, indices, np_values, mode=mode) + num.put(num_arr, indices, num_values, mode=mode) + assert np.array_equal(np_arr, num_arr) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(sys.argv))