diff --git a/docs/numpy.ipynb b/docs/numpy.ipynb index 02f8f3fce..646a0b3c2 100644 --- a/docs/numpy.ipynb +++ b/docs/numpy.ipynb @@ -284,7 +284,7 @@ "\n", "The following [ufuncs](http://docs.scipy.org/doc/numpy/reference/ufuncs.html) can be applied to a Quantity object:\n", "\n", - "- **Math operations**: `add`, `subtract`, `multiply`, `divide`, `logaddexp`, `logaddexp2`, `true_divide`, `floor_divide`, `negative`, `remainder`, `mod`, `fmod`, `absolute`, `rint`, `sign`, `conj`, `exp`, `exp2`, `log`, `log2`, `log10`, `expm1`, `log1p`, `sqrt`, `square`, `reciprocal`\n", + "- **Math operations**: `add`, `subtract`, `multiply`, `divide`, `logaddexp`, `logaddexp2`, `true_divide`, `floor_divide`, `negative`, `remainder`, `mod`, `fmod`, `absolute`, `rint`, `sign`, `conj`, `exp`, `exp2`, `log`, `log2`, `log10`, `expm1`, `log1p`, `sqrt`, `square`, `cbrt`, `reciprocal`\n", "- **Trigonometric functions**: `sin`, `cos`, `tan`, `arcsin`, `arccos`, `arctan`, `arctan2`, `hypot`, `sinh`, `cosh`, `tanh`, `arcsinh`, `arccosh`, `arctanh`\n", "- **Comparison functions**: `greater`, `greater_equal`, `less`, `less_equal`, `not_equal`, `equal`\n", "- **Floating functions**: `isreal`, `iscomplex`, `isfinite`, `isinf`, `isnan`, `signbit`, `copysign`, `nextafter`, `modf`, `ldexp`, `frexp`, `fmod`, `floor`, `ceil`, `trunc`\n", @@ -301,7 +301,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "['alen', 'amax', 'amin', 'append', 'argmax', 'argmin', 'argsort', 'around', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'average', 'block', 'broadcast_to', 'clip', 'column_stack', 'compress', 'concatenate', 'copy', 'copyto', 'count_nonzero', 'cross', 'cumprod', 'cumproduct', 'cumsum', 'diagonal', 'diff', 'dot', 'dstack', 'ediff1d', 'einsum', 'empty_like', 'expand_dims', 'fix', 'flip', 'full_like', 'gradient', 'hstack', 'insert', 'interp', 'isclose', 'iscomplex', 'isin', 'isreal', 'linspace', 'mean', 'median', 'meshgrid', 'moveaxis', 'nan_to_num', 'nanargmax', 'nanargmin', 'nancumprod', 'nancumsum', 'nanmax', 'nanmean', 'nanmedian', 'nanmin', 'nanpercentile', 'nanstd', 'nansum', 'nanvar', 'ndim', 'nonzero', 'ones_like', 'pad', 'percentile', 'ptp', 'ravel', 'resize', 'result_type', 'rollaxis', 'rot90', 'round_', 'searchsorted', 'shape', 'size', 'sort', 'squeeze', 'stack', 'std', 'sum', 'swapaxes', 'tile', 'transpose', 'trapz', 'trim_zeros', 'unwrap', 'var', 'vstack', 'where', 'zeros_like']\n" + "['alen', 'all', 'amax', 'amin', 'any', 'append', 'argmax', 'argmin', 'argsort', 'around', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'average', 'block', 'broadcast_to', 'clip', 'column_stack', 'compress', 'concatenate', 'copy', 'copyto', 'count_nonzero', 'cross', 'cumprod', 'cumproduct', 'cumsum', 'diagonal', 'diff', 'dot', 'dstack', 'ediff1d', 'einsum', 'empty_like', 'expand_dims', 'fix', 'flip', 'full_like', 'gradient', 'hstack', 'insert', 'interp', 'isclose', 'iscomplex', 'isin', 'isreal', 'linalg.solve', 'linspace', 'mean', 'median', 'meshgrid', 'moveaxis', 'nan_to_num', 'nanargmax', 'nanargmin', 'nancumprod', 'nancumsum', 'nanmax', 'nanmean', 'nanmedian', 'nanmin', 'nanpercentile', 'nanstd', 'nansum', 'nanvar', 'ndim', 'nonzero', 'ones_like', 'pad', 'percentile', 'ptp', 'ravel', 'resize', 'result_type', 'rollaxis', 'rot90', 'round_', 'searchsorted', 'shape', 'size', 'sort', 'squeeze', 'stack', 'std', 'sum', 'swapaxes', 'tile', 'transpose', 'trapz', 'trim_zeros', 'unwrap', 'var', 'vstack', 'where', 'zeros_like']\n" ] } ], diff --git a/pint/numpy_func.py b/pint/numpy_func.py index 1a9221caa..059ffaa6e 100644 --- a/pint/numpy_func.py +++ b/pint/numpy_func.py @@ -194,6 +194,8 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None): result_unit = first_input_units ** 2 elif unit_op == "sqrt": result_unit = first_input_units ** 0.5 + elif unit_op == "cbrt": + result_unit = first_input_units ** (1 / 3) elif unit_op == "reciprocal": result_unit = first_input_units ** -1 elif unit_op == "size": @@ -255,7 +257,11 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None): if np is None: return - func = getattr(np, func_str) + # Handle functions in submodules + func_str_split = func_str.split(".") + func = getattr(np, func_str_split[0]) + for func_str_piece in func_str_split[1:]: + func = getattr(func, func_str_piece) @implements(func_str, func_type) def implementation(*args, **kwargs): @@ -295,6 +301,7 @@ def implementation(*args, **kwargs): "variance", "square", "sqrt", + "cbrt", "reciprocal", "size", ]: @@ -408,6 +415,7 @@ def implementation(*args, **kwargs): "divide": "div", "floor_divide": "div", "sqrt": "sqrt", + "cbrt": "cbrt", "square": "square", "reciprocal": "reciprocal", "std": "sum", @@ -665,6 +673,24 @@ def _recursive_convert(arg, unit): ) +@implements("any", "function") +def _any(a, *args, **kwargs): + # Only valid when multiplicative unit/no offset + if a._is_multiplicative: + return np.any(a._magnitude, *args, **kwargs) + else: + raise ValueError("Boolean value of Quantity with offset unit is ambiguous.") + + +@implements("all", "function") +def _all(a, *args, **kwargs): + # Only valid when multiplicative unit/no offset + if a._is_multiplicative: + return np.all(a._magnitude, *args, **kwargs) + else: + raise ValueError("Boolean value of Quantity with offset unit is ambiguous.") + + # Implement simple matching-unit or stripped-unit functions based on signature @@ -836,6 +862,8 @@ def implementation(a, *args, **kwargs): implement_func("function", func_str, input_units=None, output_unit="delta") for func_str in ["gradient"]: implement_func("function", func_str, input_units=None, output_unit="delta,div") +for func_str in ["linalg.solve"]: + implement_func("function", func_str, input_units=None, output_unit="div") for func_str in ["var", "nanvar"]: implement_func("function", func_str, input_units=None, output_unit="variance") @@ -846,11 +874,15 @@ def numpy_wrap(func_type, func, args, kwargs, types): if func_type == "function": handled = HANDLED_FUNCTIONS + # Need to handle functions in submodules + name = ".".join(func.__module__.split(".")[1:] + [func.__name__]) elif func_type == "ufunc": handled = HANDLED_UFUNCS + # ufuncs do not have func.__module__ + name = func.__name__ else: raise ValueError("Invalid func_type {}".format(func_type)) - if func.__name__ not in handled or any(is_upcast_type(t) for t in types): + if name not in handled or any(is_upcast_type(t) for t in types): return NotImplemented - return handled[func.__name__](*args, **kwargs) + return handled[name](*args, **kwargs) diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py index 3f2bce18b..0c9d09022 100644 --- a/pint/testsuite/test_numpy.py +++ b/pint/testsuite/test_numpy.py @@ -374,6 +374,13 @@ def test_einsum(self): np.array([30, 80, 130, 180, 230]) * self.ureg.m ** 2, ) + @helpers.requires_array_function_protocol() + def test_solve(self): + self.assertQuantityAlmostEqual( + np.linalg.solve(self.q, [[3], [7]] * self.ureg.s), + self.Q_([[1], [1]], "m / s"), + ) + # Arithmetic operations def test_addition_with_scalar(self): a = np.array([0, 1, 2]) @@ -414,6 +421,14 @@ def test_power(self): ) self.assertNDArrayEqual(arr ** self.Q_(2), np.array([0, 1, 4])) + def test_sqrt(self): + q = self.Q_(100, "m**2") + self.assertQuantityEqual(np.sqrt(q), self.Q_(10, "m")) + + def test_cbrt(self): + q = self.Q_(1000, "m**3") + self.assertQuantityEqual(np.cbrt(q), self.Q_(10, "m")) + @unittest.expectedFailure @helpers.requires_numpy() def test_exponentiation_array_exp_2(self): @@ -537,6 +552,18 @@ def test_nonzero_numpy_func(self): q = [1, 0, 5, 6, 0, 9] * self.ureg.m self.assertNDArrayEqual(np.nonzero(q)[0], [0, 2, 3, 5]) + @helpers.requires_array_function_protocol() + def test_any_numpy_func(self): + q = [0, 1] * self.ureg.m + self.assertTrue(np.any(q)) + self.assertRaises(ValueError, np.any, self.q_temperature) + + @helpers.requires_array_function_protocol() + def test_all_numpy_func(self): + q = [0, 1] * self.ureg.m + self.assertFalse(np.all(q)) + self.assertRaises(ValueError, np.all, self.q_temperature) + @helpers.requires_array_function_protocol() def test_count_nonzero_numpy_func(self): q = [1, 0, 5, 6, 0, 9] * self.ureg.m