diff --git a/docs/code/update_dataset.py b/docs/code/update_dataset.py index 9fe3931e9..9ec5f30f5 100644 --- a/docs/code/update_dataset.py +++ b/docs/code/update_dataset.py @@ -38,10 +38,16 @@ """""" """Update label / overwrite label""" +from tensorbay.label import Classification + +dataset = Dataset("", gas) for segment in dataset: - segment_client = dataset_client.get_segment(segment.name) + update_data = [] for data in segment: - segment_client.upload_label(data) + data.label.classification = Classification("NEW_CATEGORY") # set new label + update_data.append(data) + segment_client = dataset_client.get_segment(segment.name) + segment_client.upload_label(update_data) """""" """Update label / commit dataset""" diff --git a/docs/source/quick_start/examples/update_dataset.rst b/docs/source/quick_start/examples/update_dataset.rst index ba58b94ca..f8000c658 100644 --- a/docs/source/quick_start/examples/update_dataset.rst +++ b/docs/source/quick_start/examples/update_dataset.rst @@ -64,7 +64,7 @@ Update the catalog if needed: :start-after: """Update label / update catalog""" :end-before: """""" -Overwrite previous labels with new label on dataset: +Overwrite previous labels with new label: .. literalinclude:: ../../../../docs/code/update_dataset.py :language: python diff --git a/tensorbay/client/segment.py b/tensorbay/client/segment.py index 63e526bd5..79594e63e 100644 --- a/tensorbay/client/segment.py +++ b/tensorbay/client/segment.py @@ -18,6 +18,7 @@ from tensorbay.client.lazy import LazyPage, PagingList from tensorbay.client.status import Status from tensorbay.dataset import AuthData, Data, Frame, RemoteData +from tensorbay.dataset.data import DataBase from tensorbay.exception import FrameError, InvalidParamsError, ResourceNotExistError, ResponseError from tensorbay.label import Label from tensorbay.sensor.sensor import Sensor, Sensors @@ -328,6 +329,42 @@ def _upload_label(self, data: Union[AuthData, Data]) -> None: self._client.open_api_do("PUT", "labels", self._dataset_id, json=post_data) + def _upload_multi_label(self, data: Iterable[DataBase._Type]) -> None: + post_data: Dict[str, Any] = {"segmentName": self.name} + objects = [] + for single_data in data: + label = single_data.label.dumps() + if not label: + continue + + remote_path = ( + single_data.path + if isinstance(single_data, RemoteData) + else single_data.target_remote_path + ) + objects.append({"remotePath": remote_path, "label": label}) + post_data["objects"] = objects + post_data.update(self._status.get_status_info()) + + self._client.open_api_do("PUT", "multi/data/labels", self._dataset_id, json=post_data) + + def upload_label(self, data: Union[DataBase._Type, Iterable[DataBase._Type]]) -> None: + """Upload label with Data object to the draft. + + Arguments: + data: The data object which represents the local file to upload. + + """ + self._status.check_authority_for_draft() + + if not isinstance(data, Iterable): + data = [data] + + for chunked_data in chunked(data, 128): + for single_data in chunked_data: + self._upload_mask_files(single_data.label) + self._upload_multi_label(chunked_data) + @property def name(self) -> str: """Return the segment name. @@ -458,18 +495,6 @@ def upload_file(self, local_path: str, target_remote_path: str = "") -> None: self._synchronize_upload_info((data.get_callback_body(),)) - def upload_label(self, data: Data) -> None: - """Upload label with Data object to the draft. - - Arguments: - data: The data object which represents the local file to upload. - - """ - self._status.check_authority_for_draft() - - self._upload_mask_files(data.label) - self._upload_label(data) - def upload_data(self, data: Data) -> None: """Upload Data object to the draft. diff --git a/tests/test_label.py b/tests/test_label.py index f882b6d89..9852983af 100644 --- a/tests/test_label.py +++ b/tests/test_label.py @@ -236,3 +236,43 @@ def test_upload_dataset_with_mask(self, accesskey, url, tmp_path, mask_file): assert remote_panoptic_mask.all_category_ids == panoptic_mask.all_category_ids gas_client.delete_dataset(dataset_name) + + def test_upload_label(self, accesskey, url, tmp_path): + gas_client = GAS(access_key=accesskey, url=url) + dataset_name = get_dataset_name() + gas_client.create_dataset(dataset_name) + + dataset = Dataset(name=dataset_name) + segment = dataset.create_segment("Segment1") + # When uploading label, upload catalog first. + dataset._catalog = Catalog.loads(CATALOG_CONTENTS) + + path = tmp_path / "sub" + path.mkdir() + local_path = path / "hello.txt" + local_path.write_text("CONTENT") + data = Data(local_path=str(local_path)) + data.label = Label.loads(LABEL) + segment.append(data) + + dataset_client = gas_client.upload_dataset(dataset) + dataset_client.commit("upload dataset with label") + dataset = Dataset(dataset_name, gas_client) + assert dataset[0][0].label == Label.loads(LABEL) + + dataset_client.create_draft("update label") + segment_client = dataset_client.get_segment(segment.name) + + upload_data = [] + new_label = Label.loads(LABEL) + new_label.multi_polygon[0].category = "dog" + for data in segment: + data.label = new_label + upload_data.append(data) + segment_client.upload_label(upload_data) + dataset_client.commit("update label") + dataset = Dataset(dataset_name, gas_client) + assert dataset.catalog == Catalog.loads(CATALOG_CONTENTS) + assert dataset[0][0].label == new_label + + gas_client.delete_dataset(dataset_name)