Skip to content

Commit

Permalink
Don't raise TypeError from generated equality method (#4287)
Browse files Browse the repository at this point in the history
* Don't raise TypeError in derived equality method

* Add newsfragment
  • Loading branch information
jatoben authored Jun 26, 2024
1 parent 2e2d440 commit 7c2f5e8
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 9 deletions.
1 change: 1 addition & 0 deletions newsfragments/4287.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Return `NotImplemented` from generated equality method when comparing different types.
10 changes: 7 additions & 3 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1844,9 +1844,13 @@ fn pyclass_richcmp(
op: #pyo3_path::pyclass::CompareOp
) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
let self_val = self;
let other = &*#pyo3_path::types::PyAnyMethods::downcast::<Self>(other)?.borrow();
match op {
#arms
if let Ok(other) = #pyo3_path::types::PyAnyMethods::downcast::<Self>(other) {
let other = &*other.borrow();
match op {
#arms
}
} else {
::std::result::Result::Ok(py.NotImplemented())
}
}
};
Expand Down
13 changes: 13 additions & 0 deletions pytests/src/comparisons.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ impl EqDefaultNe {
}
}

#[pyclass(eq)]
#[derive(PartialEq, Eq)]
struct EqDerived(i64);

#[pymethods]
impl EqDerived {
#[new]
fn new(value: i64) -> Self {
Self(value)
}
}

#[pyclass]
struct Ordered(i64);

Expand Down Expand Up @@ -104,6 +116,7 @@ impl OrderedDefaultNe {
pub fn comparisons(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Eq>()?;
m.add_class::<EqDefaultNe>()?;
m.add_class::<EqDerived>()?;
m.add_class::<Ordered>()?;
m.add_class::<OrderedDefaultNe>()?;
Ok(())
Expand Down
33 changes: 27 additions & 6 deletions pytests/tests/test_comparisons.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,37 @@
from typing import Type, Union

import pytest
from pyo3_pytests.comparisons import Eq, EqDefaultNe, Ordered, OrderedDefaultNe
from pyo3_pytests.comparisons import (
Eq,
EqDefaultNe,
EqDerived,
Ordered,
OrderedDefaultNe,
)
from typing_extensions import Self


class PyEq:
def __init__(self, x: int) -> None:
self.x = x

def __eq__(self, other: Self) -> bool:
return self.x == other.x
def __eq__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return self.x == other.x
else:
return NotImplemented

def __ne__(self, other: Self) -> bool:
return self.x != other.x
if isinstance(other, self.__class__):
return self.x != other.x
else:
return NotImplemented


@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python"))
def test_eq(ty: Type[Union[Eq, PyEq]]):
@pytest.mark.parametrize(
"ty", (Eq, EqDerived, PyEq), ids=("rust", "rust-derived", "python")
)
def test_eq(ty: Type[Union[Eq, EqDerived, PyEq]]):
a = ty(0)
b = ty(0)
c = ty(1)
Expand All @@ -32,6 +46,13 @@ def test_eq(ty: Type[Union[Eq, PyEq]]):
assert b != c
assert not (b == c)

assert not a == 0
assert a != 0
assert not b == 0
assert b != 1
assert not c == 1
assert c != 1

with pytest.raises(TypeError):
assert a <= b

Expand Down

0 comments on commit 7c2f5e8

Please sign in to comment.