diff --git a/newsfragments/3712.added.md b/newsfragments/3712.added.md new file mode 100644 index 00000000000..d7390f77c14 --- /dev/null +++ b/newsfragments/3712.added.md @@ -0,0 +1 @@ +Added methods to `PyAnyMethods` for binary operators (`add`, `sub`, etc.) diff --git a/src/tests/common.rs b/src/tests/common.rs index 89b4a83fa1d..f2082437693 100644 --- a/src/tests/common.rs +++ b/src/tests/common.rs @@ -23,6 +23,13 @@ mod inner { }; } + #[macro_export] + macro_rules! assert_py_eq { + ($val:expr, $expected:expr) => { + assert!($val.eq($expected).unwrap()); + }; + } + #[macro_export] macro_rules! py_expect_exception { // Case1: idents & no err_msg diff --git a/src/types/any.rs b/src/types/any.rs index edc74879cb6..102099c5cd2 100644 --- a/src/types/any.rs +++ b/src/types/any.rs @@ -1208,6 +1208,58 @@ pub trait PyAnyMethods<'py> { where O: ToPyObject; + /// Computes `self + other`. + fn add(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self - other`. + fn sub(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self * other`. + fn mul(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self / other`. + fn div(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self << other`. + fn lshift(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self >> other`. + fn rshift(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self ** other % modulus` (`pow(self, other, modulus)`). + /// `py.None()` may be passed for the `modulus`. + fn pow(&self, other: O1, modulus: O2) -> PyResult> + where + O1: ToPyObject, + O2: ToPyObject; + + /// Computes `self & other`. + fn bitand(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self | other`. + fn bitor(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self ^ other`. + fn bitxor(&self, other: O) -> PyResult> + where + O: ToPyObject; + /// Determines whether this object appears callable. /// /// This is equivalent to Python's [`callable()`][1] function. @@ -1680,6 +1732,26 @@ pub trait PyAnyMethods<'py> { fn py_super(&self) -> PyResult>; } +macro_rules! implement_binop { + ($name:ident, $c_api:ident, $op:expr) => { + #[doc = concat!("Computes `self ", $op, " other`.")] + fn $name(&self, other: O) -> PyResult> + where + O: ToPyObject, + { + fn inner<'py>( + any: &Bound<'py, PyAny>, + other: Bound<'_, PyAny>, + ) -> PyResult> { + unsafe { ffi::$c_api(any.as_ptr(), other.as_ptr()).assume_owned_or_err(any.py()) } + } + + let py = self.py(); + inner(self, other.to_object(py).into_bound(py)) + } + }; +} + impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> { #[inline] fn is(&self, other: &T) -> bool { @@ -1855,6 +1927,42 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> { .and_then(|any| any.is_truthy()) } + implement_binop!(add, PyNumber_Add, "+"); + implement_binop!(sub, PyNumber_Subtract, "-"); + implement_binop!(mul, PyNumber_Multiply, "*"); + implement_binop!(div, PyNumber_TrueDivide, "/"); + implement_binop!(lshift, PyNumber_Lshift, "<<"); + implement_binop!(rshift, PyNumber_Rshift, ">>"); + implement_binop!(bitand, PyNumber_And, "&"); + implement_binop!(bitor, PyNumber_Or, "|"); + implement_binop!(bitxor, PyNumber_Xor, "^"); + + /// Computes `self ** other % modulus` (`pow(self, other, modulus)`). + /// `py.None()` may be passed for the `modulus`. + fn pow(&self, other: O1, modulus: O2) -> PyResult> + where + O1: ToPyObject, + O2: ToPyObject, + { + fn inner<'py>( + any: &Bound<'py, PyAny>, + other: Bound<'_, PyAny>, + modulus: Bound<'_, PyAny>, + ) -> PyResult> { + unsafe { + ffi::PyNumber_Power(any.as_ptr(), other.as_ptr(), modulus.as_ptr()) + .assume_owned_or_err(any.py()) + } + } + + let py = self.py(); + inner( + self, + other.to_object(py).into_bound(py), + modulus.to_object(py).into_bound(py), + ) + } + fn is_callable(&self) -> bool { unsafe { ffi::PyCallable_Check(self.as_ptr()) != 0 } } diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index 86078080176..456d21a3b62 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -178,6 +178,10 @@ impl BinaryArithmetic { format!("BA * {:?}", rhs) } + fn __truediv__(&self, rhs: &PyAny) -> String { + format!("BA / {:?}", rhs) + } + fn __lshift__(&self, rhs: &PyAny) -> String { format!("BA << {:?}", rhs) } @@ -233,6 +237,18 @@ fn binary_arithmetic() { py_expect_exception!(py, c, "1 ** c", PyTypeError); py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'"); + + let c: Bound<'_, PyAny> = c.extract().unwrap(); + assert_py_eq!(c.add(&c).unwrap(), "BA + BA"); + assert_py_eq!(c.sub(&c).unwrap(), "BA - BA"); + assert_py_eq!(c.mul(&c).unwrap(), "BA * BA"); + assert_py_eq!(c.div(&c).unwrap(), "BA / BA"); + assert_py_eq!(c.lshift(&c).unwrap(), "BA << BA"); + assert_py_eq!(c.rshift(&c).unwrap(), "BA >> BA"); + assert_py_eq!(c.bitand(&c).unwrap(), "BA & BA"); + assert_py_eq!(c.bitor(&c).unwrap(), "BA | BA"); + assert_py_eq!(c.bitxor(&c).unwrap(), "BA ^ BA"); + assert_py_eq!(c.pow(&c, py.None()).unwrap(), "BA ** BA (mod: None)"); }); }