Skip to content

Commit

Permalink
have comparisons return bool when both objects are values (#611)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekorman authored Jul 5, 2024
1 parent 0694e28 commit 5eff5be
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 178 deletions.
14 changes: 11 additions & 3 deletions client/unit-tests/coretypes/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from valor import Annotation, Dataset, Filter, Label, Model
from valor.schemas import And, Polygon
from valor.schemas import And, Eq, Gt, Lt, Polygon


@pytest.fixture
Expand Down Expand Up @@ -36,14 +36,22 @@ def test_complex_filter(
geojson: Dict[str, Union[str, List[List[Tuple[float, float]]]]],
polygon: Polygon,
):
# check expression types (this also makes pyright pass)
model_name_eq_x = Model.name == "x"
assert isinstance(model_name_eq_x, Eq)
annotation_raster_area_gt = Annotation.raster.area > 100
assert isinstance(annotation_raster_area_gt, Gt)
annotation_raster_area_lt = Annotation.raster.area < 500
assert isinstance(annotation_raster_area_lt, Lt)

filter_from_constraints = Filter(
annotations=And(
Dataset.name.in_(["a", "b", "c"]),
(Model.name == "x") | Model.name.in_(["y", "z"]),
model_name_eq_x | Model.name.in_(["y", "z"]),
Label.score > 0.75,
Annotation.polygon.area > 1000,
Annotation.polygon.area < 5000,
(Annotation.raster.area > 100) & (Annotation.raster.area < 500),
annotation_raster_area_gt & annotation_raster_area_lt,
Dataset.metadata["some_str"] == "foobar",
Dataset.metadata["some_float"] >= 0.123,
Dataset.metadata["some_datetime"] > datetime.timedelta(days=1),
Expand Down
67 changes: 58 additions & 9 deletions client/unit-tests/schemas/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,21 @@ def test_scored_label():
assert l1.value == "value"

# test member fn `__eq__`
assert (s1 == s2).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert (s1 == s6).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert not (s1 == s3).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert not (s1 == s4).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert not (s1 == s5).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert s1 == s2
assert s1 == s6
assert not (s1 == s3)
assert not (s1 == s4)
assert not (s1 == s5)
with pytest.raises(TypeError):
assert s1 == 123
with pytest.raises(TypeError):
assert s1 == "123"

# test member fn `__ne__`
assert not (s1 != s2).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert (s1 != s3).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert (s1 != s4).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert (s1 != s5).get_value() # type: ignore - resolved to 'Bool' as both sides are values
assert not (s1 != s2)
assert s1 != s3
assert s1 != s4
assert s1 != s5
with pytest.raises(TypeError):
assert s1 != 123
with pytest.raises(TypeError):
Expand All @@ -76,3 +76,52 @@ def test_scored_label():
assert s1.__hash__() != s3.__hash__()
assert s1.__hash__() != s4.__hash__()
assert s1.__hash__() != s5.__hash__()


def test_label_equality():
label1 = Label(key="test", value="value")
label2 = Label(key="test", value="value")
label3 = Label(key="test", value="other")
label4 = Label(key="other", value="value")

eq1 = label1 == label2
assert type(eq1) == bool
assert eq1

eq2 = label1 == label3
assert type(eq2) == bool
assert not eq2

eq3 = label1 == label4
assert type(eq3) == bool
assert not eq3


def test_label_score():
label1 = Label(key="test", value="value", score=0.5)
label2 = Label(key="test", value="value", score=0.5)
label3 = Label(key="test", value="value", score=0.1)

b1 = label1.score == label2.score
assert type(b1) == bool
assert b1

b2 = label1.score > label3.score
assert type(b2) == bool
assert b2

b3 = label1.score < label3.score
assert type(b3) == bool
assert not b3

b4 = label1.score >= label2.score
assert type(b4) == bool
assert b4

b5 = label1.score != label3.score
assert type(b5) == bool
assert b5

b6 = label1.score != label2.score
assert type(b6) == bool
assert not b6
27 changes: 14 additions & 13 deletions client/unit-tests/symbolic/collections/test_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def test_list():
assert variable[0].get_value() == 0.1

# test comparison symbol -> value
assert (symbol == [0.1, 0.2, 0.3]).to_dict() == {
eq = symbol == [0.1, 0.2, 0.3]
assert isinstance(eq, Eq)
assert eq.to_dict() == {
"op": "eq",
"lhs": {
"name": "list[float]",
Expand All @@ -217,7 +219,9 @@ def test_list():
}

# test comparison symbol -> valued variable
assert (symbol == variable).to_dict() == {
eq = symbol == variable
assert isinstance(eq, Eq)
assert eq.to_dict() == {
"op": "eq",
"lhs": {
"name": "list[float]",
Expand All @@ -236,10 +240,7 @@ def test_list():
]

# test comparison between valued variable and value
assert (variable == [0.1, 0.2, 0.3]).to_dict() == {
"type": "boolean",
"value": True,
}
assert variable == [0.1, 0.2, 0.3]

# test setting list to non-list type
with pytest.raises(TypeError):
Expand Down Expand Up @@ -303,9 +304,9 @@ def test_dictionary_value():
assert (
DictionaryValue.symbolic(name="a", key="b").is_not_none()
).to_dict()["op"] == "isnotnull"
assert (DictionaryValue.symbolic(name="a", key="b").area == 0).to_dict()[
"op"
] == "eq"
eq = DictionaryValue.symbolic(name="a", key="b") == 0
assert isinstance(eq, Eq)
assert eq.to_dict()["op"] == "eq"

# test router with Variable type
assert (DictionaryValue.symbolic(name="a", key="b") == Float(0)).to_dict()[
Expand Down Expand Up @@ -368,12 +369,12 @@ def test_dictionary():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() # type: ignore - issue #604
assert not v1.is_not_none() # type: ignore - issue #604
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert not v2.is_none() # type: ignore - issue #604
assert v2.is_not_none() # type: ignore - issue #604

# test encoding
assert {
Expand Down
44 changes: 22 additions & 22 deletions client/unit-tests/symbolic/types/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def _test_resolvable(
truth = a.__getattribute__(op)(b)

# test variable -> builtin against truth
assert A.__getattribute__(op)(b).get_value() is truth
assert A.__getattribute__(op)(b) is truth
# test variable -> variable against truth
assert A.__getattribute__(op)(B).get_value() is truth
assert A.__getattribute__(op)(B) is truth
# test dictionary generation
dictA = A.to_dict()
assert A.get_value() == a
Expand Down Expand Up @@ -166,12 +166,12 @@ def test_score():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() is True
assert v1.is_not_none() is False
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert v2.is_none() is False
assert v2.is_not_none() is True

# test unsupported methods
for op in [
Expand Down Expand Up @@ -221,12 +221,12 @@ def test_tasktypeenum():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() is True
assert v1.is_not_none() is False
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert v2.is_none() is False
assert v2.is_not_none() is True

# test encoding
_test_encoding(
Expand Down Expand Up @@ -273,12 +273,12 @@ def test_box():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() is True
assert v1.is_not_none() is False
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert v2.is_none() is False
assert v2.is_not_none() is True

# test unsupported methods
for op in [
Expand Down Expand Up @@ -338,12 +338,12 @@ def test_raster():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() is True
assert v1.is_not_none() is False
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert v2.is_none() is False
assert v2.is_not_none() is True

# test 'from_numpy' classmethod
assert Raster.from_numpy(bitmask1).to_dict() == Raster(value).to_dict()
Expand Down Expand Up @@ -420,12 +420,12 @@ def test_embedding():
# test nullable
v1 = objcls.nullable(None)
assert v1.get_value() is None
assert v1.is_none().get_value() is True # type: ignore - issue #604
assert v1.is_not_none().get_value() is False # type: ignore - issue #604
assert v1.is_none() is True
assert v1.is_not_none() is False
v2 = objcls.nullable(permutations[0][0])
assert v2.get_value() is not None
assert v2.is_none().get_value() is False # type: ignore - issue #604
assert v2.is_not_none().get_value() is True # type: ignore - issue #604
assert v2.is_none() is False
assert v2.is_not_none() is True

# test unsupported methods
for op in [
Expand Down
Loading

0 comments on commit 5eff5be

Please sign in to comment.