Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

have comparisons return bool when both objects are values #611

Merged
merged 11 commits into from
Jul 5, 2024
14 changes: 11 additions & 3 deletions client/unit-tests/coretypes/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from valor import Annotation, Dataset, Filter, Label, Model
from valor.schemas import And, Polygon
from valor.schemas import And, Eq, Gt, Lt, Polygon


@pytest.fixture
Expand Down Expand Up @@ -36,14 +36,22 @@ def test_complex_filter(
geojson: Dict[str, Union[str, List[List[Tuple[float, float]]]]],
polygon: Polygon,
):
# check expression types (this also makes pyright pass)
model_name_eq_x = Model.name == "x"
assert isinstance(model_name_eq_x, Eq)
annotation_raster_area_gt = Annotation.raster.area > 100
assert isinstance(annotation_raster_area_gt, Gt)
annotation_raster_area_lt = Annotation.raster.area < 500
assert isinstance(annotation_raster_area_lt, Lt)

filter_from_constraints = Filter(
annotations=And(
Dataset.name.in_(["a", "b", "c"]),
(Model.name == "x") | Model.name.in_(["y", "z"]),
model_name_eq_x | Model.name.in_(["y", "z"]),
Label.score > 0.75,
Annotation.polygon.area > 1000,
Annotation.polygon.area < 5000,
(Annotation.raster.area > 100) & (Annotation.raster.area < 500),
annotation_raster_area_gt & annotation_raster_area_lt,
Dataset.metadata["some_str"] == "foobar",
Dataset.metadata["some_float"] >= 0.123,
Dataset.metadata["some_datetime"] > datetime.timedelta(days=1),
Expand Down
67 changes: 58 additions & 9 deletions client/unit-tests/schemas/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,21 @@ def test_scored_label():
assert l1.value == "value"

# test member fn `__eq__`
assert (s1 == s2).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert (s1 == s6).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert not (s1 == s3).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert not (s1 == s4).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert not (s1 == s5).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert s1 == s2 # type: ignore - resolved to 'Bool' as both sides are values
assert s1 == s6 # type: ignore - resolved to 'Bool' as both sides are values
assert not (s1 == s3) # type: ignore - resolved to 'Bool' as both sides are values
assert not (s1 == s4) # type: ignore - resolved to 'Bool' as both sides are values
assert not (s1 == s5) # type: ignore - resolved to 'Bool' as both sides are values
czaloom marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(TypeError):
assert s1 == 123
with pytest.raises(TypeError):
assert s1 == "123"

# test member fn `__ne__`
assert not (s1 != s2).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert (s1 != s3).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert (s1 != s4).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert (s1 != s5).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert not (s1 != s2) # type: ignore - resolved to 'Bool' as both sides are values
assert s1 != s3 # type: ignore - resolved to 'Bool' as both sides are values
assert s1 != s4 # type: ignore - resolved to 'Bool' as both sides are values
assert s1 != s5 # type: ignore - resolved to 'Bool' as both sides are values
with pytest.raises(TypeError):
assert s1 != 123
with pytest.raises(TypeError):
Expand All @@ -76,3 +76,52 @@ def test_scored_label():
assert s1.__hash__() != s3.__hash__()
assert s1.__hash__() != s4.__hash__()
assert s1.__hash__() != s5.__hash__()


def test_label_equality():
label1 = Label(key="test", value="value")
label2 = Label(key="test", value="value")
label3 = Label(key="test", value="other")
label4 = Label(key="other", value="value")

eq1 = label1 == label2
assert type(eq1) == bool
assert eq1

eq2 = label1 == label3
assert type(eq2) == bool
assert not eq2

eq3 = label1 == label4
assert type(eq3) == bool
assert not eq3


def test_label_score():
label1 = Label(key="test", value="value", score=0.5)
label2 = Label(key="test", value="value", score=0.5)
label3 = Label(key="test", value="value", score=0.1)

b1 = label1.score == label2.score
assert type(b1) == bool
assert b1

b2 = label1.score > label3.score
assert type(b2) == bool
assert b2

b3 = label1.score < label3.score
assert type(b3) == bool
assert not b3

b4 = label1.score >= label2.score
assert type(b4) == bool
assert b4

b5 = label1.score != label3.score
assert type(b5) == bool
assert b5

b6 = label1.score != label2.score
assert type(b6) == bool
assert not b6
27 changes: 14 additions & 13 deletions client/unit-tests/symbolic/collections/test_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def test_list():
assert variable[0].get_value() == 0.1

# test comparison symbol -> value
assert (symbol == [0.1, 0.2, 0.3]).to_dict() == {
eq = symbol == [0.1, 0.2, 0.3]
assert isinstance(eq, Eq)
assert eq.to_dict() == {
"op": "eq",
"lhs": {
"name": "list[float]",
Expand All @@ -217,7 +219,9 @@ def test_list():
}

# test comparison symbol -> valued variable
assert (symbol == variable).to_dict() == {
eq = symbol == variable
assert isinstance(eq, Eq)
assert eq.to_dict() == {
"op": "eq",
"lhs": {
"name": "list[float]",
Expand All @@ -236,10 +240,7 @@ def test_list():
]

# test comparison between valued variable and value
assert (variable == [0.1, 0.2, 0.3]).to_dict() == {
"type": "boolean",
"value": True,
}
assert variable == [0.1, 0.2, 0.3]

# test setting list to non-list type
with pytest.raises(TypeError):
Expand Down Expand Up @@ -303,9 +304,9 @@ def test_dictionary_value():
assert (
DictionaryValue.symbolic(name="a", key="b").is_not_none()
).to_dict()["op"] == "isnotnull"
assert (DictionaryValue.symbolic(name="a", key="b").area == 0).to_dict()[
"op"
] == "eq"
eq = DictionaryValue.symbolic(name="a", key="b") == 0
assert isinstance(eq, Eq)
assert eq.to_dict()["op"] == "eq"

# test router with Variable type
assert (DictionaryValue.symbolic(name="a", key="b") == Float(0)).to_dict()[
Expand Down Expand Up @@ -368,12 +369,12 @@ def test_dictionary():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() # type: ignore - issue #604
assert not v1.is_not_none() # type: ignore - issue #604
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert not v2.is_none() # type: ignore - issue #604
assert v2.is_not_none() # type: ignore - issue #604

# test encoding
assert {
Expand Down
44 changes: 22 additions & 22 deletions client/unit-tests/symbolic/types/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def _test_resolvable(
truth = a.__getattribute__(op)(b)

# test variable -> builtin against truth
assert A.__getattribute__(op)(b).get_value() is truth
assert A.__getattribute__(op)(b) is truth
# test variable -> variable against truth
assert A.__getattribute__(op)(B).get_value() is truth
assert A.__getattribute__(op)(B) is truth
# test dictionary generation
dictA = A.to_dict()
assert A.get_value() == a
Expand Down Expand Up @@ -166,12 +166,12 @@ def test_score():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() is True
assert v1.is_not_none() is False
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert v2.is_none() is False
assert v2.is_not_none() is True

# test unsupported methods
for op in [
Expand Down Expand Up @@ -221,12 +221,12 @@ def test_tasktypeenum():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() is True
assert v1.is_not_none() is False
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert v2.is_none() is False
assert v2.is_not_none() is True

# test encoding
_test_encoding(
Expand Down Expand Up @@ -273,12 +273,12 @@ def test_box():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() is True
assert v1.is_not_none() is False
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert v2.is_none() is False
assert v2.is_not_none() is True

# test unsupported methods
for op in [
Expand Down Expand Up @@ -338,12 +338,12 @@ def test_raster():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() is True
assert v1.is_not_none() is False
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert v2.is_none() is False
assert v2.is_not_none() is True

# test 'from_numpy' classmethod
assert Raster.from_numpy(bitmask1).to_dict() == Raster(value).to_dict()
Expand Down Expand Up @@ -420,12 +420,12 @@ def test_embedding():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() is True
assert v1.is_not_none() is False
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert v2.is_none() is False
assert v2.is_not_none() is True

# test unsupported methods
for op in [
Expand Down
Loading
Loading