diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index ae29bdd9..a73b67fa 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -45,7 +45,7 @@ def dot(a, b, out=None): conv_factor = _get_conversion_factor(out, prod_units) np.dot._implementation(a.ndview, b.ndview, out=out.ndview) - if not out.units == prod_units: + if out.units != prod_units: out[:] *= conv_factor return out @@ -69,7 +69,7 @@ def outer(a, b, out=None): conv_factor = _get_conversion_factor(out, prod_units) np.outer._implementation(a.ndview, b.ndview, out=out.ndview) - if not out.units == prod_units: + if out.units != prod_units: out[:] *= conv_factor return out @@ -87,7 +87,7 @@ def matmul(a, b, out=None, **kwargs): conv_factor = _get_conversion_factor(out, prod_units) np.matmul._implementation(a.ndview, b.ndview, out=out.ndview) - if not out.units == prod_units: + if out.units != prod_units: out[:] *= conv_factor return out @@ -203,17 +203,37 @@ def _validate_units_consistency(arrs): @implements(np.concatenate) def concatenate(arrs, /, axis=0, out=None, dtype=None, casting="same_kind"): _validate_units_consistency(arrs) - if NUMPY_VERSION >= Version("1.20"): - v = np.concatenate._implementation( - [_.ndview for _ in arrs], axis=axis, out=out, dtype=dtype, casting=casting - ) + ret_units = arrs[0].units + if out is None: + if NUMPY_VERSION >= Version("1.20"): + v = np.concatenate._implementation( + [_.ndview for _ in arrs], axis=axis, dtype=dtype, casting=casting + ) + else: + v = np.concatenate._implementation( + [_.ndview for _ in arrs], + axis=axis, + ) + out = v * ret_units else: - v = np.concatenate._implementation( - [_.ndview for _ in arrs], - axis=axis, - out=out, - ) - return v * arrs[0].units + cf = _get_conversion_factor(out, ret_units) + if NUMPY_VERSION >= Version("1.20"): + np.concatenate._implementation( + [_.ndview for _ in arrs], + axis=axis, + out=out.ndview, + dtype=dtype, + casting=casting, + ) + else: + np.concatenate._implementation( + [_.ndview for _ in arrs], + axis=axis, + out=out.ndview, + ) + if out.units != ret_units: + out[:] *= cf + return out @implements(np.cross) @@ -271,7 +291,13 @@ def hstack(tup, /): @implements(np.stack) def stack(arrays, /, axis=0, out=None): _validate_units_consistency(arrays) - return ( - np.stack._implementation([_.ndview for _ in arrays], axis=axis, out=out) - * arrays[0].units - ) + ret_units = arrays[0].units + if out is None: + return ( + np.stack._implementation([_.ndview for _ in arrays], axis=axis) * ret_units + ) + cf = _get_conversion_factor(out, ret_units) + np.stack._implementation([_.ndview for _ in arrays], axis=axis, out=out.ndview) + if out.units != ret_units: + out[:] *= cf + return out