From 3cfa3e17df610948f3e131de0723bdbc5985ecd2 Mon Sep 17 00:00:00 2001 From: "changjun.zhu" Date: Mon, 25 Oct 2021 16:08:17 +0800 Subject: [PATCH] feat(client): adapt openAPI "getData" --- tensorbay/client/segment.py | 78 ++++++++++++++++- tests/test_data.py | 167 +++++++++++++++++++++++++++++------- 2 files changed, 212 insertions(+), 33 deletions(-) diff --git a/tensorbay/client/segment.py b/tensorbay/client/segment.py index 3e0364b0b..21552a3e1 100644 --- a/tensorbay/client/segment.py +++ b/tensorbay/client/segment.py @@ -37,7 +37,7 @@ from tensorbay.client.requests import config from tensorbay.client.status import Status from tensorbay.dataset import AuthData, Data, Frame, RemoteData -from tensorbay.exception import FrameError, InvalidParamsError, ResponseError +from tensorbay.exception import FrameError, InvalidParamsError, ResourceNotExistError, ResponseError from tensorbay.label import Label from tensorbay.sensor.sensor import Sensor, Sensors from tensorbay.utility import URL, FileMixin, chunked, locked @@ -123,6 +123,24 @@ def _list_urls(self, offset: int = 0, limit: int = 128) -> Dict[str, Any]: response = self._client.open_api_do("GET", "data/urls", self._dataset_id, params=params) return response.json() # type: ignore[no-any-return] + def _get_data_details(self, remote_path: str) -> Dict[str, Any]: + params: Dict[str, Any] = { + "segmentName": self._name, + "remotePath": remote_path, + } + params.update(self._status.get_status_info()) + + if config.is_internal: + params["isInternal"] = True + + response = self._client.open_api_do("GET", "data/details", self._dataset_id, params=params) + try: + data_details = response.json()["dataDetails"][0] + except IndexError as error: + raise ResourceNotExistError(resource="data", identification=remote_path) from error + + return data_details # type: ignore[no-any-return] + def _list_data_details(self, offset: int = 0, limit: int = 128) -> Dict[str, Any]: params: Dict[str, Any] = { "segmentName": self._name, @@ -137,6 +155,27 @@ def _list_data_details(self, offset: int = 0, limit: int = 128) -> Dict[str, Any response = self._client.open_api_do("GET", "data/details", self._dataset_id, params=params) return response.json() # type: ignore[no-any-return] + def _get_mask_url(self, mask_type: str, remote_path: str) -> str: + params: Dict[str, Any] = { + "segmentName": self._name, + "maskType": mask_type, + "remotePath": remote_path, + } + params.update(self._status.get_status_info()) + + if config.is_internal: + params["isInternal"] = True + + response = self._client.open_api_do("GET", "masks/urls", self._dataset_id, params=params) + try: + mask_url = response.json()["urls"][0]["url"] + except IndexError as error: + raise ResourceNotExistError( + resource="{mask_type} of data", identification=remote_path + ) from error + + return mask_url # type: ignore[no-any-return] + def _list_mask_urls(self, mask_type: str, offset: int = 0, limit: int = 128) -> Dict[str, Any]: params: Dict[str, Any] = { "segmentName": self._name, @@ -662,6 +701,43 @@ def list_data_paths(self) -> PagingList[str]: """ return PagingList(self._generate_data_paths, 128) + def get_data(self, remote_path: str) -> RemoteData: + """Get required Data object from a dataset segment. + + Arguments: + remote_path: The remote paths of the required data. + + Returns: + :class:`~tensorbay.dataset.data.RemoteData`. + + Raises: + ResourceNotExistError: When the required data does not exist. + + """ + if not remote_path: + raise ResourceNotExistError(resource="data", identification=remote_path) + + data_details = self._get_data_details(remote_path) + data = RemoteData.from_response_body( + data_details, + url=URL(data_details["url"], lambda: self._get_url(remote_path)), + cache_path=self._cache_path, + ) + label = data.label + + for key in _MASK_KEYS: + mask = getattr(label, key, None) + if mask: + mask.url = URL.from_getter( + lambda k=key.upper(), r=remote_path: self._get_mask_url(k, r), + lambda k=key.upper(), r=remote_path: ( # type: ignore[misc, arg-type] + self._get_mask_url(k, r) + ), + ) + mask.cache_path = os.path.join(self._cache_path, key, mask.path) + + return data + def list_data(self) -> PagingList[RemoteData]: """List required Data object in a dataset segment. diff --git a/tests/test_data.py b/tests/test_data.py index 01fdc7978..47a54e257 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -2,43 +2,63 @@ # # Copyright 2021 Graviti. Licensed under MIT License. # +import os + +import numpy as np import pytest from tensorbay import GAS from tensorbay.dataset import Data, Frame -from tensorbay.label import Catalog, Label +from tensorbay.label import Catalog, InstanceMask, Label, PanopticMask, SemanticMask +from tensorbay.label.label_mask import RemoteInstanceMask, RemotePanopticMask, RemoteSemanticMask from tensorbay.sensor import Sensor from tests.utility import get_dataset_name +CATALOG_ATTRBUTES = [ + {"name": "gender", "enum": ["male", "female"]}, + {"name": "occluded", "type": "integer", "minimum": 1, "maximum": 5}, +] +MASK_CATALOG_CONTENTS = { + "categories": [ + {"name": "cat", "description": "This is an exmaple of test", "categoryId": 0}, + {"name": "dog", "description": "This is an exmaple of test", "categoryId": 1}, + ], + "attributes": CATALOG_ATTRBUTES, +} +BOX2D_CATALOG_CONTENTS = { + "categories": [ + {"name": "01"}, + {"name": "02"}, + {"name": "03"}, + {"name": "04"}, + {"name": "05"}, + {"name": "06"}, + {"name": "07"}, + {"name": "08"}, + {"name": "09"}, + {"name": "10"}, + {"name": "11"}, + {"name": "12"}, + {"name": "13"}, + {"name": "14"}, + {"name": "15"}, + ], + "attributes": [ + {"name": "Vertical angle", "enum": [-90, -60, -30, -15, 0, 15, 30, 60, 90]}, + { + "name": "Horizontal angle", + "enum": [-90, -75, -60, -45, -30, -15, 0, 15, 30, 45, 60, 75, 90], + }, + {"name": "Serie", "enum": [1, 2]}, + {"name": "Number", "type": "integer", "minimum": 0, "maximum": 92}, + ], +} +BOX2D_CATALOG = {"BOX2D": BOX2D_CATALOG_CONTENTS} CATALOG = { - "BOX2D": { - "categories": [ - {"name": "01"}, - {"name": "02"}, - {"name": "03"}, - {"name": "04"}, - {"name": "05"}, - {"name": "06"}, - {"name": "07"}, - {"name": "08"}, - {"name": "09"}, - {"name": "10"}, - {"name": "11"}, - {"name": "12"}, - {"name": "13"}, - {"name": "14"}, - {"name": "15"}, - ], - "attributes": [ - {"name": "Vertical angle", "enum": [-90, -60, -30, -15, 0, 15, 30, 60, 90]}, - { - "name": "Horizontal angle", - "enum": [-90, -75, -60, -45, -30, -15, 0, 15, 30, 45, 60, 75, 90], - }, - {"name": "Serie", "enum": [1, 2]}, - {"name": "Number", "type": "integer", "minimum": 0, "maximum": 92}, - ], - } + "BOX2D": BOX2D_CATALOG_CONTENTS, + "SEMANTIC_MASK": MASK_CATALOG_CONTENTS, + "INSTANCE_MASK": MASK_CATALOG_CONTENTS, + "PANOPTIC_MASK": MASK_CATALOG_CONTENTS, } LABEL = { "BOX2D": [ @@ -66,9 +86,92 @@ "rotation": {"w": 1.0, "x": 2.0, "y": 3.0, "z": 4.0}, }, } +SEMANTIC_MASK_LABEL = { + "remotePath": "hello.png", + "info": [ + {"categoryId": 0, "attributes": {"occluded": True}}, + {"categoryId": 1, "attributes": {"occluded": False}}, + ], +} +INSTANCE_MASK_LABEL = { + "remotePath": "hello.png", + "info": [ + {"instanceId": 0, "attributes": {"occluded": True}}, + {"instanceId": 1, "attributes": {"occluded": False}}, + ], +} +PANOPTIC_MASK_LABEL = { + "remotePath": "hello.png", + "info": [ + {"instanceId": 100, "categoryId": 0, "attributes": {"occluded": True}}, + {"instanceId": 101, "categoryId": 1, "attributes": {"occluded": False}}, + ], +} + + +@pytest.fixture +def mask_file(tmp_path): + local_path = tmp_path / "hello.png" + mask = np.random.randint(0, 1, 48).reshape(8, 6) + mask.dump(local_path) + return local_path class TestData: + def test_get_data(self, accesskey, url, tmp_path, mask_file): + gas_client = GAS(access_key=accesskey, url=url) + dataset_name = get_dataset_name() + dataset_client = gas_client.create_dataset(dataset_name) + + dataset_client.create_draft("draft-1") + dataset_client.upload_catalog(Catalog.loads(CATALOG)) + segment_client = dataset_client.get_or_create_segment("segment1") + path = tmp_path / "sub" + path.mkdir() + + # Upload data with label + for i in range(10): + local_path = path / f"hello{i}.txt" + local_path.write_text(f"CONTENT{i}") + data = Data(local_path=str(local_path)) + data.label = Label.loads(LABEL) + + semantic_mask = SemanticMask(str(mask_file)) + semantic_mask.all_attributes = {0: {"occluded": True}, 1: {"occluded": False}} + data.label.semantic_mask = semantic_mask + + instance_mask = InstanceMask(str(mask_file)) + instance_mask.all_attributes = {0: {"occluded": True}, 1: {"occluded": False}} + data.label.instance_mask = instance_mask + + panoptic_mask = PanopticMask(str(mask_file)) + panoptic_mask.all_category_ids = {100: 0, 101: 1} + data.label.panoptic_mask = panoptic_mask + segment_client.upload_data(data) + + for i in range(10): + data = segment_client.get_data(f"hello{i}.txt") + assert data.path == f"hello{i}.txt" + assert data.label.box2d == Label.loads(LABEL).box2d + + stem = os.path.splitext(data.path)[0] + remote_semantic_mask = data.label.semantic_mask + semantic_mask = RemoteSemanticMask.from_response_body(SEMANTIC_MASK_LABEL) + assert remote_semantic_mask.path == f"{stem}.png" + assert remote_semantic_mask.all_attributes == semantic_mask.all_attributes + + remote_instance_mask = data.label.instance_mask + instance_mask = RemoteInstanceMask.from_response_body(INSTANCE_MASK_LABEL) + assert remote_instance_mask.path == f"{stem}.png" + assert remote_instance_mask.all_attributes == instance_mask.all_attributes + + remote_panoptic_mask = data.label.panoptic_mask + panoptic_mask = RemotePanopticMask.from_response_body(PANOPTIC_MASK_LABEL) + assert remote_panoptic_mask.path == f"{stem}.png" + assert remote_panoptic_mask.all_category_ids == panoptic_mask.all_category_ids + + gas_client.delete_dataset(dataset_name) + def test_list_file_order(self, accesskey, url, tmp_path): gas_client = GAS(access_key=accesskey, url=url) dataset_name = get_dataset_name() @@ -154,7 +257,7 @@ def test_overwrite_label(self, accesskey, url, tmp_path): dataset_name = get_dataset_name() dataset_client = gas_client.create_dataset(dataset_name) dataset_client.create_draft("draft-1") - dataset_client.upload_catalog(Catalog.loads(CATALOG)) + dataset_client.upload_catalog(Catalog.loads(BOX2D_CATALOG)) segment_client = dataset_client.get_or_create_segment("segment1") path = tmp_path / "sub" path.mkdir() @@ -182,7 +285,7 @@ def test_delete_data(self, accesskey, url, tmp_path): dataset_name = get_dataset_name() dataset_client = gas_client.create_dataset(dataset_name) dataset_client.create_draft("draft-1") - dataset_client.upload_catalog(Catalog.loads(CATALOG)) + dataset_client.upload_catalog(Catalog.loads(BOX2D_CATALOG)) segment_client = dataset_client.get_or_create_segment("segment1") path = tmp_path / "sub" @@ -205,7 +308,7 @@ def test_delete_frame(self, accesskey, url, tmp_path): dataset_name = get_dataset_name() dataset_client = gas_client.create_dataset(dataset_name, is_fusion=True) dataset_client.create_draft("draft-1") - dataset_client.upload_catalog(Catalog.loads(CATALOG)) + dataset_client.upload_catalog(Catalog.loads(BOX2D_CATALOG)) segment_client = dataset_client.get_or_create_segment("segment1") segment_client.upload_sensor(Sensor.loads(LIDAR_DATA))