From d4b6d6967b0f2b4aac270f308d4ceca317381e6d Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Fri, 9 Sep 2022 10:12:47 -0600 Subject: [PATCH 01/33] towards implementing put --- cunumeric/array.py | 42 +++++++++++ cunumeric/deferred.py | 8 +-- cunumeric/module.py | 32 +++++++++ tests/integration/test_put.py | 127 ++++++++++++++++++++++++++++++++++ 4 files changed, 205 insertions(+), 4 deletions(-) create mode 100644 tests/integration/test_put.py diff --git a/cunumeric/array.py b/cunumeric/array.py index 9097429f8..bd463e4c0 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2465,6 +2465,48 @@ def diagonal( raise ValueError("Either axis1/axis2 or axes must be supplied") return self._diag_helper(offset=offset, axes=axes, extract=extract) + @add_boilerplate("indices", "values") + def put( + self, indices: ndarray, values: ndarray, mode: bool = "raise" + ) -> None: + """ + Set storage-indexed locations to corresponding values. + + See Also + -------- + numpy.put + + Availability + -------- + Multiple GPUs, Multiple CPUs + + """ + + if values.size == 0 or indices.size == 0: + return + + if mode not in ("raise", "wrap", "clip"): + raise ValueError("clipmode not understood") + if mode == "wrap": + indices = indices % self.size + elif mode == "clip": + if np.isscalar(indices): + if indices >= self.size: + indices = self.size - 1 + if indices < 0: + indices = 0 + else: + indices = indices.clip(0, self.size - 1) + + # call _wrap on the values if they need to be wrapped + if values.ndim > 1 or values.size < indices.size: + values = values._wrap(indices.size) + + if indices.ndim == 1 and self.ndim == 1: + self._thunk.set_item(indices._thunk, values._thunk) + else: + self._thunk.set_item(indices._thunk, values._thunk, put=True) + @add_boilerplate() def trace( self, diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index c089d90af..83c5e22ae 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -518,7 +518,7 @@ def _slice_store(k: slice, store: Store, dim: int) -> tuple[slice, Store]: return k, store def _create_indexing_array( - self, key: Any, is_set: bool = False + self, key: Any, is_set: bool = False, is_put: bool = False ) -> tuple[bool, Any, Any, Any]: store = self.base rhs = self @@ -859,17 +859,17 @@ def get_item(self, key: Any) -> NumPyThunk: return result @auto_convert([2]) - def set_item(self, key: Any, rhs: Any) -> None: + def set_item(self, key: Any, rhs: Any, is_put: bool = False) -> None: assert self.dtype == rhs.dtype # Check to see if this is advanced indexing or not - if is_advanced_indexing(key): + if is_advanced_indexing(key) or (is_put and key.ndim != self.ndim): # Create the indexing array ( copy_needed, lhs, index_array, self, - ) = self._create_indexing_array(key, True) + ) = self._create_indexing_array(key, True, is_put) if rhs.shape != index_array.shape: rhs_tmp = rhs._broadcast(index_array.base.shape) diff --git a/cunumeric/module.py b/cunumeric/module.py index 5981e06d7..3ebe09b7b 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -3446,6 +3446,38 @@ def diagonal( ) +@add_boilerplate("a", "indices", "values") +def put( + a: ndarray, indices: ndarray, values: ndarray, mode: str = "raise" +) -> ndarray: + """ + Set storage-indexed locations to corresponding values. + + Parameters + ---------- + a : array_like + Array to put data into + indices : array_like + Target indices, interpreted as integers. + values : array_like + Values to place in `a` at target indices. + mode : {'raise', 'wrap', 'clip'}, optional + Specifies how out-of-bounds indices will behave. + 'raise' : raise an error. + 'wrap' : wrap around. + 'clip' : clip to the range. + + See Also + -------- + numpy.put + + Availability + -------- + Multiple GPUs, Multiple CPUs + """ + return a.put(indices=indices, values=values, mode=mode) + + @add_boilerplate("a", "val") def fill_diagonal(a: ndarray, val: ndarray, wrap: bool = False) -> None: """ diff --git a/tests/integration/test_put.py b/tests/integration/test_put.py new file mode 100644 index 000000000..d01b4ec2f --- /dev/null +++ b/tests/integration/test_put.py @@ -0,0 +1,127 @@ +# Copyright 2022 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from utils.generators import mk_seq_array + +import cunumeric as num + +# from legate.core import LEGATE_MAX_DIM + + +def test_1d(): + x_np = np.arange(10) + x_num = num.arange(10) + i_np = np.array([1, 3, 5]) + i_num = num.array(i_np) + v_np = np.array([100]) + v_num = num.array(v_np) + np.put(x_np, i_np, v_np) + num.put(x_num, i_num, v_num) + assert np.array_equal(x_num, x_np) + + +def test_raise(): + + x = mk_seq_array(np, (3, 4, 5)) + x_num = mk_seq_array(num, (3, 4, 5)) + indices = mk_seq_array(np, (8,)) + indices_num = num.array(indices) + values = mk_seq_array(np, (6,)) * 10 + values_num = num.array(values) + + np.put(x, indices, values) + num.put(x_num, indices_num, values_num) + assert np.array_equal(x_num, x) + + +# def test_modes(): +# +# x = mk_seq_array(np, (3, 4, 5)) +# x_num = mk_seq_array(num, (3, 4, 5)) +# indices = mk_seq_array(np, (8,))*2 +# indices_num = num.array(indices) +# values = mk_seq_array(np, (6,))*10 +# values_num = num.array(values) +# +# np.put(x, indices, values, mode="clip") +# num.put(x_num, indices_num, values, mode="clip") +# assert np.array_equal(x_num, x) +# +# np.put(x, indices, values, mode="wrap") +# num.put(x_num, indices_num, values, mode="wrap") +# assert np.array_equal(x_num, x) +# +# def test_scalar(): +# # testing the case when indices is a scalar +# x = mk_seq_array(np, (3, 4, 5)) +# x_num = mk_seq_array(num, (3, 4, 5)) +# values = mk_seq_array(np, (6,))*10 +# values_num = num.array(values) +# +# np.put(x, 0, values) +# num.put(x_num, 0, values) +# assert np.array_equal(x_num, x) +# +# np.put(x, 1, -10, mode) +# num.put(x_num, 1, -10, mode) +# assert np.array_equal(x_num, x) +# +# +# def test_nd_indices(): +# x = mk_seq_array(np, (15)) +# x_num = mk_seq_array(num, (15)) +# indices = mk_seq_array(np, (3,2))*2 +# indices_num = num.array(indices) +# values = mk_seq_array(np, (2,2))*10 +# values_num = num.array(values) +# +# np.put(x, indices, values, mode) +# num.put(x_num, indices_num, values, mode) +# assert np.array_equal(x_num, x) +# +# @pytest.mark.parametrize("ndim", range(LEGATE_MAX_DIM + 1)) +# def test_ndim(ndim): +# shape = (5,) * ndim +# np_arr = mk_seq_array(np, shape) +# num_arr = mk_seq_array(num, shape) +# np_indices = mk_seq_array(np, (4,)) +# num_indices = mk_seq_array(num, (4,)) +# np_values = mk_seq_array(np, (2,))*10 +# num_values = mk_seq_array(num, (2,))*10 +# +# np.put(np_arr, np_indices, np_values) +# num.put(num_arr, num_indices, num_values) +# assert np.array_equal(np_arr, num_arr) +# +# np_indices = mk_seq_array(np, (8,)) +# num_indices = mk_seq_array(num, (8,)) +# np.put(np_arr, np_indices, np_values, mode="wrap") +# num.put(num_arr, num_indices,num_values, mode="wrap") +# assert np.array_equal(np_arr, num_arr) +# +# np_arr = mk_seq_array(np, shape) +# num_arr = mk_seq_array(num, shape) +# np_arr.put(np_indices,np_values, mode="clip") +# num_arr.put(num_indices, num_values, mode="clip") +# assert np.array_equal(np_arr, num_arr) +# +# return + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(sys.argv)) From 4dd090aa15d32685f5e7a7d8d7ad5b25f571e842 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 12 Sep 2022 14:42:43 -0600 Subject: [PATCH 02/33] fixing errors and improving test --- cunumeric/array.py | 27 +++++- cunumeric/deferred.py | 51 +++++++++- cunumeric/eager.py | 6 ++ cunumeric/thunk.py | 4 + src/cunumeric/index/wrap.cc | 29 ++++++ src/cunumeric/index/wrap.cu | 57 +++++++++++ src/cunumeric/index/wrap.h | 2 + src/cunumeric/index/wrap_omp.cc | 28 ++++++ src/cunumeric/index/wrap_template.inl | 29 +++++- tests/integration/test_put.py | 135 ++++++++------------------ 10 files changed, 262 insertions(+), 106 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index bd463e4c0..4c3c3350d 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2498,14 +2498,31 @@ def put( else: indices = indices.clip(0, self.size - 1) + if indices.dtype != np.int64: + runtime.warn( + "converting index array to int64 type", + category=RuntimeWarning, + ) + indices = indices.astype(np.int64) + + if indices.ndim > 1: + indices = indices.ravel() + + # in case size of the indices larger than size of the array + # some entries in the array will be updated 2 times. + # and Legate doesn't guarantee the order in which values + # are updated + if indices.size > self.size: + runtime.warn( + "size of indices is larger than source array which", + " might result in undefined behaviour", + category=RuntimeWarning, + ) # call _wrap on the values if they need to be wrapped - if values.ndim > 1 or values.size < indices.size: + if values.ndim != indices.ndim or values.size != indices.size: values = values._wrap(indices.size) - if indices.ndim == 1 and self.ndim == 1: - self._thunk.set_item(indices._thunk, values._thunk) - else: - self._thunk.set_item(indices._thunk, values._thunk, put=True) + self._thunk.put(indices._thunk, values._thunk) @add_boilerplate() def trace( diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 83c5e22ae..5df12918e 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -518,7 +518,7 @@ def _slice_store(k: slice, store: Store, dim: int) -> tuple[slice, Store]: return k, store def _create_indexing_array( - self, key: Any, is_set: bool = False, is_put: bool = False + self, key: Any, is_set: bool = False ) -> tuple[bool, Any, Any, Any]: store = self.base rhs = self @@ -859,17 +859,17 @@ def get_item(self, key: Any) -> NumPyThunk: return result @auto_convert([2]) - def set_item(self, key: Any, rhs: Any, is_put: bool = False) -> None: + def set_item(self, key: Any, rhs: Any) -> None: assert self.dtype == rhs.dtype # Check to see if this is advanced indexing or not - if is_advanced_indexing(key) or (is_put and key.ndim != self.ndim): + if is_advanced_indexing(key): # Create the indexing array ( copy_needed, lhs, index_array, self, - ) = self._create_indexing_array(key, True, is_put) + ) = self._create_indexing_array(key, True) if rhs.shape != index_array.shape: rhs_tmp = rhs._broadcast(index_array.base.shape) @@ -1657,6 +1657,48 @@ def _diag_helper( task.execute() + @auto_convert([1, 2]) + def put(self, indices: Any, values: Any) -> None: + if indices.base.kind == Future or indices.base.transformed: + indices = indices._convert_future_to_regionfield() + if values.base.kind == Future or values.base.transformed: + values = values._convert_future_to_regionfield() + if self.base.kind == Future or self.base.transformed: + self = self._convert_future_to_regionfield() + + assert indices.size == values.size + + # first, we create indirect array with PointN type that + # (indices.size,) shape and is used to copy data from values + # to the target ND array (self) + N = self.ndim + pointN_dtype = self.runtime.get_point_type(N) + indirect = cast( + DeferredArray, + self.runtime.create_empty_thunk( + shape=indices.shape, + dtype=pointN_dtype, + inputs=[indices], + ), + ) + + task = self.context.create_task(CuNumericOpCode.WRAP) + task.add_output(indirect.base) + task.add_scalar_arg(self.shape, (ty.int64,)) + task.add_scalar_arg(True, bool) # has_input + task.add_input(indices.base) + task.add_alignment(indices.base, indirect.base) + task.execute() + if indirect.base.kind == Future or indirect.base.transformed: + indirect = indirect._convert_future_to_regionfield() + + copy = self.context.create_copy() + copy.set_target_indirect_out_of_range(False) + copy.add_input(values.base) + copy.add_target_indirect(indirect.base) + copy.add_output(self.base) + copy.execute() + # Create an identity array with the ones offset from the diagonal by k def eye(self, k: int) -> None: assert self.ndim == 2 # Only 2-D arrays should be here @@ -3341,6 +3383,7 @@ def _wrap(self, src: Any, new_len: int) -> None: task = self.context.create_task(CuNumericOpCode.WRAP) task.add_output(indirect.base) task.add_scalar_arg(src.shape, (ty.int64,)) + task.add_scalar_arg(False, bool) # has_input task.execute() copy = self.context.create_copy() diff --git a/cunumeric/eager.py b/cunumeric/eager.py index c6f19fb6b..9e2e7e198 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -620,6 +620,12 @@ def _diag_helper( axes = tuple(range(ndims - naxes, ndims)) self.array = diagonal_reference(rhs.array, axes) + def put(self, indices: Any, values: Any) -> None: + if self.deferred is not None: + self.deferred.put(indices, values) + else: + np.put(self.array, indices.array, values.array) + def eye(self, k: int) -> None: if self.deferred is not None: self.deferred.eye(k) diff --git a/cunumeric/thunk.py b/cunumeric/thunk.py index 6230905af..2d79fe7e3 100644 --- a/cunumeric/thunk.py +++ b/cunumeric/thunk.py @@ -197,6 +197,10 @@ def _diag_helper( ) -> None: ... + @abstractmethod + def put(self, indices: Any, values: Any) -> None: + ... + @abstractmethod def eye(self, k: int) -> None: ... diff --git a/src/cunumeric/index/wrap.cc b/src/cunumeric/index/wrap.cc index 33dfcfe4b..b055cabc4 100644 --- a/src/cunumeric/index/wrap.cc +++ b/src/cunumeric/index/wrap.cc @@ -51,6 +51,35 @@ struct WrapImplBody { } } // else } + + // the version when input is specified + void operator()(const AccessorWO, 1>& out, + const Pitches<0>& pitches_out, + const Rect<1>& out_rect, + const Pitches& pitches_in, + const Rect& in_rect, + const bool dense, + const AccessorRO& indices) const + { + const int64_t start = out_rect.lo[0]; + const int64_t end = out_rect.hi[0]; + if (dense) { + int64_t out_idx = 0; + auto outptr = out.ptr(out_rect); + for (int64_t i = start; i <= end; i++) { + const int64_t input_idx = indices[i]; + auto point = pitches_in.unflatten(input_idx, in_rect.lo); + outptr[out_idx] = point; + out_idx++; + } + } else { + for (int64_t i = start; i <= end; i++) { + const int64_t input_idx = indices[i]; + auto point = pitches_in.unflatten(input_idx, in_rect.lo); + out[i] = point; + } + } // else + } }; /*static*/ void WrapTask::cpu_variant(TaskContext& context) diff --git a/src/cunumeric/index/wrap.cu b/src/cunumeric/index/wrap.cu index 0f118eadf..c88a6cfff 100644 --- a/src/cunumeric/index/wrap.cu +++ b/src/cunumeric/index/wrap.cu @@ -58,6 +58,41 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[idx] = p; } +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + wrap_kernel(const AccessorWO, 1> out, + const int64_t start, + const int64_t volume, + const Pitches<0> pitches_out, + const Point<1> out_lo, + const Pitches pitches_in, + const Point in_lo, + const AccessorRO indices) +{ + const auto idx = global_tid_1d(); + if (idx >= volume) return; + auto out_p = pitches_out.unflatten(idx, out_lo); + const int64_t input_idx = indices[out_p]; + auto p = pitches_in.unflatten(input_idx, in_lo); + out[out_p] = p; +} + +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + wrap_kernel_dense(Point* out, + const int64_t start, + const int64_t volume, + const Pitches pitches_in, + const Point in_lo, + const AccessorRO indices) +{ + const auto idx = global_tid_1d(); + if (idx >= volume) return; + const int64_t input_idx = indices[idx]; + auto p = pitches_in.unflatten(input_idx, in_lo); + out[idx] = p; +} + template struct WrapImplBody { void operator()(const AccessorWO, 1>& out, @@ -82,6 +117,28 @@ struct WrapImplBody { } CHECK_CUDA_STREAM(stream); } + void operator()(const AccessorWO, 1>& out, + const Pitches<0>& pitches_out, + const Rect<1>& out_rect, + const Pitches& pitches_in, + const Rect& in_rect, + const bool dense, + const AccessorRO& indices) const + { + auto stream = get_cached_stream(); + const int64_t start = out_rect.lo[0]; + const int64_t volume = out_rect.volume(); + const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + if (dense) { + auto outptr = out.ptr(out_rect); + wrap_kernel_dense<<>>( + outptr, start, volume, pitches_in, in_rect.lo, indices); + } else { + wrap_kernel<<>>( + out, start, volume, pitches_out, out_rect.lo, pitches_in, in_rect.lo, indices); + } + CHECK_CUDA_STREAM(stream); + } }; /*static*/ void WrapTask::gpu_variant(TaskContext& context) diff --git a/src/cunumeric/index/wrap.h b/src/cunumeric/index/wrap.h index 91c3f2326..abab1c1bb 100644 --- a/src/cunumeric/index/wrap.h +++ b/src/cunumeric/index/wrap.h @@ -25,6 +25,8 @@ struct WrapArgs { // copy information from original array to the // `wrapped` one const Legion::DomainPoint shape; // shape of the original array + const bool has_input; + const Array& in; }; class WrapTask : public CuNumericTask { diff --git a/src/cunumeric/index/wrap_omp.cc b/src/cunumeric/index/wrap_omp.cc index f95e9123c..1d8b55a52 100644 --- a/src/cunumeric/index/wrap_omp.cc +++ b/src/cunumeric/index/wrap_omp.cc @@ -51,6 +51,34 @@ struct WrapImplBody { } } // else } + + void operator()(const AccessorWO, 1>& out, + const Pitches<0>& pitches_out, + const Rect<1>& out_rect, + const Pitches& pitches_in, + const Rect& in_rect, + const bool dense, + const AccessorRO& indices) const + { + const int64_t start = out_rect.lo[0]; + const int64_t end = out_rect.hi[0]; + if (dense) { + auto outptr = out.ptr(out_rect); +#pragma omp parallel for schedule(static) + for (int64_t i = start; i <= end; i++) { + const int64_t input_idx = indices[i]; + auto point = pitches_in.unflatten(input_idx, in_rect.lo); + outptr[i - start] = point; + } + } else { +#pragma omp parallel for schedule(static) + for (int64_t i = start; i <= end; i++) { + const int64_t input_idx = indices[i]; + auto point = pitches_in.unflatten(input_idx, in_rect.lo); + out[i] = point; + } + } // else + } }; /*static*/ void WrapTask::omp_variant(TaskContext& context) diff --git a/src/cunumeric/index/wrap_template.inl b/src/cunumeric/index/wrap_template.inl index 46885f24e..60f94139c 100644 --- a/src/cunumeric/index/wrap_template.inl +++ b/src/cunumeric/index/wrap_template.inl @@ -60,17 +60,36 @@ struct WrapImpl { assert(volume_in != 0); #endif - WrapImplBody()(out, pitches_out, out_rect, pitches_in, input_rect, dense); + if (args.has_input) { + auto in_rect = args.in.shape<1>(); + auto in = args.in.read_accessor(in_rect); // input should be always integer type +#ifdef DEBUG_CUNUMERIC + assert(in_rect == out_rect); +#endif +#ifndef LEGION_BOUNDS_CHECKS + dense = dense && in.accessor.is_dense_row_major(out_rect); +#endif + WrapImplBody()(out, pitches_out, out_rect, pitches_in, input_rect, dense, in); + + } else { + WrapImplBody()(out, pitches_out, out_rect, pitches_in, input_rect, dense); + } // else } }; template static void wrap_template(TaskContext& context) { - auto shape = context.scalars()[0].value(); - int dim = shape.dim; - WrapArgs args{context.outputs()[0], shape}; - dim_dispatch(dim, WrapImpl{}, args); + auto shape = context.scalars()[0].value(); + int dim = shape.dim; + bool has_input = context.scalars()[1].value(); + if (has_input) { + WrapArgs args1{context.outputs()[0], shape, has_input, context.inputs()[0]}; + dim_dispatch(dim, WrapImpl{}, args1); + } else { + WrapArgs args2{context.outputs()[0], shape, has_input, Array()}; + dim_dispatch(dim, WrapImpl{}, args2); + } } } // namespace cunumeric diff --git a/tests/integration/test_put.py b/tests/integration/test_put.py index d01b4ec2f..8a307e73b 100644 --- a/tests/integration/test_put.py +++ b/tests/integration/test_put.py @@ -18,108 +18,59 @@ from utils.generators import mk_seq_array import cunumeric as num +from legate.core import LEGATE_MAX_DIM -# from legate.core import LEGATE_MAX_DIM - - -def test_1d(): - x_np = np.arange(10) - x_num = num.arange(10) - i_np = np.array([1, 3, 5]) - i_num = num.array(i_np) - v_np = np.array([100]) - v_num = num.array(v_np) - np.put(x_np, i_np, v_np) - num.put(x_num, i_num, v_num) - assert np.array_equal(x_num, x_np) - - -def test_raise(): +@pytest.mark.parametrize("mode", ("wrap", "clip")) +def test_scalar(mode): + # testing the case when indices is a scalar x = mk_seq_array(np, (3, 4, 5)) x_num = mk_seq_array(num, (3, 4, 5)) - indices = mk_seq_array(np, (8,)) - indices_num = num.array(indices) values = mk_seq_array(np, (6,)) * 10 values_num = num.array(values) - np.put(x, indices, values) - num.put(x_num, indices_num, values_num) + np.put(x, 0, values) + num.put(x_num, 0, values_num) assert np.array_equal(x_num, x) + np.put(x, 1, -10, mode) + num.put(x_num, 1, -10, mode) + assert np.array_equal(x_num, x) + + +@pytest.mark.parametrize("ndim", range(1, LEGATE_MAX_DIM + 1)) +def test_ndim(ndim): + shape = (5,) * ndim + np_arr = mk_seq_array(np, shape) + num_arr = mk_seq_array(num, shape) + np_indices = mk_seq_array(np, (4 * ndim,)) + num_indices = mk_seq_array(num, (4 * ndim,)) + shape_val = (3,) * ndim + np_values = mk_seq_array(np, shape_val) * 10 + num_values = mk_seq_array(num, shape_val) * 10 + + np.put(np_arr, np_indices, np_values) + num.put(num_arr, num_indices, num_values) + assert np.array_equal(np_arr, num_arr) + + +@pytest.mark.parametrize("ndim", range(1, LEGATE_MAX_DIM + 1)) +@pytest.mark.parametrize("mode", ("wrap", "clip")) +def test_ndim_mode(ndim, mode): + shape = (5,) * ndim + np_arr = mk_seq_array(np, shape) + num_arr = mk_seq_array(num, shape) + shape_in = (3,) * ndim + np_indices = mk_seq_array(np, shape_in) * 2 + num_indices = mk_seq_array(num, shape_in) * 2 + shape_val = (2,) * ndim + np_values = mk_seq_array(np, shape_val) * 10 + num_values = mk_seq_array(num, shape_val) * 10 + + np.put(np_arr, np_indices, np_values, mode=mode) + num.put(num_arr, num_indices, num_values, mode=mode) + assert np.array_equal(np_arr, num_arr) -# def test_modes(): -# -# x = mk_seq_array(np, (3, 4, 5)) -# x_num = mk_seq_array(num, (3, 4, 5)) -# indices = mk_seq_array(np, (8,))*2 -# indices_num = num.array(indices) -# values = mk_seq_array(np, (6,))*10 -# values_num = num.array(values) -# -# np.put(x, indices, values, mode="clip") -# num.put(x_num, indices_num, values, mode="clip") -# assert np.array_equal(x_num, x) -# -# np.put(x, indices, values, mode="wrap") -# num.put(x_num, indices_num, values, mode="wrap") -# assert np.array_equal(x_num, x) -# -# def test_scalar(): -# # testing the case when indices is a scalar -# x = mk_seq_array(np, (3, 4, 5)) -# x_num = mk_seq_array(num, (3, 4, 5)) -# values = mk_seq_array(np, (6,))*10 -# values_num = num.array(values) -# -# np.put(x, 0, values) -# num.put(x_num, 0, values) -# assert np.array_equal(x_num, x) -# -# np.put(x, 1, -10, mode) -# num.put(x_num, 1, -10, mode) -# assert np.array_equal(x_num, x) -# -# -# def test_nd_indices(): -# x = mk_seq_array(np, (15)) -# x_num = mk_seq_array(num, (15)) -# indices = mk_seq_array(np, (3,2))*2 -# indices_num = num.array(indices) -# values = mk_seq_array(np, (2,2))*10 -# values_num = num.array(values) -# -# np.put(x, indices, values, mode) -# num.put(x_num, indices_num, values, mode) -# assert np.array_equal(x_num, x) -# -# @pytest.mark.parametrize("ndim", range(LEGATE_MAX_DIM + 1)) -# def test_ndim(ndim): -# shape = (5,) * ndim -# np_arr = mk_seq_array(np, shape) -# num_arr = mk_seq_array(num, shape) -# np_indices = mk_seq_array(np, (4,)) -# num_indices = mk_seq_array(num, (4,)) -# np_values = mk_seq_array(np, (2,))*10 -# num_values = mk_seq_array(num, (2,))*10 -# -# np.put(np_arr, np_indices, np_values) -# num.put(num_arr, num_indices, num_values) -# assert np.array_equal(np_arr, num_arr) -# -# np_indices = mk_seq_array(np, (8,)) -# num_indices = mk_seq_array(num, (8,)) -# np.put(np_arr, np_indices, np_values, mode="wrap") -# num.put(num_arr, num_indices,num_values, mode="wrap") -# assert np.array_equal(np_arr, num_arr) -# -# np_arr = mk_seq_array(np, shape) -# num_arr = mk_seq_array(num, shape) -# np_arr.put(np_indices,np_values, mode="clip") -# num_arr.put(num_indices, num_values, mode="clip") -# assert np.array_equal(np_arr, num_arr) -# -# return if __name__ == "__main__": import sys From fc6ebe701a1274e58cfd733df6688e9882de9b80 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 12 Sep 2022 14:47:21 -0600 Subject: [PATCH 03/33] updating documentation --- docs/cunumeric/source/api/indexing.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/cunumeric/source/api/indexing.rst b/docs/cunumeric/source/api/indexing.rst index 1023ed1d4..1ace111d4 100644 --- a/docs/cunumeric/source/api/indexing.rst +++ b/docs/cunumeric/source/api/indexing.rst @@ -43,5 +43,6 @@ Inserting data into arrays :toctree: generated/ fill_diagonal + put put_along_axis place From 1d87fd72a8febd9c64db6e0069445682edea6cae Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 12 Sep 2022 15:49:18 -0600 Subject: [PATCH 04/33] code clean-up --- src/cunumeric/index/wrap.cc | 43 +++----------- src/cunumeric/index/wrap.cu | 83 +++++---------------------- src/cunumeric/index/wrap.h | 6 ++ src/cunumeric/index/wrap_omp.cc | 40 +++---------- src/cunumeric/index/wrap_template.inl | 3 +- 5 files changed, 37 insertions(+), 138 deletions(-) diff --git a/src/cunumeric/index/wrap.cc b/src/cunumeric/index/wrap.cc index b055cabc4..30f472710 100644 --- a/src/cunumeric/index/wrap.cc +++ b/src/cunumeric/index/wrap.cc @@ -24,57 +24,28 @@ using namespace legate; template struct WrapImplBody { + template void operator()(const AccessorWO, 1>& out, const Pitches<0>& pitches_out, const Rect<1>& out_rect, const Pitches& pitches_in, const Rect& in_rect, - const bool dense) const + const bool dense, + IND& indices) const { const int64_t start = out_rect.lo[0]; const int64_t end = out_rect.hi[0]; const auto in_volume = in_rect.volume(); if (dense) { - int64_t out_idx = 0; - auto outptr = out.ptr(out_rect); - for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = i % in_volume; - auto point = pitches_in.unflatten(input_idx, in_rect.lo); - outptr[out_idx] = point; - out_idx++; - } - } else { - for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = i % in_volume; - auto point = pitches_in.unflatten(input_idx, in_rect.lo); - out[i] = point; - } - } // else - } - - // the version when input is specified - void operator()(const AccessorWO, 1>& out, - const Pitches<0>& pitches_out, - const Rect<1>& out_rect, - const Pitches& pitches_in, - const Rect& in_rect, - const bool dense, - const AccessorRO& indices) const - { - const int64_t start = out_rect.lo[0]; - const int64_t end = out_rect.hi[0]; - if (dense) { - int64_t out_idx = 0; - auto outptr = out.ptr(out_rect); + auto outptr = out.ptr(out_rect); for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = indices[i]; + const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); - outptr[out_idx] = point; - out_idx++; + outptr[i - start] = point; } } else { for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = indices[i]; + const int64_t input_idx = compute_idx(i, in_volume, indices); // i % in_volume; auto point = pitches_in.unflatten(input_idx, in_rect.lo); out[i] = point; } diff --git a/src/cunumeric/index/wrap.cu b/src/cunumeric/index/wrap.cu index c88a6cfff..ae6b4a1d9 100644 --- a/src/cunumeric/index/wrap.cu +++ b/src/cunumeric/index/wrap.cu @@ -23,7 +23,7 @@ namespace cunumeric { using namespace Legion; using namespace legate; -template +template __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) wrap_kernel(const AccessorWO, 1> out, const int64_t start, @@ -32,110 +32,57 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) const Point<1> out_lo, const Pitches pitches_in, const Point in_lo, - const size_t in_volume) + const size_t in_volume, + IND indices) { const auto idx = global_tid_1d(); if (idx >= volume) return; - const int64_t input_idx = (idx + start) % in_volume; + const int64_t input_idx = compute_idx((idx + start), in_volume, indices); auto out_p = pitches_out.unflatten(idx, out_lo); auto p = pitches_in.unflatten(input_idx, in_lo); out[out_p] = p; } -template +template __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) wrap_kernel_dense(Point* out, const int64_t start, const int64_t volume, const Pitches pitches_in, const Point in_lo, - const size_t in_volume) + const size_t in_volume, + IND indices) { const auto idx = global_tid_1d(); if (idx >= volume) return; - const int64_t input_idx = (idx + start) % in_volume; - auto p = pitches_in.unflatten(input_idx, in_lo); - out[idx] = p; -} - -template -__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - wrap_kernel(const AccessorWO, 1> out, - const int64_t start, - const int64_t volume, - const Pitches<0> pitches_out, - const Point<1> out_lo, - const Pitches pitches_in, - const Point in_lo, - const AccessorRO indices) -{ - const auto idx = global_tid_1d(); - if (idx >= volume) return; - auto out_p = pitches_out.unflatten(idx, out_lo); - const int64_t input_idx = indices[out_p]; - auto p = pitches_in.unflatten(input_idx, in_lo); - out[out_p] = p; -} - -template -__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - wrap_kernel_dense(Point* out, - const int64_t start, - const int64_t volume, - const Pitches pitches_in, - const Point in_lo, - const AccessorRO indices) -{ - const auto idx = global_tid_1d(); - if (idx >= volume) return; - const int64_t input_idx = indices[idx]; + const int64_t input_idx = compute_idx((idx + start), in_volume, indices); auto p = pitches_in.unflatten(input_idx, in_lo); out[idx] = p; } template struct WrapImplBody { - void operator()(const AccessorWO, 1>& out, - const Pitches<0>& pitches_out, - const Rect<1>& out_rect, - const Pitches& pitches_in, - const Rect& in_rect, - const bool dense) const - { - auto stream = get_cached_stream(); - const int64_t start = out_rect.lo[0]; - const int64_t volume = out_rect.volume(); - const auto in_volume = in_rect.volume(); - const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - if (dense) { - auto outptr = out.ptr(out_rect); - wrap_kernel_dense<<>>( - outptr, start, volume, pitches_in, in_rect.lo, in_volume); - } else { - wrap_kernel<<>>( - out, start, volume, pitches_out, out_rect.lo, pitches_in, in_rect.lo, in_volume); - } - CHECK_CUDA_STREAM(stream); - } + template void operator()(const AccessorWO, 1>& out, const Pitches<0>& pitches_out, const Rect<1>& out_rect, const Pitches& pitches_in, const Rect& in_rect, const bool dense, - const AccessorRO& indices) const + IND& indices) const { auto stream = get_cached_stream(); const int64_t start = out_rect.lo[0]; const int64_t volume = out_rect.volume(); + const auto in_volume = in_rect.volume(); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; if (dense) { auto outptr = out.ptr(out_rect); - wrap_kernel_dense<<>>( - outptr, start, volume, pitches_in, in_rect.lo, indices); + wrap_kernel_dense<<>>( + outptr, start, volume, pitches_in, in_rect.lo, in_volume, indices); } else { - wrap_kernel<<>>( - out, start, volume, pitches_out, out_rect.lo, pitches_in, in_rect.lo, indices); + wrap_kernel<<>>( + out, start, volume, pitches_out, out_rect.lo, pitches_in, in_rect.lo, in_volume, indices); } CHECK_CUDA_STREAM(stream); } diff --git a/src/cunumeric/index/wrap.h b/src/cunumeric/index/wrap.h index abab1c1bb..962ca2d4a 100644 --- a/src/cunumeric/index/wrap.h +++ b/src/cunumeric/index/wrap.h @@ -43,4 +43,10 @@ class WrapTask : public CuNumericTask { #endif }; +__CUDA_HD__ int64_t compute_idx(int64_t i, int64_t volume, bool&) { return i % volume; } + +__CUDA_HD__ int64_t compute_idx(int64_t i, int64_t, const legate::AccessorRO& indices) +{ + return indices[i]; +} } // namespace cunumeric diff --git a/src/cunumeric/index/wrap_omp.cc b/src/cunumeric/index/wrap_omp.cc index 1d8b55a52..afdf0c7ab 100644 --- a/src/cunumeric/index/wrap_omp.cc +++ b/src/cunumeric/index/wrap_omp.cc @@ -24,56 +24,30 @@ using namespace legate; template struct WrapImplBody { + template void operator()(const AccessorWO, 1>& out, const Pitches<0>& pitches_out, const Rect<1>& out_rect, const Pitches& pitches_in, const Rect& in_rect, - const bool dense) const + const bool dense, + IND& indices) const { const int64_t start = out_rect.lo[0]; const int64_t end = out_rect.hi[0]; const auto in_volume = in_rect.volume(); if (dense) { auto outptr = out.ptr(out_rect); -#pragma omp parallel for schedule(static) - for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = i % in_volume; - auto point = pitches_in.unflatten(input_idx, in_rect.lo); - outptr[i - start] = point; - } - } else { -#pragma omp parallel for schedule(static) - for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = i % in_volume; - auto point = pitches_in.unflatten(input_idx, in_rect.lo); - out[i] = point; - } - } // else - } - - void operator()(const AccessorWO, 1>& out, - const Pitches<0>& pitches_out, - const Rect<1>& out_rect, - const Pitches& pitches_in, - const Rect& in_rect, - const bool dense, - const AccessorRO& indices) const - { - const int64_t start = out_rect.lo[0]; - const int64_t end = out_rect.hi[0]; - if (dense) { - auto outptr = out.ptr(out_rect); -#pragma omp parallel for schedule(static) + //#pragma omp parallel for schedule(static) for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = indices[i]; + const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); outptr[i - start] = point; } } else { -#pragma omp parallel for schedule(static) for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = indices[i]; + //#pragma omp parallel for schedule(static) + const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); out[i] = point; } diff --git a/src/cunumeric/index/wrap_template.inl b/src/cunumeric/index/wrap_template.inl index 60f94139c..42c84c9db 100644 --- a/src/cunumeric/index/wrap_template.inl +++ b/src/cunumeric/index/wrap_template.inl @@ -72,7 +72,8 @@ struct WrapImpl { WrapImplBody()(out, pitches_out, out_rect, pitches_in, input_rect, dense, in); } else { - WrapImplBody()(out, pitches_out, out_rect, pitches_in, input_rect, dense); + bool tmp = false; + WrapImplBody()(out, pitches_out, out_rect, pitches_in, input_rect, dense, tmp); } // else } }; From 15db53e2fcc52824cacbfebde1eb5125e70057d4 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 12 Sep 2022 22:12:25 -0600 Subject: [PATCH 05/33] adding missig pragams for openmp --- src/cunumeric/index/wrap_omp.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cunumeric/index/wrap_omp.cc b/src/cunumeric/index/wrap_omp.cc index afdf0c7ab..05ee2ef7c 100644 --- a/src/cunumeric/index/wrap_omp.cc +++ b/src/cunumeric/index/wrap_omp.cc @@ -38,15 +38,15 @@ struct WrapImplBody { const auto in_volume = in_rect.volume(); if (dense) { auto outptr = out.ptr(out_rect); - //#pragma omp parallel for schedule(static) +#pragma omp parallel for schedule(static) for (int64_t i = start; i <= end; i++) { const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); outptr[i - start] = point; } } else { +#pragma omp parallel for schedule(static) for (int64_t i = start; i <= end; i++) { - //#pragma omp parallel for schedule(static) const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); out[i] = point; From d01e17c940c4d43472649aa6b47be77904d2c49c Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 13 Sep 2022 09:32:55 -0600 Subject: [PATCH 06/33] fixing compile-time errors --- src/cunumeric/index/wrap.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/cunumeric/index/wrap.h b/src/cunumeric/index/wrap.h index 962ca2d4a..7436ac4e8 100644 --- a/src/cunumeric/index/wrap.h +++ b/src/cunumeric/index/wrap.h @@ -43,9 +43,11 @@ class WrapTask : public CuNumericTask { #endif }; -__CUDA_HD__ int64_t compute_idx(int64_t i, int64_t volume, bool&) { return i % volume; } +__CUDA_HD__ static int64_t compute_idx(int64_t i, int64_t volume, bool&) { return i % volume; } -__CUDA_HD__ int64_t compute_idx(int64_t i, int64_t, const legate::AccessorRO& indices) +__CUDA_HD__ static int64_t compute_idx(int64_t i, + int64_t, + const legate::AccessorRO& indices) { return indices[i]; } From 48d1920e07e5228d9144fb1a0aaa3dd214538bdf Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 13 Sep 2022 11:06:50 -0600 Subject: [PATCH 07/33] fixing mypy errors --- cunumeric/array.py | 2 +- cunumeric/module.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 4c3c3350d..d1243ceb7 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2467,7 +2467,7 @@ def diagonal( @add_boilerplate("indices", "values") def put( - self, indices: ndarray, values: ndarray, mode: bool = "raise" + self, indices: ndarray, values: ndarray, mode: str = "raise" ) -> None: """ Set storage-indexed locations to corresponding values. diff --git a/cunumeric/module.py b/cunumeric/module.py index 3ebe09b7b..28f9e51e2 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -3449,7 +3449,7 @@ def diagonal( @add_boilerplate("a", "indices", "values") def put( a: ndarray, indices: ndarray, values: ndarray, mode: str = "raise" -) -> ndarray: +) -> None: """ Set storage-indexed locations to corresponding values. @@ -3475,7 +3475,7 @@ def put( -------- Multiple GPUs, Multiple CPUs """ - return a.put(indices=indices, values=values, mode=mode) + a.put(indices=indices, values=values, mode=mode) @add_boilerplate("a", "val") From 458f1a2fec093e44374a8c8801d8a7ad4640e110 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 13 Sep 2022 13:52:47 -0600 Subject: [PATCH 08/33] fixing issue whith converting futures for put + modifying tests --- cunumeric/array.py | 10 ++-------- cunumeric/deferred.py | 22 +++++++++++++++++----- tests/integration/test_put.py | 20 +++++++++++--------- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index d1243ceb7..74b0551a0 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2490,13 +2490,7 @@ def put( if mode == "wrap": indices = indices % self.size elif mode == "clip": - if np.isscalar(indices): - if indices >= self.size: - indices = self.size - 1 - if indices < 0: - indices = 0 - else: - indices = indices.clip(0, self.size - 1) + indices = indices.clip(0, self.size - 1) if indices.dtype != np.int64: runtime.warn( @@ -2514,7 +2508,7 @@ def put( # are updated if indices.size > self.size: runtime.warn( - "size of indices is larger than source array which", + "size of indices is larger than source array which" " might result in undefined behaviour", category=RuntimeWarning, ) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 5df12918e..53cb6bced 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -781,10 +781,16 @@ def _broadcast(self, shape: NdShape) -> Any: return result - def _convert_future_to_regionfield(self) -> DeferredArray: + def _convert_future_to_regionfield( + self, future: bool = False + ) -> DeferredArray: + if future: + shape = (1,) + else: + shape = self.shape store = self.context.create_store( self.dtype, - shape=self.shape, + shape=shape, optimize_scalar=False, ) thunk_copy = DeferredArray( @@ -1660,11 +1666,17 @@ def _diag_helper( @auto_convert([1, 2]) def put(self, indices: Any, values: Any) -> None: if indices.base.kind == Future or indices.base.transformed: - indices = indices._convert_future_to_regionfield() + indices = indices._convert_future_to_regionfield( + indices.base.kind == Future + ) if values.base.kind == Future or values.base.transformed: - values = values._convert_future_to_regionfield() + values = values._convert_future_to_regionfield( + values.base.kind == Future + ) if self.base.kind == Future or self.base.transformed: - self = self._convert_future_to_regionfield() + self = self._convert_future_to_regionfield( + self.base.kind == Future + ) assert indices.size == values.size diff --git a/tests/integration/test_put.py b/tests/integration/test_put.py index 8a307e73b..e82a08e02 100644 --- a/tests/integration/test_put.py +++ b/tests/integration/test_put.py @@ -43,9 +43,10 @@ def test_ndim(ndim): shape = (5,) * ndim np_arr = mk_seq_array(np, shape) num_arr = mk_seq_array(num, shape) - np_indices = mk_seq_array(np, (4 * ndim,)) - num_indices = mk_seq_array(num, (4 * ndim,)) - shape_val = (3,) * ndim + shape_in = (3,) * ndim + np_indices = mk_seq_array(np, shape_in) + num_indices = mk_seq_array(num, shape_in) + shape_val = (2,) * ndim np_values = mk_seq_array(np, shape_val) * 10 num_values = mk_seq_array(num, shape_val) * 10 @@ -54,21 +55,22 @@ def test_ndim(ndim): assert np.array_equal(np_arr, num_arr) +INDICES = ([1, 2, 3, 100], [[2, 1], [3, 100]], [1], [100]) + + @pytest.mark.parametrize("ndim", range(1, LEGATE_MAX_DIM + 1)) @pytest.mark.parametrize("mode", ("wrap", "clip")) -def test_ndim_mode(ndim, mode): +@pytest.mark.parametrize("indices", INDICES) +def test_ndim_mode(ndim, mode, indices): shape = (5,) * ndim np_arr = mk_seq_array(np, shape) num_arr = mk_seq_array(num, shape) - shape_in = (3,) * ndim - np_indices = mk_seq_array(np, shape_in) * 2 - num_indices = mk_seq_array(num, shape_in) * 2 shape_val = (2,) * ndim np_values = mk_seq_array(np, shape_val) * 10 num_values = mk_seq_array(num, shape_val) * 10 - np.put(np_arr, np_indices, np_values, mode=mode) - num.put(num_arr, num_indices, num_values, mode=mode) + np.put(np_arr, indices, np_values, mode=mode) + num.put(num_arr, indices, num_values, mode=mode) assert np.array_equal(np_arr, num_arr) From e67d0f31802144f6fef79f18e043ac82beed32d1 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 13 Sep 2022 13:57:44 -0600 Subject: [PATCH 09/33] adding check for repeated entires in indices array --- cunumeric/array.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 74b0551a0..c2f26a46d 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2502,11 +2502,10 @@ def put( if indices.ndim > 1: indices = indices.ravel() - # in case size of the indices larger than size of the array - # some entries in the array will be updated 2 times. - # and Legate doesn't guarantee the order in which values + # in case if there are repeated entries in the indices array + # Legate doesn't guarantee the order in which values # are updated - if indices.size > self.size: + if indices.size > indices.unique().size: runtime.warn( "size of indices is larger than source array which" " might result in undefined behaviour", From 5b8a301c78898b1579fe0209fb9df24bd0a5568a Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 13 Sep 2022 15:36:13 -0600 Subject: [PATCH 10/33] fixing mypy errors --- cunumeric/deferred.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 53cb6bced..d285d6aea 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -785,7 +785,7 @@ def _convert_future_to_regionfield( self, future: bool = False ) -> DeferredArray: if future: - shape = (1,) + shape: NdShape = (1,) else: shape = self.shape store = self.context.create_store( From dd5f0a3129e6afbb01c979d9ab326f2bf49c00ba Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 19 Sep 2022 09:37:04 -0600 Subject: [PATCH 11/33] Update error message for wrong clip mode Co-authored-by: Bryan Van de Ven --- cunumeric/array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index c2f26a46d..2b316e7b5 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2486,7 +2486,9 @@ def put( return if mode not in ("raise", "wrap", "clip"): - raise ValueError("clipmode not understood") + raise ValueError(f"clipmode must be one of 'clip', 'raise', or 'wrap' (got {mode})") + +") if mode == "wrap": indices = indices % self.size elif mode == "clip": From ca965ed3c50e6d35a508b97c6f3ebfdd9f22758d Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 19 Sep 2022 09:37:38 -0600 Subject: [PATCH 12/33] update warning message Co-authored-by: Bryan Van de Ven --- cunumeric/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 2b316e7b5..2a007ba5e 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2510,7 +2510,7 @@ def put( if indices.size > indices.unique().size: runtime.warn( "size of indices is larger than source array which" - " might result in undefined behaviour", + " might yield results different from NumPy", category=RuntimeWarning, ) # call _wrap on the values if they need to be wrapped From 6e055733afb10304cdd3cfdbd9ec1f63517a027c Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 19 Sep 2022 10:04:42 -0600 Subject: [PATCH 13/33] adding _warn_and_convert function --- cunumeric/array.py | 33 +++++++++++++++------------------ tests/integration/test_put.py | 12 ++++++++++++ 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index c2f26a46d..362ffb5b0 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -920,12 +920,8 @@ def _convert_key(self, key: Any, first: bool = True) -> Any: key = convert_to_cunumeric_ndarray(key) if key.dtype != bool and not np.issubdtype(key.dtype, np.integer): raise TypeError("index arrays should be int or bool type") - if key.dtype != bool and key.dtype != np.int64: - runtime.warn( - "converting index array to int64 type", - category=RuntimeWarning, - ) - key = key.astype(np.int64) + if key.dtype != bool: + key = key._warn_and_convert(np.int64) return key._thunk @@ -2093,12 +2089,8 @@ def compress( raise ValueError( "Dimension mismatch: condition must be a 1D array" ) - if condition.dtype != bool: - runtime.warn( - "converting condition to bool type", - category=RuntimeWarning, - ) - condition = condition.astype(bool) + + condition = condition._warn_and_convert(bool) if axis is None: axis = 0 @@ -2492,12 +2484,7 @@ def put( elif mode == "clip": indices = indices.clip(0, self.size - 1) - if indices.dtype != np.int64: - runtime.warn( - "converting index array to int64 type", - category=RuntimeWarning, - ) - indices = indices.astype(np.int64) + indices = indices._warn_and_convert(np.int64) if indices.ndim > 1: indices = indices.ravel() @@ -3863,6 +3850,16 @@ def _maybe_convert(self, dtype: np.dtype[Any], hints: Any) -> ndarray: copy._thunk.convert(self._thunk) return copy + def _warn_and_convert(self, dtype: np.dtype[Any]) -> ndarray: + if self.dtype != dtype: + runtime.warn( + f"converting array to {dtype} type", + category=RuntimeWarning, + ) + return self.astype(dtype) + else: + return self + # For performing normal/broadcast unary operations @classmethod def _perform_unary_op( diff --git a/tests/integration/test_put.py b/tests/integration/test_put.py index e82a08e02..998a91303 100644 --- a/tests/integration/test_put.py +++ b/tests/integration/test_put.py @@ -38,6 +38,18 @@ def test_scalar(mode): assert np.array_equal(x_num, x) +def test_indices_type_convert(): + x = mk_seq_array(np, (3, 4, 5)) + x_num = mk_seq_array(num, (3, 4, 5)) + values = mk_seq_array(np, (6,)) * 10 + values_num = num.array(values) + indices = np.array([1, 2], dtype=np.int32) + indices_num = num.array(indices) + np.put(x, indices, values) + num.put(x_num, indices_num, values_num) + assert np.array_equal(x_num, x) + + @pytest.mark.parametrize("ndim", range(1, LEGATE_MAX_DIM + 1)) def test_ndim(ndim): shape = (5,) * ndim From 9ead87fdb3dfef70ef7a0cba1adaf0762e0d7e23 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 19 Sep 2022 10:43:07 -0600 Subject: [PATCH 14/33] fixed formatting error --- cunumeric/array.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 00022f728..e905ad89c 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2478,9 +2478,11 @@ def put( return if mode not in ("raise", "wrap", "clip"): - raise ValueError(f"clipmode must be one of 'clip', 'raise', or 'wrap' (got {mode})") + raise ValueError( + "clipmode must be one of 'clip', 'raise', or 'wrap' " + f"(get {mode})" + ) -") if mode == "wrap": indices = indices % self.size elif mode == "clip": From 24cb09684d6bb3598272e2ef4795db8530ab88ef Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 19 Sep 2022 14:11:33 -0600 Subject: [PATCH 15/33] fixing mypy errors --- cunumeric/array.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index e905ad89c..a3a91a82e 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -921,7 +921,7 @@ def _convert_key(self, key: Any, first: bool = True) -> Any: if key.dtype != bool and not np.issubdtype(key.dtype, np.integer): raise TypeError("index arrays should be int or bool type") if key.dtype != bool: - key = key._warn_and_convert(np.int64) + key = key._warn_and_convert(np.dtype(np.int64)) return key._thunk @@ -2090,7 +2090,7 @@ def compress( "Dimension mismatch: condition must be a 1D array" ) - condition = condition._warn_and_convert(bool) + condition = condition._warn_and_convert(np.dtype(bool)) if axis is None: axis = 0 @@ -2488,7 +2488,7 @@ def put( elif mode == "clip": indices = indices.clip(0, self.size - 1) - indices = indices._warn_and_convert(np.int64) + indices = indices._warn_and_convert(np.dtype(np.int64)) if indices.ndim > 1: indices = indices.ravel() From 41e84063bed009dd746d6a92a91528635743c361 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 27 Sep 2022 15:55:46 -0600 Subject: [PATCH 16/33] addressing PR comments --- cunumeric/array.py | 16 +++++++++------- cunumeric/deferred.py | 4 +--- cunumeric/eager.py | 1 + cunumeric/module.py | 20 +++++++++----------- tests/integration/test_put.py | 9 ++++++++- 5 files changed, 28 insertions(+), 22 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 076ba5ad6..19f5989dc 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2005,7 +2005,7 @@ def choose( if not np.issubdtype(self.dtype, np.integer): raise TypeError("a array should be integer type") if self.dtype != np.int64: - a = a.astype(np.int64) + a = a._warn_and_convert(np.int64) if mode == "raise": if (a < 0).any() | (a >= n).any(): raise ValueError("invalid entry in choice array") @@ -2461,11 +2461,13 @@ def put( self, indices: ndarray, values: ndarray, mode: str = "raise" ) -> None: """ - Set storage-indexed locations to corresponding values. + Replaces specified elements of the array with given values. + + Refer to :func:`cunumeric.put` for full documentation. See Also -------- - numpy.put + cunumeric.put : equivalent function Availability -------- @@ -2478,8 +2480,8 @@ def put( if mode not in ("raise", "wrap", "clip"): raise ValueError( - "clipmode must be one of 'clip', 'raise', or 'wrap' " - f"(get {mode})" + "mode must be one of 'clip', 'raise', or 'wrap' " + f"(got {mode})" ) if mode == "wrap": @@ -3385,9 +3387,9 @@ def searchsorted( ch_dtype = np.find_common_type([a.dtype, v_ndarray.dtype], []) if v_ndarray.dtype != ch_dtype: - v_ndarray = v_ndarray.astype(ch_dtype) + v_ndarray = v_ndarray._warn_and_convert(ch_dtype) if a.dtype != ch_dtype: - a = a.astype(ch_dtype) + a = a._warn_and_convert(ch_dtype) if sorter is not None and a.shape[0] > 1: if sorter.ndim != 1: diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index a34afe417..a323800e6 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -1677,9 +1677,7 @@ def put(self, indices: Any, values: Any) -> None: values.base.kind == Future ) if self.base.kind == Future or self.base.transformed: - self = self._convert_future_to_regionfield( - self.base.kind == Future - ) + self = self.copy(self, deep=True) assert indices.size == values.size diff --git a/cunumeric/eager.py b/cunumeric/eager.py index df7be0ae2..1d733e4d2 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -620,6 +620,7 @@ def _diag_helper( self.array = diagonal_reference(rhs.array, axes) def put(self, indices: Any, values: Any) -> None: + self.check_eager_args(indices, values) if self.deferred is not None: self.deferred.put(indices, values) else: diff --git a/cunumeric/module.py b/cunumeric/module.py index 0a03ac516..d02ca2489 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -2387,11 +2387,7 @@ def repeat(a: ndarray, repeats: Any, axis: Optional[int] = None) -> ndarray: else: # repeats should be integer type if repeats.dtype != np.int64: - runtime.warn( - "converting repeats to an integer type", - category=RuntimeWarning, - ) - repeats = repeats.astype(np.int64) + repeats = repeats._warn_and_convert(np.int64) if repeats.shape[0] != array.shape[axis]: raise ValueError("incorrect shape of repeats array") result = array._thunk.repeat( @@ -2509,7 +2505,7 @@ def place(arr: ndarray, mask: ndarray, vals: ndarray) -> None: if mask_reshape.dtype == bool: arr._thunk.set_item(mask_reshape._thunk, vals_resized._thunk) else: - bool_mask = mask_reshape.astype(bool) + bool_mask = mask_reshape._warn_and_convert(bool) arr._thunk.set_item(bool_mask._thunk, vals_resized._thunk) @@ -2551,7 +2547,7 @@ def extract(condition: ndarray, arr: ndarray) -> ndarray: if condition_reshape.dtype == bool: thunk = arr._thunk.get_item(condition_reshape._thunk) else: - bool_condition = condition_reshape.astype(bool) + bool_condition = condition_reshape._warn_and_convert(bool) thunk = arr._thunk.get_item(bool_condition._thunk) return ndarray(shape=thunk.shape, thunk=thunk) @@ -3454,7 +3450,8 @@ def put( a: ndarray, indices: ndarray, values: ndarray, mode: str = "raise" ) -> None: """ - Set storage-indexed locations to corresponding values. + Replaces specified elements of an array with given values. + The indexing works as if the target array is first flattened. Parameters ---------- @@ -3463,7 +3460,8 @@ def put( indices : array_like Target indices, interpreted as integers. values : array_like - Values to place in `a` at target indices. + Values to place in `a` at target indices. If values array is shorter + than indices, it will be repeated as necessary. mode : {'raise', 'wrap', 'clip'}, optional Specifies how out-of-bounds indices will behave. 'raise' : raise an error. @@ -5446,7 +5444,7 @@ def convolve(a: ndarray, v: ndarray, mode: ConvolveMode = "full") -> ndarray: v, a = a, v if a.dtype != v.dtype: - v = v.astype(a.dtype) + v = v._warn_and_convert(a.dtype) out = ndarray( shape=a.shape, dtype=a.dtype, @@ -6170,7 +6168,7 @@ def bincount( if weights.dtype.kind == "c": raise ValueError("weights must be convertible to float64") # Make sure the weights are float64 - weights = weights.astype(np.float64) + weights = weights._warn_and_convert(np.float64) if x.dtype.kind != "i" and x.dtype.kind != "u": raise TypeError("input array for bincount must be integer type") if minlength < 0: diff --git a/tests/integration/test_put.py b/tests/integration/test_put.py index 998a91303..0dd4312b2 100644 --- a/tests/integration/test_put.py +++ b/tests/integration/test_put.py @@ -15,10 +15,10 @@ import numpy as np import pytest +from legate.core import LEGATE_MAX_DIM from utils.generators import mk_seq_array import cunumeric as num -from legate.core import LEGATE_MAX_DIM @pytest.mark.parametrize("mode", ("wrap", "clip")) @@ -37,6 +37,13 @@ def test_scalar(mode): num.put(x_num, 1, -10, mode) assert np.array_equal(x_num, x) + # checking transformed array + y = x[:1] + y_num = x_num[:1] + np.put(y, 0, values) + num.put(y_num, 0, values_num) + assert np.array_equal(x_num, x) + def test_indices_type_convert(): x = mk_seq_array(np, (3, 4, 5)) From 4814b4cd98fe9108f47e2b8783b42a3c25b85b78 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Tue, 27 Sep 2022 16:07:45 -0700 Subject: [PATCH 17/33] Avoid emitting new warnings --- cunumeric/array.py | 6 +++--- cunumeric/module.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 19f5989dc..eeccc5738 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2005,7 +2005,7 @@ def choose( if not np.issubdtype(self.dtype, np.integer): raise TypeError("a array should be integer type") if self.dtype != np.int64: - a = a._warn_and_convert(np.int64) + a = a.astype(np.int64) if mode == "raise": if (a < 0).any() | (a >= n).any(): raise ValueError("invalid entry in choice array") @@ -3387,9 +3387,9 @@ def searchsorted( ch_dtype = np.find_common_type([a.dtype, v_ndarray.dtype], []) if v_ndarray.dtype != ch_dtype: - v_ndarray = v_ndarray._warn_and_convert(ch_dtype) + v_ndarray = v_ndarray.astype(ch_dtype) if a.dtype != ch_dtype: - a = a._warn_and_convert(ch_dtype) + a = a.astype(ch_dtype) if sorter is not None and a.shape[0] > 1: if sorter.ndim != 1: diff --git a/cunumeric/module.py b/cunumeric/module.py index d02ca2489..0d43555fc 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -2505,7 +2505,7 @@ def place(arr: ndarray, mask: ndarray, vals: ndarray) -> None: if mask_reshape.dtype == bool: arr._thunk.set_item(mask_reshape._thunk, vals_resized._thunk) else: - bool_mask = mask_reshape._warn_and_convert(bool) + bool_mask = mask_reshape.astype(bool) arr._thunk.set_item(bool_mask._thunk, vals_resized._thunk) @@ -2547,7 +2547,7 @@ def extract(condition: ndarray, arr: ndarray) -> ndarray: if condition_reshape.dtype == bool: thunk = arr._thunk.get_item(condition_reshape._thunk) else: - bool_condition = condition_reshape._warn_and_convert(bool) + bool_condition = condition_reshape.astype(bool) thunk = arr._thunk.get_item(bool_condition._thunk) return ndarray(shape=thunk.shape, thunk=thunk) @@ -5444,7 +5444,7 @@ def convolve(a: ndarray, v: ndarray, mode: ConvolveMode = "full") -> ndarray: v, a = a, v if a.dtype != v.dtype: - v = v._warn_and_convert(a.dtype) + v = v.astype(a.dtype) out = ndarray( shape=a.shape, dtype=a.dtype, @@ -6168,7 +6168,7 @@ def bincount( if weights.dtype.kind == "c": raise ValueError("weights must be convertible to float64") # Make sure the weights are float64 - weights = weights._warn_and_convert(np.float64) + weights = weights.astype(np.float64) if x.dtype.kind != "i" and x.dtype.kind != "u": raise TypeError("input array for bincount must be integer type") if minlength < 0: From 5f64cadcf2b92a033baf52a33e3e4ee892ce1143 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 27 Sep 2022 17:29:56 -0600 Subject: [PATCH 18/33] fixing logic for PUT in the case of transformed arrays --- cunumeric/deferred.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index a323800e6..65025f4d0 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -1677,14 +1677,18 @@ def put(self, indices: Any, values: Any) -> None: values.base.kind == Future ) if self.base.kind == Future or self.base.transformed: - self = self.copy(self, deep=True) + self_tmp = self._convert_future_to_regionfield( + self.base.kind == Future + ) + else: + self_tmp = self assert indices.size == values.size # first, we create indirect array with PointN type that # (indices.size,) shape and is used to copy data from values # to the target ND array (self) - N = self.ndim + N = self_tmp.ndim pointN_dtype = self.runtime.get_point_type(N) indirect = cast( DeferredArray, @@ -1697,7 +1701,7 @@ def put(self, indices: Any, values: Any) -> None: task = self.context.create_task(CuNumericOpCode.WRAP) task.add_output(indirect.base) - task.add_scalar_arg(self.shape, (ty.int64,)) + task.add_scalar_arg(self_tmp.shape, (ty.int64,)) task.add_scalar_arg(True, bool) # has_input task.add_input(indices.base) task.add_alignment(indices.base, indirect.base) @@ -1709,9 +1713,12 @@ def put(self, indices: Any, values: Any) -> None: copy.set_target_indirect_out_of_range(False) copy.add_input(values.base) copy.add_target_indirect(indirect.base) - copy.add_output(self.base) + copy.add_output(self_tmp.base) copy.execute() + if self_tmp is not self: + self.copy(self_tmp, deep=True) + # Create an identity array with the ones offset from the diagonal by k def eye(self, k: int) -> None: assert self.ndim == 2 # Only 2-D arrays should be here From d5c4414524200575931167d0a22374700b941f97 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Tue, 27 Sep 2022 18:41:39 -0700 Subject: [PATCH 19/33] _warn_and_convert checks the target type already --- cunumeric/module.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cunumeric/module.py b/cunumeric/module.py index 0d43555fc..dbc048b7e 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -2386,8 +2386,7 @@ def repeat(a: ndarray, repeats: Any, axis: Optional[int] = None) -> ndarray: # repeats is an array else: # repeats should be integer type - if repeats.dtype != np.int64: - repeats = repeats._warn_and_convert(np.int64) + repeats = repeats._warn_and_convert(np.int64) if repeats.shape[0] != array.shape[axis]: raise ValueError("incorrect shape of repeats array") result = array._thunk.repeat( From 9a8a3ca3f8ad8382f6a92e72729e57831ce4537b Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 27 Sep 2022 21:03:49 -0600 Subject: [PATCH 20/33] addressing PR comments --- cunumeric/array.py | 9 --------- cunumeric/module.py | 4 ++++ src/cunumeric/index/wrap.cc | 2 +- src/cunumeric/index/wrap.cu | 6 +++--- src/cunumeric/index/wrap.h | 11 +++++++---- src/cunumeric/index/wrap_omp.cc | 2 +- src/cunumeric/index/wrap_template.inl | 20 +++++++------------- tests/integration/test_put.py | 4 ++++ 8 files changed, 27 insertions(+), 31 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index eeccc5738..3755e4bd2 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2494,15 +2494,6 @@ def put( if indices.ndim > 1: indices = indices.ravel() - # in case if there are repeated entries in the indices array - # Legate doesn't guarantee the order in which values - # are updated - if indices.size > indices.unique().size: - runtime.warn( - "size of indices is larger than source array which" - " might yield results different from NumPy", - category=RuntimeWarning, - ) # call _wrap on the values if they need to be wrapped if values.ndim != indices.ndim or values.size != indices.size: values = values._wrap(indices.size) diff --git a/cunumeric/module.py b/cunumeric/module.py index dbc048b7e..d3b0d0c54 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -3458,6 +3458,10 @@ def put( Array to put data into indices : array_like Target indices, interpreted as integers. + WARNING: in case if there are repeated entries in the + indices array, Legate doesn't guarantee the order in + which values are updated + values : array_like Values to place in `a` at target indices. If values array is shorter than indices, it will be repeated as necessary. diff --git a/src/cunumeric/index/wrap.cc b/src/cunumeric/index/wrap.cc index 30f472710..87d362556 100644 --- a/src/cunumeric/index/wrap.cc +++ b/src/cunumeric/index/wrap.cc @@ -31,7 +31,7 @@ struct WrapImplBody { const Pitches& pitches_in, const Rect& in_rect, const bool dense, - IND& indices) const + const IND& indices) const { const int64_t start = out_rect.lo[0]; const int64_t end = out_rect.hi[0]; diff --git a/src/cunumeric/index/wrap.cu b/src/cunumeric/index/wrap.cu index ae6b4a1d9..e999477c8 100644 --- a/src/cunumeric/index/wrap.cu +++ b/src/cunumeric/index/wrap.cu @@ -33,7 +33,7 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) const Pitches pitches_in, const Point in_lo, const size_t in_volume, - IND indices) + const IND indices) { const auto idx = global_tid_1d(); if (idx >= volume) return; @@ -51,7 +51,7 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) const Pitches pitches_in, const Point in_lo, const size_t in_volume, - IND indices) + const IND indices) { const auto idx = global_tid_1d(); if (idx >= volume) return; @@ -69,7 +69,7 @@ struct WrapImplBody { const Pitches& pitches_in, const Rect& in_rect, const bool dense, - IND& indices) const + const IND& indices) const { auto stream = get_cached_stream(); const int64_t start = out_rect.lo[0]; diff --git a/src/cunumeric/index/wrap.h b/src/cunumeric/index/wrap.h index 7436ac4e8..495f549b8 100644 --- a/src/cunumeric/index/wrap.h +++ b/src/cunumeric/index/wrap.h @@ -26,7 +26,7 @@ struct WrapArgs { // `wrapped` one const Legion::DomainPoint shape; // shape of the original array const bool has_input; - const Array& in; + const Array& in = Array(); }; class WrapTask : public CuNumericTask { @@ -43,10 +43,13 @@ class WrapTask : public CuNumericTask { #endif }; -__CUDA_HD__ static int64_t compute_idx(int64_t i, int64_t volume, bool&) { return i % volume; } +__CUDA_HD__ static int64_t compute_idx(const int64_t i, const int64_t volume, const bool&) +{ + return i % volume; +} -__CUDA_HD__ static int64_t compute_idx(int64_t i, - int64_t, +__CUDA_HD__ static int64_t compute_idx(const int64_t i, + const int64_t, const legate::AccessorRO& indices) { return indices[i]; diff --git a/src/cunumeric/index/wrap_omp.cc b/src/cunumeric/index/wrap_omp.cc index 05ee2ef7c..185e003f8 100644 --- a/src/cunumeric/index/wrap_omp.cc +++ b/src/cunumeric/index/wrap_omp.cc @@ -31,7 +31,7 @@ struct WrapImplBody { const Pitches& pitches_in, const Rect& in_rect, const bool dense, - IND& indices) const + const IND& indices) const { const int64_t start = out_rect.lo[0]; const int64_t end = out_rect.hi[0]; diff --git a/src/cunumeric/index/wrap_template.inl b/src/cunumeric/index/wrap_template.inl index 42c84c9db..093f5f5b1 100644 --- a/src/cunumeric/index/wrap_template.inl +++ b/src/cunumeric/index/wrap_template.inl @@ -65,9 +65,6 @@ struct WrapImpl { auto in = args.in.read_accessor(in_rect); // input should be always integer type #ifdef DEBUG_CUNUMERIC assert(in_rect == out_rect); -#endif -#ifndef LEGION_BOUNDS_CHECKS - dense = dense && in.accessor.is_dense_row_major(out_rect); #endif WrapImplBody()(out, pitches_out, out_rect, pitches_in, input_rect, dense, in); @@ -81,16 +78,13 @@ struct WrapImpl { template static void wrap_template(TaskContext& context) { - auto shape = context.scalars()[0].value(); - int dim = shape.dim; - bool has_input = context.scalars()[1].value(); - if (has_input) { - WrapArgs args1{context.outputs()[0], shape, has_input, context.inputs()[0]}; - dim_dispatch(dim, WrapImpl{}, args1); - } else { - WrapArgs args2{context.outputs()[0], shape, has_input, Array()}; - dim_dispatch(dim, WrapImpl{}, args2); - } + auto shape = context.scalars()[0].value(); + int dim = shape.dim; + bool has_input = context.scalars()[1].value(); + Array tmp_array = Array(); + WrapArgs args{ + context.outputs()[0], shape, has_input, has_input ? context.inputs()[0] : tmp_array}; + dim_dispatch(dim, WrapImpl{}, args); } } // namespace cunumeric diff --git a/tests/integration/test_put.py b/tests/integration/test_put.py index 0dd4312b2..5d5ceb438 100644 --- a/tests/integration/test_put.py +++ b/tests/integration/test_put.py @@ -44,6 +44,10 @@ def test_scalar(mode): num.put(y_num, 0, values_num) assert np.array_equal(x_num, x) + x = num.zeros(1) + num.put(x, num.arange(4), num.ones(4), mode="clip") + print(x) # prints [0] + def test_indices_type_convert(): x = mk_seq_array(np, (3, 4, 5)) From c91042a582a31a6ee7eaf221a8f29ae92e50d835 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Tue, 27 Sep 2022 20:26:49 -0700 Subject: [PATCH 21/33] Typo --- cunumeric/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cunumeric/module.py b/cunumeric/module.py index d3b0d0c54..6eb88580e 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -3458,9 +3458,9 @@ def put( Array to put data into indices : array_like Target indices, interpreted as integers. - WARNING: in case if there are repeated entries in the + WARNING: In case there are repeated entries in the indices array, Legate doesn't guarantee the order in - which values are updated + which values are updated. values : array_like Values to place in `a` at target indices. If values array is shorter From 5d6c9fa049c767ca95c544849d7d5d90f9ef0cd3 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Wed, 28 Sep 2022 10:47:37 -0600 Subject: [PATCH 22/33] adding check for out-of-the-bounds indices --- cunumeric/deferred.py | 1 + src/cunumeric/index/wrap.cc | 4 +++- src/cunumeric/index/wrap.h | 13 +++++++++++++ src/cunumeric/index/wrap_omp.cc | 2 ++ 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 65025f4d0..7bf10168c 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -1705,6 +1705,7 @@ def put(self, indices: Any, values: Any) -> None: task.add_scalar_arg(True, bool) # has_input task.add_input(indices.base) task.add_alignment(indices.base, indirect.base) + task.throws_exception(IndexError) task.execute() if indirect.base.kind == Future or indirect.base.transformed: indirect = indirect._convert_future_to_regionfield() diff --git a/src/cunumeric/index/wrap.cc b/src/cunumeric/index/wrap.cc index 87d362556..a5483cbdd 100644 --- a/src/cunumeric/index/wrap.cc +++ b/src/cunumeric/index/wrap.cc @@ -39,13 +39,15 @@ struct WrapImplBody { if (dense) { auto outptr = out.ptr(out_rect); for (int64_t i = start; i <= end; i++) { + check_idx(i, in_volume, indices); const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); outptr[i - start] = point; } } else { for (int64_t i = start; i <= end; i++) { - const int64_t input_idx = compute_idx(i, in_volume, indices); // i % in_volume; + check_idx(i, in_volume, indices); + const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); out[i] = point; } diff --git a/src/cunumeric/index/wrap.h b/src/cunumeric/index/wrap.h index 495f549b8..aa91c827d 100644 --- a/src/cunumeric/index/wrap.h +++ b/src/cunumeric/index/wrap.h @@ -54,4 +54,17 @@ __CUDA_HD__ static int64_t compute_idx(const int64_t i, { return indices[i]; } + +static void check_idx(const int64_t i, + const int64_t volume, + const legate::AccessorRO& indices) +{ + int64_t index = indices[i]; + if (index < 0 || index >= volume) + throw legate::TaskException("index is out of bounds in index array"); +} +static void check_idx(const int64_t i, const int64_t volume, const bool&) +{ + // don't do anything when wrapping indices +} } // namespace cunumeric diff --git a/src/cunumeric/index/wrap_omp.cc b/src/cunumeric/index/wrap_omp.cc index 185e003f8..531592df9 100644 --- a/src/cunumeric/index/wrap_omp.cc +++ b/src/cunumeric/index/wrap_omp.cc @@ -40,6 +40,7 @@ struct WrapImplBody { auto outptr = out.ptr(out_rect); #pragma omp parallel for schedule(static) for (int64_t i = start; i <= end; i++) { + check_idx(i, in_volume, indices); const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); outptr[i - start] = point; @@ -47,6 +48,7 @@ struct WrapImplBody { } else { #pragma omp parallel for schedule(static) for (int64_t i = start; i <= end; i++) { + check_idx(i, in_volume, indices); const int64_t input_idx = compute_idx(i, in_volume, indices); auto point = pitches_in.unflatten(input_idx, in_rect.lo); out[i] = point; From 4ffae77f97d7a6d60f43a6f3cdd0882d319d81ae Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Wed, 28 Sep 2022 11:07:58 -0600 Subject: [PATCH 23/33] fixing the case when scalar walue needs to be wrapped --- cunumeric/deferred.py | 5 +++-- tests/integration/test_put.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 7bf10168c..0a8ea22f6 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -3394,9 +3394,10 @@ def unpackbits( task.execute() @auto_convert([1]) - def _wrap(self, src: Any, new_len: int) -> None: + def _wrap(self, src: DeferredArray, new_len: int) -> None: + src = self.runtime.to_deferred_array(src) if src.base.kind == Future or src.base.transformed: - src = src._convert_future_to_regionfield() + src = src._convert_future_to_regionfield(src.base.kind == Future) # first, we create indirect array with PointN type that # (len,) shape and is used to copy data from original array diff --git a/tests/integration/test_put.py b/tests/integration/test_put.py index 5d5ceb438..80cfb089a 100644 --- a/tests/integration/test_put.py +++ b/tests/integration/test_put.py @@ -44,9 +44,19 @@ def test_scalar(mode): num.put(y_num, 0, values_num) assert np.array_equal(x_num, x) - x = num.zeros(1) - num.put(x, num.arange(4), num.ones(4), mode="clip") - print(x) # prints [0] + x = np.zeros(1) + x_num = num.zeros(1) + np.put(x, np.arange(4), np.ones(4), mode="clip") + num.put(x_num, num.arange(4), num.ones(4), mode="clip") + assert np.array_equal(x_num, x) + + x = np.arange(5) + x_num = num.array(x) + indices = np.array([1, 4]) + indices_num = num.array(indices) + np.put(x, indices, 10) + num.put(x_num, indices_num, 10) + assert np.array_equal(x_num, x) def test_indices_type_convert(): From f8af1dece8fa82501a929bc5f1ce74821d691f73 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Thu, 6 Oct 2022 14:36:48 -0600 Subject: [PATCH 24/33] adding bounds check to the cuda kernel --- src/cunumeric/index/wrap.cu | 56 +++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/cunumeric/index/wrap.cu b/src/cunumeric/index/wrap.cu index e999477c8..1e60d1367 100644 --- a/src/cunumeric/index/wrap.cu +++ b/src/cunumeric/index/wrap.cu @@ -23,6 +23,26 @@ namespace cunumeric { using namespace Legion; using namespace legate; +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + check_kernel(Output out, + const AccessorRO indices, + const int64_t start, + const int64_t volume, + const int64_t in_volume, + const int64_t iters) +{ + bool value = false; + for (size_t i = 0; i < iters; i++) { + const auto idx = (i * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= volume) break; + auto index = indices[idx + start]; + bool val = (index < 0 || index >= in_volume); + SumReduction::fold(value, val); + } + reduce_output(out, value); +} + template __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) wrap_kernel(const AccessorWO, 1> out, @@ -60,6 +80,39 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[idx] = p; } +// don't do anything when indices is a boolean +void check_out_of_bounds(const bool& indices, + const int64_t start, + const int64_t volume, + const int64_t volume_in, + cudaStream_t stream) +{ +} + +void check_out_of_bounds(const AccessorRO& indices, + const int64_t start, + const int64_t volume, + const int64_t volume_in, + cudaStream_t stream) +{ + const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(int64_t); + DeviceScalarReductionBuffer> out_of_bounds(stream); + + if (blocks >= MAX_REDUCTION_CTAS) { + const size_t iters = (blocks + MAX_REDUCTION_CTAS - 1) / MAX_REDUCTION_CTAS; + check_kernel<<>>( + out_of_bounds, indices, start, volume, volume_in, iters); + } else { + check_kernel<<>>( + out_of_bounds, indices, start, volume, volume_in, 1); + } + CHECK_CUDA_STREAM(stream); + + bool res = out_of_bounds.read(stream); + if (res) throw legate::TaskException("index is out of bounds in index array"); +} + template struct WrapImplBody { template @@ -76,6 +129,9 @@ struct WrapImplBody { const int64_t volume = out_rect.volume(); const auto in_volume = in_rect.volume(); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + + check_out_of_bounds(indices, start, volume, in_volume, stream); + if (dense) { auto outptr = out.ptr(out_rect); wrap_kernel_dense<<>>( From f5b92d3c2b9cd6e5c93129e2a157f94101fca3a1 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Thu, 6 Oct 2022 14:40:19 -0600 Subject: [PATCH 25/33] changing name of a bool variable in _convert_future_to_regionfield method --- cunumeric/deferred.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 0a8ea22f6..82c40c2ae 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -782,9 +782,9 @@ def _broadcast(self, shape: NdShape) -> Any: return result def _convert_future_to_regionfield( - self, future: bool = False + self, change_shape: bool = False ) -> DeferredArray: - if future: + if change_shape: shape: NdShape = (1,) else: shape = self.shape From c893ee5b92d621c1962bbb96cc21559ffc5d42ff Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Fri, 7 Oct 2022 13:28:31 -0600 Subject: [PATCH 26/33] fixing the cases for scalar lhs in put operation --- cunumeric/array.py | 11 ++++++++++- cunumeric/deferred.py | 36 ++++++++++++++++++++--------------- tests/integration/test_put.py | 18 ++++++++++++++++++ 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 3755e4bd2..ce8e92b0b 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2475,7 +2475,7 @@ def put( """ - if values.size == 0 or indices.size == 0: + if values.size == 0 or indices.size == 0 or self.size == 0: return if mode not in ("raise", "wrap", "clip"): @@ -2490,10 +2490,19 @@ def put( indices = indices.clip(0, self.size - 1) indices = indices._warn_and_convert(np.dtype(np.int64)) + values = values._warn_and_convert(self.dtype) if indices.ndim > 1: indices = indices.ravel() + if self.shape == (): + if values.shape == (): + v = values + else: + v = values[0] + self._thunk.copy(v._thunk, deep=False) + return + # call _wrap on the values if they need to be wrapped if values.ndim != indices.ndim or values.size != indices.size: values = values._wrap(indices.size) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 82c40c2ae..0d377df24 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -784,7 +784,7 @@ def _broadcast(self, shape: NdShape) -> Any: def _convert_future_to_regionfield( self, change_shape: bool = False ) -> DeferredArray: - if change_shape: + if change_shape and self.shape == (): shape: NdShape = (1,) else: shape = self.shape @@ -1668,18 +1668,17 @@ def _diag_helper( @auto_convert([1, 2]) def put(self, indices: Any, values: Any) -> None: + if indices.base.kind == Future or indices.base.transformed: - indices = indices._convert_future_to_regionfield( - indices.base.kind == Future - ) + change_shape = indices.base.kind == Future + indices = indices._convert_future_to_regionfield(change_shape) if values.base.kind == Future or values.base.transformed: - values = values._convert_future_to_regionfield( - values.base.kind == Future - ) + change_shape = values.base.kind == Future + values = values._convert_future_to_regionfield(change_shape) + if self.base.kind == Future or self.base.transformed: - self_tmp = self._convert_future_to_regionfield( - self.base.kind == Future - ) + change_shape = self.base.kind == Future + self_tmp = self._convert_future_to_regionfield(change_shape) else: self_tmp = self @@ -1699,15 +1698,16 @@ def put(self, indices: Any, values: Any) -> None: ), ) + shape = self_tmp.shape task = self.context.create_task(CuNumericOpCode.WRAP) task.add_output(indirect.base) - task.add_scalar_arg(self_tmp.shape, (ty.int64,)) + task.add_scalar_arg(shape, (ty.int64,)) task.add_scalar_arg(True, bool) # has_input task.add_input(indices.base) task.add_alignment(indices.base, indirect.base) task.throws_exception(IndexError) task.execute() - if indirect.base.kind == Future or indirect.base.transformed: + if indirect.base.kind == Future: indirect = indirect._convert_future_to_regionfield() copy = self.context.create_copy() @@ -2937,8 +2937,13 @@ def unary_op( args: Any, multiout: Optional[Any] = None, ) -> None: - lhs = self.base - rhs = src._broadcast(lhs.shape) + + if self.shape == () and self.size == src.size: + lhs = self._broadcast(src.shape) + rhs = src.base + else: + lhs = self.base + rhs = src._broadcast(lhs.shape) task = self.context.create_auto_task(CuNumericOpCode.UNARY_OP) task.add_output(lhs) @@ -3397,7 +3402,8 @@ def unpackbits( def _wrap(self, src: DeferredArray, new_len: int) -> None: src = self.runtime.to_deferred_array(src) if src.base.kind == Future or src.base.transformed: - src = src._convert_future_to_regionfield(src.base.kind == Future) + change_shape = src.base.kind == Future + src = src._convert_future_to_regionfield(change_shape) # first, we create indirect array with PointN type that # (len,) shape and is used to copy data from original array diff --git a/tests/integration/test_put.py b/tests/integration/test_put.py index 80cfb089a..0cd6268c7 100644 --- a/tests/integration/test_put.py +++ b/tests/integration/test_put.py @@ -58,6 +58,24 @@ def test_scalar(mode): num.put(x_num, indices_num, 10) assert np.array_equal(x_num, x) + x = np.zeros(()) + x_num = num.zeros(()) + np.put(x, 0, 1) + num.put(x_num, 0, 1) + assert np.array_equal(x_num, x) + + x = np.zeros(()) + x_num = num.zeros(()) + np.put(x, [0], 1) + num.put(x_num, [0], 1) + assert np.array_equal(x_num, x) + + x = np.zeros(()) + x_num = num.zeros(()) + np.put(x, [0], [1]) + num.put(x_num, [0], [1]) + assert np.array_equal(x_num, x) + def test_indices_type_convert(): x = mk_seq_array(np, (3, 4, 5)) From 9edae6294aabbfb5a183e4bd41d32cdfef288663 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Fri, 7 Oct 2022 15:05:38 -0600 Subject: [PATCH 27/33] fixing out of the bounds check for ZIP cuda kernel --- src/cunumeric/index/zip.cu | 98 ++++++++++++++++++++++++++++++-------- src/cunumeric/index/zip.h | 6 +++ 2 files changed, 83 insertions(+), 21 deletions(-) diff --git a/src/cunumeric/index/zip.cu b/src/cunumeric/index/zip.cu index 8bdfcd3f0..a245a9f00 100644 --- a/src/cunumeric/index/zip.cu +++ b/src/cunumeric/index/zip.cu @@ -28,15 +28,15 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) const Buffer, 1> index_arrays, const Rect rect, const Pitches pitches, - size_t volume, - DomainPoint shape, + const size_t volume, + const DomainPoint shape, std::index_sequence) { const size_t idx = global_tid_1d(); if (idx >= volume) return; auto p = pitches.unflatten(idx, rect.lo); Legion::Point new_point; - for (size_t i = 0; i < N; i++) { new_point[i] = compute_idx(index_arrays[i][p], shape[i]); } + for (size_t i = 0; i < N; i++) { new_point[i] = compute_idx_cuda(index_arrays[i][p], shape[i]); } out[p] = new_point; } @@ -45,14 +45,16 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) zip_kernel_dense(Point* out, const Buffer index_arrays, const Rect rect, - size_t volume, - DomainPoint shape, + const size_t volume, + const DomainPoint shape, std::index_sequence) { const size_t idx = global_tid_1d(); if (idx >= volume) return; Legion::Point new_point; - for (size_t i = 0; i < N; i++) { new_point[i] = compute_idx(index_arrays[i][idx], shape[i]); } + for (size_t i = 0; i < N; i++) { + new_point[i] = compute_idx_cuda(index_arrays[i][idx], shape[i]); + } out[idx] = new_point; } @@ -62,11 +64,11 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) const Buffer, 1> index_arrays, const Rect rect, const Pitches pitches, - int narrays, - size_t volume, - int64_t key_dim, - int64_t start_index, - DomainPoint shape) + const int64_t narrays, + const size_t volume, + const int64_t key_dim, + const int64_t start_index, + const DomainPoint shape) { const size_t idx = global_tid_1d(); if (idx >= volume) return; @@ -74,7 +76,7 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) Legion::Point new_point; for (size_t i = 0; i < start_index; i++) { new_point[i] = p[i]; } for (size_t i = 0; i < narrays; i++) { - new_point[start_index + i] = compute_idx(index_arrays[i][p], shape[start_index + i]); + new_point[start_index + i] = compute_idx_cuda(index_arrays[i][p], shape[start_index + i]); } for (size_t i = (start_index + narrays); i < N; i++) { int64_t j = key_dim + i - narrays; @@ -83,10 +85,63 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[p] = new_point; } +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + check_kernel(Output out, + const Buffer, 1> index_arrays, + const int64_t volume, + const int64_t iters, + const Rect rect, + const Pitches pitches, + const int64_t narrays, + const int64_t start_index, + const DomainPoint shape) +{ + bool value = false; + for (size_t i = 0; i < iters; i++) { + const auto idx = (i * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= volume) break; + auto p = pitches.unflatten(idx, rect.lo); + for (size_t n = 0; n < narrays; n++) { + const int64_t extent = shape[start_index + n]; + coord_t index = index_arrays[n][p] < 0 ? index_arrays[n][p] + extent : index_arrays[n][p]; + bool val = (index < 0 || index > extent); + SumReduction::fold(value, val); + } // for n + } + reduce_output(out, value); +} + template struct ZipImplBody { using VAL = int64_t; + void check_out_of_bounds(const Buffer, 1>& index_arrays, + const int64_t volume, + const Rect& rect, + const Pitches& pitches, + const int64_t narrays, + const int64_t start_index, + const DomainPoint& shape, + cudaStream_t stream) const + { + const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(int64_t); + DeviceScalarReductionBuffer> out_of_bounds(stream); + if (blocks >= MAX_REDUCTION_CTAS) { + const size_t iters = (blocks + MAX_REDUCTION_CTAS - 1) / MAX_REDUCTION_CTAS; + check_kernel<<>>( + out_of_bounds, index_arrays, volume, iters, rect, pitches, narrays, start_index, shape); + } else { + check_kernel<<>>( + out_of_bounds, index_arrays, volume, 1, rect, pitches, narrays, start_index, shape); + } + CHECK_CUDA_STREAM(stream); + + bool res = out_of_bounds.read(stream); + if (res) throw legate::TaskException("index is out of bounds in index array"); + } + template void operator()(const AccessorWO, DIM>& out, const std::vector>& index_arrays, @@ -101,19 +156,23 @@ struct ZipImplBody { auto stream = get_cached_stream(); const size_t volume = rect.volume(); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + + auto index_buf = + create_buffer, 1>(index_arrays.size(), Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < index_arrays.size(); ++idx) index_buf[idx] = index_arrays[idx]; + check_out_of_bounds( + index_buf, volume, rect, pitches, index_arrays.size(), start_index, shape, stream); + if (index_arrays.size() == N) { if (dense) { - auto index_buf = + auto index_buf_dense = create_buffer(index_arrays.size(), Memory::Kind::Z_COPY_MEM); for (uint32_t idx = 0; idx < index_arrays.size(); ++idx) { - index_buf[idx] = index_arrays[idx].ptr(rect); + index_buf_dense[idx] = index_arrays[idx].ptr(rect); } zip_kernel_dense<<>>( - out.ptr(rect), index_buf, rect, volume, shape, std::make_index_sequence()); + out.ptr(rect), index_buf_dense, rect, volume, shape, std::make_index_sequence()); } else { - auto index_buf = - create_buffer, 1>(index_arrays.size(), Memory::Kind::Z_COPY_MEM); - for (uint32_t idx = 0; idx < index_arrays.size(); ++idx) index_buf[idx] = index_arrays[idx]; zip_kernel<<>>( out, index_buf, rect, pitches, volume, shape, std::make_index_sequence()); } @@ -121,9 +180,6 @@ struct ZipImplBody { #ifdef DEBUG_CUNUMERIC assert(index_arrays.size() < N); #endif - auto index_buf = - create_buffer, 1>(index_arrays.size(), Memory::Kind::Z_COPY_MEM); - for (uint32_t idx = 0; idx < index_arrays.size(); ++idx) index_buf[idx] = index_arrays[idx]; int num_arrays = index_arrays.size(); zip_kernel<<>>( out, index_buf, rect, pitches, num_arrays, volume, key_dim, start_index, shape); diff --git a/src/cunumeric/index/zip.h b/src/cunumeric/index/zip.h index 61a87104c..ffa5941d5 100644 --- a/src/cunumeric/index/zip.h +++ b/src/cunumeric/index/zip.h @@ -51,4 +51,10 @@ constexpr coord_t compute_idx(coord_t index, coord_t extent) return new_index; } +constexpr coord_t compute_idx_cuda(coord_t index, coord_t extent) +{ + coord_t new_index = index < 0 ? index + extent : index; + return new_index; +} + } // namespace cunumeric From 7f59b15d37e815b7036c37db3d85b1e7bc73afb2 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Fri, 7 Oct 2022 15:18:46 -0600 Subject: [PATCH 28/33] fixing logic for negative indices --- src/cunumeric/index/wrap.cu | 5 +++-- src/cunumeric/index/wrap.h | 9 ++++++--- tests/integration/test_put.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/cunumeric/index/wrap.cu b/src/cunumeric/index/wrap.cu index 1e60d1367..47ae084b3 100644 --- a/src/cunumeric/index/wrap.cu +++ b/src/cunumeric/index/wrap.cu @@ -36,8 +36,9 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) for (size_t i = 0; i < iters; i++) { const auto idx = (i * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; if (idx >= volume) break; - auto index = indices[idx + start]; - bool val = (index < 0 || index >= in_volume); + auto index_tmp = indices[idx + start]; + int64_t index = index_tmp < 0 ? index_tmp + in_volume : index_tmp; + bool val = (index < 0 || index >= in_volume); SumReduction::fold(value, val); } reduce_output(out, value); diff --git a/src/cunumeric/index/wrap.h b/src/cunumeric/index/wrap.h index aa91c827d..181a9b97c 100644 --- a/src/cunumeric/index/wrap.h +++ b/src/cunumeric/index/wrap.h @@ -49,17 +49,20 @@ __CUDA_HD__ static int64_t compute_idx(const int64_t i, const int64_t volume, co } __CUDA_HD__ static int64_t compute_idx(const int64_t i, - const int64_t, + const int64_t volume, const legate::AccessorRO& indices) { - return indices[i]; + int64_t idx = indices[i]; + int64_t index = idx < 0 ? idx + volume : idx; + return index; } static void check_idx(const int64_t i, const int64_t volume, const legate::AccessorRO& indices) { - int64_t index = indices[i]; + int64_t idx = indices[i]; + int64_t index = idx < 0 ? idx + volume : idx; if (index < 0 || index >= volume) throw legate::TaskException("index is out of bounds in index array"); } diff --git a/tests/integration/test_put.py b/tests/integration/test_put.py index 0cd6268c7..1c69a705b 100644 --- a/tests/integration/test_put.py +++ b/tests/integration/test_put.py @@ -82,7 +82,7 @@ def test_indices_type_convert(): x_num = mk_seq_array(num, (3, 4, 5)) values = mk_seq_array(np, (6,)) * 10 values_num = num.array(values) - indices = np.array([1, 2], dtype=np.int32) + indices = np.array([-2, 2], dtype=np.int32) indices_num = num.array(indices) np.put(x, indices, values) num.put(x_num, indices_num, values_num) From 67aa0adda15977af9a43e222619d88f1fb4c8582 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Thu, 13 Oct 2022 16:50:14 -0700 Subject: [PATCH 29/33] Update a leftover use of auto_convert --- cunumeric/deferred.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 1be5a2702..85a0fd1f5 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -1685,7 +1685,7 @@ def _diag_helper( task.execute() - @auto_convert([1, 2]) + @auto_convert("indices", "values") def put(self, indices: Any, values: Any) -> None: if indices.base.kind == Future or indices.base.transformed: From 88cca6dfc80d5a1b8f2f0ebeb75998f8dff3bc65 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Thu, 13 Oct 2022 23:07:42 -0600 Subject: [PATCH 30/33] addressing PR comments --- cunumeric/array.py | 2 ++ src/cunumeric/index/advanced_indexing.cu | 2 +- src/cunumeric/index/zip.cu | 4 ++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 5ff7eeda9..c03f0542c 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2508,6 +2508,8 @@ def put( indices = indices.ravel() if self.shape == (): + if indices.any() is True: + raise ValueError("Indices out of bounds") if values.shape == (): v = values else: diff --git a/src/cunumeric/index/advanced_indexing.cu b/src/cunumeric/index/advanced_indexing.cu index fde5590fd..706c8749d 100644 --- a/src/cunumeric/index/advanced_indexing.cu +++ b/src/cunumeric/index/advanced_indexing.cu @@ -94,7 +94,7 @@ struct AdvancedIndexingImplBody { const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(int64_t); + size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(bool); if (blocks >= MAX_REDUCTION_CTAS) { const size_t iters = (blocks + MAX_REDUCTION_CTAS - 1) / MAX_REDUCTION_CTAS; diff --git a/src/cunumeric/index/zip.cu b/src/cunumeric/index/zip.cu index a245a9f00..82d162126 100644 --- a/src/cunumeric/index/zip.cu +++ b/src/cunumeric/index/zip.cu @@ -105,7 +105,7 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) for (size_t n = 0; n < narrays; n++) { const int64_t extent = shape[start_index + n]; coord_t index = index_arrays[n][p] < 0 ? index_arrays[n][p] + extent : index_arrays[n][p]; - bool val = (index < 0 || index > extent); + bool val = (index < 0 || index >= extent); SumReduction::fold(value, val); } // for n } @@ -126,7 +126,7 @@ struct ZipImplBody { cudaStream_t stream) const { const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(int64_t); + size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(bool); DeviceScalarReductionBuffer> out_of_bounds(stream); if (blocks >= MAX_REDUCTION_CTAS) { const size_t iters = (blocks + MAX_REDUCTION_CTAS - 1) / MAX_REDUCTION_CTAS; From 92cf417586d18ee54d0fc68258419d6d482298a8 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 17 Oct 2022 14:02:17 -0600 Subject: [PATCH 31/33] fixing logic for the bounds check --- cunumeric/array.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index c03f0542c..c6791f91a 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2508,8 +2508,9 @@ def put( indices = indices.ravel() if self.shape == (): - if indices.any() is True: - raise ValueError("Indices out of bounds") + if mode == "raise": + if indices.min() < 0 or indices.max() > 0: + raise ValueError("Indices out of bounds") if values.shape == (): v = values else: From e489d565255c4e6d3d1ce0424b6e5fe44613e6bd Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 17 Oct 2022 22:17:27 -0600 Subject: [PATCH 32/33] addressing PR comments --- cunumeric/array.py | 2 +- src/cunumeric/index/wrap.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index c6791f91a..cd14eda7c 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2509,7 +2509,7 @@ def put( if self.shape == (): if mode == "raise": - if indices.min() < 0 or indices.max() > 0: + if indices.min() < -1 or indices.max() > 0: raise ValueError("Indices out of bounds") if values.shape == (): v = values diff --git a/src/cunumeric/index/wrap.cu b/src/cunumeric/index/wrap.cu index 47ae084b3..af81073d6 100644 --- a/src/cunumeric/index/wrap.cu +++ b/src/cunumeric/index/wrap.cu @@ -97,7 +97,7 @@ void check_out_of_bounds(const AccessorRO& indices, cudaStream_t stream) { const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(int64_t); + size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(bool); DeviceScalarReductionBuffer> out_of_bounds(stream); if (blocks >= MAX_REDUCTION_CTAS) { From 58cd174075fc424a5520df940e8e03b310d6c0d8 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 18 Oct 2022 14:32:02 -0600 Subject: [PATCH 33/33] addressing PR comments --- cunumeric/deferred.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 85a0fd1f5..3bb5c4db7 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -2957,12 +2957,8 @@ def unary_op( multiout: Optional[Any] = None, ) -> None: - if self.shape == () and self.size == src.size: - lhs = self._broadcast(src.shape) - rhs = src.base - else: - lhs = self.base - rhs = src._broadcast(lhs.shape) + lhs = self.base + rhs = src._broadcast(lhs.shape) task = self.context.create_auto_task(CuNumericOpCode.UNARY_OP) task.add_output(lhs)