diff --git a/dpctl/tensor/_utility_functions.py b/dpctl/tensor/_utility_functions.py index 69a1a200df..aab32327f2 100644 --- a/dpctl/tensor/_utility_functions.py +++ b/dpctl/tensor/_utility_functions.py @@ -36,10 +36,13 @@ def _boolean_reduction(x, axis, keepdims, func): res_usm_type = x.usm_type wait_list = [] + # always allocate the temporary as + # int32 and usm-device to ensure that atomic updates + # are supported res_tmp = dpt.empty( res_shape, dtype=dpt.int32, - usm_type=res_usm_type, + usm_type="device", sycl_queue=exec_q, ) hev0, ev0 = func(