Skip to content

Commit

Permalink
Update and move tests based on Eric's feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ntlind committed Sep 28, 2023
1 parent d833342 commit fa8e48a
Showing 1 changed file with 0 additions and 95 deletions.
95 changes: 0 additions & 95 deletions client/unit-tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,3 @@
from velour.enums import TaskType
from velour.integrations.coco import _merge_annotations
from velour.schemas import Label
from copy import deepcopy


def test__merge_annotations():
"""Check that we get the correct annotation set after merging semantic segmentions"""

initial_annotations = [
dict(
task_type=TaskType.SEMANTIC_SEGMENTATION,
labels=set([Label(key="k1", value="v1"), Label(key="k2", value="v2")]),
mask=[True, False, False, False],
),
dict(
task_type=TaskType.SEMANTIC_SEGMENTATION,
labels=set([Label(key="k1", value="v1"), Label(key="k3", value="v3")]),
mask=[False, False, True, False],
),
dict(
task_type=TaskType.SEMANTIC_SEGMENTATION,
labels=set(
[
Label(key="k1", value="v1"),
Label(key="k2", value="v2"),
Label(key="k4", value="v4"),
]
),
mask=[False, False, False, True],
),
dict(
task_type=TaskType.INSTANCE_SEGMENTATION,
labels=set([Label(key="k1", value="v1"), Label(key="k3", value="v3")]),
mask=[False, True, False, False],
),
]

expected = [
dict(
task_type=TaskType.SEMANTIC_SEGMENTATION,
labels=set([Label(key="k3", value="v3")]),
mask=[False, False, True, False],
),
dict(
task_type=TaskType.SEMANTIC_SEGMENTATION,
labels=set([Label(key="k4", value="v4")]),
mask=[False, False, False, True],
),
dict(
task_type=TaskType.INSTANCE_SEGMENTATION,
labels=set(
[
Label(key="k1", value="v1"),
Label(key="k3", value="v3"),
]
),
mask=[False, True, False, False],
),
dict(
task_type=TaskType.SEMANTIC_SEGMENTATION,
labels=set([Label(key="k1", value="v1")]),
mask=[True, False, True, True],
),
dict(
task_type=TaskType.SEMANTIC_SEGMENTATION,
labels=set([Label(key="k2", value="v2")]),
mask=[True, False, False, True],
),
]

label_map = {
Label(key="k1", value="v1"): [0, 1, 2],
Label(key="k2", value="v2"): [0, 2],
Label(key="k3", value="v3"): [1],
Label(key="k4", value="v4"): [2],
}

merged_annotations = _merge_annotations(
annotation_list=initial_annotations.copy(), label_map=label_map
)

for i, v in enumerate(merged_annotations):
assert (
merged_annotations[i]["labels"] == expected[i]["labels"]
), "Labels didn't merge as expected"
assert sum(merged_annotations[i]["mask"]) == sum(
expected[i]["mask"]
), "Masks didn't merge as expected"


if __name__ == "__main__":
test__merge_annotations()


# NOTE: Will probably reimplemnt in the future under `velour.client.Evaluation`

# from velour.client import Model
Expand Down

0 comments on commit fa8e48a

Please sign in to comment.