diff --git a/cunumeric/_ufunc/ufunc.py b/cunumeric/_ufunc/ufunc.py index 47ae475b4..d7ad95366 100644 --- a/cunumeric/_ufunc/ufunc.py +++ b/cunumeric/_ufunc/ufunc.py @@ -264,7 +264,10 @@ def _maybe_convert_output_to_cunumeric_ndarray( raise TypeError("return arrays must be of ArrayType") def _prepare_operands( - self, *args: Any, out: Union[ndarray, None], where: bool = True + self, + *args: Any, + out: Union[ndarray, tuple[ndarray, ...], None], + where: bool = True, ) -> tuple[ Sequence[ndarray], Sequence[Union[ndarray, None]], @@ -295,6 +298,8 @@ def _prepare_operands( computed_out = (None,) * self.nout elif not isinstance(out, tuple): computed_out = (out,) + else: + computed_out = out outputs = tuple( self._maybe_convert_output_to_cunumeric_ndarray(arr) @@ -469,7 +474,7 @@ def _resolve_dtype( def __call__( self, *args: Any, - out: Union[ndarray, None] = None, + out: Union[ndarray, tuple[ndarray, ...], None] = None, where: bool = True, casting: CastingKind = "same_kind", order: str = "K", diff --git a/tests/integration/test_floating.py b/tests/integration/test_floating.py index 837c50a06..9018dba14 100644 --- a/tests/integration/test_floating.py +++ b/tests/integration/test_floating.py @@ -141,4 +141,4 @@ def test_typing_unary(fun, dtype, shape): if __name__ == "__main__": import sys - pytest.main(sys.argv) + sys.exit(pytest.main(sys.argv))