diff --git a/CHANGELOG.md b/CHANGELOG.md index fe182a0f027..15f6e1b0048 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add FFI definitions `Py_fstring_input`, `sendfunc`, and `_PyErr_StackItem`. [#2423](https://github.com/PyO3/pyo3/pull/2423) - Add `PyDateTime::new_with_fold`, `PyTime::new_with_fold`, `PyTime::get_fold`, `PyDateTime::get_fold` for PyPy. [#2428](https://github.com/PyO3/pyo3/pull/2428) - Allow `#[classattr]` take `Python` argument. [#2383](https://github.com/PyO3/pyo3/issues/2383) +- Add `CompareOp::convert` to easily implement `__richcmp__` as the result of a + Rust `std::cmp::Ordering` comparison. [#2460](https://github.com/PyO3/pyo3/pull/2460) ### Changed diff --git a/guide/src/class/object.md b/guide/src/class/object.md index ade91da8463..12a8d2ccc2e 100644 --- a/guide/src/class/object.md +++ b/guide/src/class/object.md @@ -128,15 +128,15 @@ impl Number { Unlike in Python, PyO3 does not provide the magic comparison methods you might expect like `__eq__`, `__lt__` and so on. Instead you have to implement all six operations at once with `__richcmp__`. This method will be called with a value of `CompareOp` depending on the operation. - + ```rust use pyo3::class::basic::CompareOp; # use pyo3::prelude::*; -# +# # #[pyclass] # struct Number(i32); -# +# #[pymethods] impl Number { fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { @@ -152,6 +152,28 @@ impl Number { } ``` +If you obtain the result by comparing two Rust values, as in this example, you +can take a shortcut using `CompareOp::convert`: + +```rust +use pyo3::class::basic::CompareOp; + +# use pyo3::prelude::*; +# +# #[pyclass] +# struct Number(i32); +# +#[pymethods] +impl Number { + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.convert(self.0.cmp(&other.0)) + } +} +``` + +It converts the `std::cmp::Ordering` obtained from Rust's `Ord` class to the +required result for the given `CompareOp`. + ### Truthyness We'll consider `Number` to be `True` if it is nonzero: @@ -229,4 +251,4 @@ fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> { [`Hash`]: https://doc.rust-lang.org/std/hash/trait.Hash.html [`Hasher`]: https://doc.rust-lang.org/std/hash/trait.Hasher.html [`DefaultHasher`]: https://doc.rust-lang.org/std/collections/hash_map/struct.DefaultHasher.html -[SipHash]: https://en.wikipedia.org/wiki/SipHash \ No newline at end of file +[SipHash]: https://en.wikipedia.org/wiki/SipHash diff --git a/guide/src/class/protocols.md b/guide/src/class/protocols.md index 716f45d51f6..54d66c9ccdd 100644 --- a/guide/src/class/protocols.md +++ b/guide/src/class/protocols.md @@ -70,8 +70,11 @@ given signatures should be interpreted as follows:
Return type The return type will normally be `PyResult`, but any Python object can be returned. - If the `object` is not of the type specified in the signature, the generated code will - automatically `return NotImplemented`. + If the second argument `object` is not of the type specified in the + signature, the generated code will automatically `return NotImplemented`. + + You can use [`CompareOp::convert`] to adapt a Rust `std::cmp::Ordering` result + to the requested comparison.
- `__getattr__(, object) -> object` @@ -611,3 +614,4 @@ For details, look at the `#[pymethods]` regarding GC methods. [`PySequenceProtocol`]: {{#PYO3_DOCS_URL}}/pyo3/class/sequence/trait.PySequenceProtocol.html [`PyIterProtocol`]: {{#PYO3_DOCS_URL}}/pyo3/class/iter/trait.PyIterProtocol.html [`PySequence`]: {{#PYO3_DOCS_URL}}/pyo3/types/struct.PySequence.html +[`CompareOp::convert`]: {{#PYO3_DOCS_URL}}/pyo3/pyclass/enum.CompareOp.html#method.convert diff --git a/src/pyclass.rs b/src/pyclass.rs index 5596079ffd9..c543eebdc54 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -11,6 +11,7 @@ use crate::{ IntoPy, IntoPyPointer, PyCell, PyErr, PyMethodDefType, PyObject, PyResult, PyTypeInfo, Python, }; use std::{ + cmp::Ordering, convert::TryInto, ffi::{CStr, CString}, os::raw::{c_char, c_int, c_uint, c_void}, @@ -457,6 +458,7 @@ pub enum CompareOp { } impl CompareOp { + /// Conversion from the C enum. pub fn from_raw(op: c_int) -> Option { match op { ffi::Py_LT => Some(CompareOp::Lt), @@ -468,6 +470,38 @@ impl CompareOp { _ => None, } } + + /// Converts a Rust [`std::cmp::Ordering`] into the correct Boolean result for + /// this ordering query. + /// + /// Usage example: + /// + /// ```rust + /// # use pyo3::prelude::*; + /// # use pyo3::class::basic::CompareOp; + /// + /// #[pyclass] + /// struct Size { + /// size: usize + /// } + /// + /// #[pymethods] + /// impl Size { + /// fn __richcmp__(&self, other: &Size, op: CompareOp) -> bool { + /// op.convert(self.size.cmp(&other.size)) + /// } + /// } + /// ``` + pub fn convert(&self, result: Ordering) -> bool { + match self { + CompareOp::Eq => result == Ordering::Equal, + CompareOp::Ne => result != Ordering::Equal, + CompareOp::Lt => result == Ordering::Less, + CompareOp::Le => result != Ordering::Greater, + CompareOp::Gt => result == Ordering::Greater, + CompareOp::Ge => result != Ordering::Less, + } + } } /// Output of `__next__` which can either `yield` the next value in the iteration, or @@ -578,3 +612,15 @@ pub(crate) unsafe extern "C" fn no_constructor_defined( )) }) } + +mod tests { + #[test] + fn test_compare_op_convert() { + use super::CompareOp; + use std::cmp::Ordering; + + assert!(CompareOp::Eq.convert(Ordering::Equal)); + assert!(!CompareOp::Eq.convert(Ordering::Less)); + assert!(CompareOp::Ge.convert(Ordering::Greater)); + } +}