Skip to content

Commit

Permalink
Add any, all, cbrt, linalg.solve
Browse files Browse the repository at this point in the history
  • Loading branch information
jthielen committed Dec 30, 2019
1 parent f7efb18 commit 97e9a86
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/numpy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@
"\n",
"The following [ufuncs](http://docs.scipy.org/doc/numpy/reference/ufuncs.html) can be applied to a Quantity object:\n",
"\n",
"- **Math operations**: `add`, `subtract`, `multiply`, `divide`, `logaddexp`, `logaddexp2`, `true_divide`, `floor_divide`, `negative`, `remainder`, `mod`, `fmod`, `absolute`, `rint`, `sign`, `conj`, `exp`, `exp2`, `log`, `log2`, `log10`, `expm1`, `log1p`, `sqrt`, `square`, `reciprocal`\n",
"- **Math operations**: `add`, `subtract`, `multiply`, `divide`, `logaddexp`, `logaddexp2`, `true_divide`, `floor_divide`, `negative`, `remainder`, `mod`, `fmod`, `absolute`, `rint`, `sign`, `conj`, `exp`, `exp2`, `log`, `log2`, `log10`, `expm1`, `log1p`, `sqrt`, `square`, `cbrt`, `reciprocal`\n",
"- **Trigonometric functions**: `sin`, `cos`, `tan`, `arcsin`, `arccos`, `arctan`, `arctan2`, `hypot`, `sinh`, `cosh`, `tanh`, `arcsinh`, `arccosh`, `arctanh`\n",
"- **Comparison functions**: `greater`, `greater_equal`, `less`, `less_equal`, `not_equal`, `equal`\n",
"- **Floating functions**: `isreal`, `iscomplex`, `isfinite`, `isinf`, `isnan`, `signbit`, `copysign`, `nextafter`, `modf`, `ldexp`, `frexp`, `fmod`, `floor`, `ceil`, `trunc`\n",
Expand All @@ -301,7 +301,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"['alen', 'amax', 'amin', 'append', 'argmax', 'argmin', 'argsort', 'around', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'average', 'block', 'broadcast_to', 'clip', 'column_stack', 'compress', 'concatenate', 'copy', 'copyto', 'count_nonzero', 'cross', 'cumprod', 'cumproduct', 'cumsum', 'diagonal', 'diff', 'dot', 'dstack', 'ediff1d', 'einsum', 'empty_like', 'expand_dims', 'fix', 'flip', 'full_like', 'gradient', 'hstack', 'insert', 'interp', 'isclose', 'iscomplex', 'isin', 'isreal', 'linspace', 'mean', 'median', 'meshgrid', 'moveaxis', 'nan_to_num', 'nanargmax', 'nanargmin', 'nancumprod', 'nancumsum', 'nanmax', 'nanmean', 'nanmedian', 'nanmin', 'nanpercentile', 'nanstd', 'nansum', 'nanvar', 'ndim', 'nonzero', 'ones_like', 'pad', 'percentile', 'ptp', 'ravel', 'resize', 'result_type', 'rollaxis', 'rot90', 'round_', 'searchsorted', 'shape', 'size', 'sort', 'squeeze', 'stack', 'std', 'sum', 'swapaxes', 'tile', 'transpose', 'trapz', 'trim_zeros', 'unwrap', 'var', 'vstack', 'where', 'zeros_like']\n"
"['alen', 'all', 'amax', 'amin', 'any', 'append', 'argmax', 'argmin', 'argsort', 'around', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'average', 'block', 'broadcast_to', 'clip', 'column_stack', 'compress', 'concatenate', 'copy', 'copyto', 'count_nonzero', 'cross', 'cumprod', 'cumproduct', 'cumsum', 'diagonal', 'diff', 'dot', 'dstack', 'ediff1d', 'einsum', 'empty_like', 'expand_dims', 'fix', 'flip', 'full_like', 'gradient', 'hstack', 'insert', 'interp', 'isclose', 'iscomplex', 'isin', 'isreal', 'linalg.solve', 'linspace', 'mean', 'median', 'meshgrid', 'moveaxis', 'nan_to_num', 'nanargmax', 'nanargmin', 'nancumprod', 'nancumsum', 'nanmax', 'nanmean', 'nanmedian', 'nanmin', 'nanpercentile', 'nanstd', 'nansum', 'nanvar', 'ndim', 'nonzero', 'ones_like', 'pad', 'percentile', 'ptp', 'ravel', 'resize', 'result_type', 'rollaxis', 'rot90', 'round_', 'searchsorted', 'shape', 'size', 'sort', 'squeeze', 'stack', 'std', 'sum', 'swapaxes', 'tile', 'transpose', 'trapz', 'trim_zeros', 'unwrap', 'var', 'vstack', 'where', 'zeros_like']\n"
]
}
],
Expand Down
38 changes: 35 additions & 3 deletions pint/numpy_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None):
result_unit = first_input_units ** 2
elif unit_op == "sqrt":
result_unit = first_input_units ** 0.5
elif unit_op == "cbrt":
result_unit = first_input_units ** (1 / 3)
elif unit_op == "reciprocal":
result_unit = first_input_units ** -1
elif unit_op == "size":
Expand Down Expand Up @@ -255,7 +257,11 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None):
if np is None:
return

func = getattr(np, func_str)
# Handle functions in submodules
func_str_split = func_str.split(".")
func = getattr(np, func_str_split[0])
for func_str_piece in func_str_split[1:]:
func = getattr(func, func_str_piece)

@implements(func_str, func_type)
def implementation(*args, **kwargs):
Expand Down Expand Up @@ -295,6 +301,7 @@ def implementation(*args, **kwargs):
"variance",
"square",
"sqrt",
"cbrt",
"reciprocal",
"size",
]:
Expand Down Expand Up @@ -408,6 +415,7 @@ def implementation(*args, **kwargs):
"divide": "div",
"floor_divide": "div",
"sqrt": "sqrt",
"cbrt": "cbrt",
"square": "square",
"reciprocal": "reciprocal",
"std": "sum",
Expand Down Expand Up @@ -665,6 +673,24 @@ def _recursive_convert(arg, unit):
)


@implements("any", "function")
def _any(a, *args, **kwargs):
# Only valid when multiplicative unit/no offset
if a._is_multiplicative:
return np.any(a._magnitude, *args, **kwargs)
else:
raise ValueError("Boolean value of Quantity with offset unit is ambiguous.")


@implements("all", "function")
def _all(a, *args, **kwargs):
# Only valid when multiplicative unit/no offset
if a._is_multiplicative:
return np.all(a._magnitude, *args, **kwargs)
else:
raise ValueError("Boolean value of Quantity with offset unit is ambiguous.")


# Implement simple matching-unit or stripped-unit functions based on signature


Expand Down Expand Up @@ -836,6 +862,8 @@ def implementation(a, *args, **kwargs):
implement_func("function", func_str, input_units=None, output_unit="delta")
for func_str in ["gradient"]:
implement_func("function", func_str, input_units=None, output_unit="delta,div")
for func_str in ["linalg.solve"]:
implement_func("function", func_str, input_units=None, output_unit="div")
for func_str in ["var", "nanvar"]:
implement_func("function", func_str, input_units=None, output_unit="variance")

Expand All @@ -846,11 +874,15 @@ def numpy_wrap(func_type, func, args, kwargs, types):

if func_type == "function":
handled = HANDLED_FUNCTIONS
# Need to handle functions in submodules
name = ".".join(func.__module__.split(".")[1:] + [func.__name__])
elif func_type == "ufunc":
handled = HANDLED_UFUNCS
# ufuncs do not have func.__module__
name = func.__name__
else:
raise ValueError("Invalid func_type {}".format(func_type))

if func.__name__ not in handled or any(is_upcast_type(t) for t in types):
if name not in handled or any(is_upcast_type(t) for t in types):
return NotImplemented
return handled[func.__name__](*args, **kwargs)
return handled[name](*args, **kwargs)
27 changes: 27 additions & 0 deletions pint/testsuite/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,13 @@ def test_einsum(self):
np.array([30, 80, 130, 180, 230]) * self.ureg.m ** 2,
)

@helpers.requires_array_function_protocol()
def test_solve(self):
self.assertQuantityAlmostEqual(
np.linalg.solve(self.q, [[3], [7]] * self.ureg.s),
self.Q_([[1], [1]], "m / s"),
)

# Arithmetic operations
def test_addition_with_scalar(self):
a = np.array([0, 1, 2])
Expand Down Expand Up @@ -414,6 +421,14 @@ def test_power(self):
)
self.assertNDArrayEqual(arr ** self.Q_(2), np.array([0, 1, 4]))

def test_sqrt(self):
q = self.Q_(100, "m**2")
self.assertQuantityEqual(np.sqrt(q), self.Q_(10, "m"))

def test_cbrt(self):
q = self.Q_(1000, "m**3")
self.assertQuantityEqual(np.cbrt(q), self.Q_(10, "m"))

@unittest.expectedFailure
@helpers.requires_numpy()
def test_exponentiation_array_exp_2(self):
Expand Down Expand Up @@ -537,6 +552,18 @@ def test_nonzero_numpy_func(self):
q = [1, 0, 5, 6, 0, 9] * self.ureg.m
self.assertNDArrayEqual(np.nonzero(q)[0], [0, 2, 3, 5])

@helpers.requires_array_function_protocol()
def test_any_numpy_func(self):
q = [0, 1] * self.ureg.m
self.assertTrue(np.any(q))
self.assertRaises(ValueError, np.any, self.q_temperature)

@helpers.requires_array_function_protocol()
def test_all_numpy_func(self):
q = [0, 1] * self.ureg.m
self.assertFalse(np.all(q))
self.assertRaises(ValueError, np.all, self.q_temperature)

@helpers.requires_array_function_protocol()
def test_count_nonzero_numpy_func(self):
q = [1, 0, 5, 6, 0, 9] * self.ureg.m
Expand Down

0 comments on commit 97e9a86

Please sign in to comment.