From d9f0b8a5426b762c125a1a72b2b0857fcfe0d59b Mon Sep 17 00:00:00 2001 From: "changjun.zhu" Date: Mon, 14 Feb 2022 15:50:56 +0800 Subject: [PATCH] feat(client): design structure and interfaces of basicSearch --- tensorbay/client/dataset.py | 13 +++- tensorbay/client/gas.py | 10 ++- tensorbay/client/job.py | 16 ++++- tensorbay/client/search.py | 97 ++++++++++++++++++++++++++ tensorbay/client/tests/test_dataset.py | 10 +-- tensorbay/client/tests/test_gas.py | 6 +- tensorbay/client/tests/test_version.py | 6 +- tensorbay/client/version.py | 84 +++++++++++++++++++++- tests/test_dataset.py | 8 +-- 9 files changed, 225 insertions(+), 25 deletions(-) create mode 100644 tensorbay/client/search.py diff --git a/tensorbay/client/dataset.py b/tensorbay/client/dataset.py index e274d53bf..51395e7a5 100644 --- a/tensorbay/client/dataset.py +++ b/tensorbay/client/dataset.py @@ -25,7 +25,7 @@ from tensorbay.client.segment import _STRATEGIES, FusionSegmentClient, SegmentClient from tensorbay.client.statistics import Statistics from tensorbay.client.status import Status -from tensorbay.client.version import SquashAndMerge, VersionControlMixin +from tensorbay.client.version import BasicSearch, SquashAndMerge, VersionControlMixin from tensorbay.dataset import AuthData, Data, Frame, FusionSegment, Notes, RemoteData, Segment from tensorbay.exception import ( FrameError, @@ -228,6 +228,17 @@ def squash_and_merge(self) -> SquashAndMerge: """ return SquashAndMerge(self._client, self._dataset_id, self._status, self.get_draft) + @property # type: ignore[misc] + @functools.lru_cache() + def basic_search(self) -> BasicSearch: + """Get class :class:`~tensorbay.client.version.BasicSearch`. + + Returns: + Required :class:`~tensorbay.client.version.BasicSearch`. + + """ + return BasicSearch(self._client, self._dataset_id, self._status) + def enable_cache(self, cache_path: str = "") -> None: """Enable cache when open the remote data of the dataset. diff --git a/tensorbay/client/gas.py b/tensorbay/client/gas.py index b25739db7..d0d927e68 100644 --- a/tensorbay/client/gas.py +++ b/tensorbay/client/gas.py @@ -24,9 +24,7 @@ DatasetClientType = Union[DatasetClient, FusionDatasetClient] logger = logging.getLogger(__name__) - DEFAULT_BRANCH = "main" -DEFAULT_IS_PUBLIC = False class GAS: @@ -338,7 +336,7 @@ def create_dataset( *, config_name: Optional[str] = None, alias: str = "", - is_public: bool = DEFAULT_IS_PUBLIC, + is_public: bool = False, ) -> DatasetClient: ... @@ -350,7 +348,7 @@ def create_dataset( *, config_name: Optional[str] = None, alias: str = "", - is_public: bool = DEFAULT_IS_PUBLIC, + is_public: bool = False, ) -> FusionDatasetClient: ... @@ -362,7 +360,7 @@ def create_dataset( *, config_name: Optional[str] = None, alias: str = "", - is_public: bool = DEFAULT_IS_PUBLIC, + is_public: bool = False, ) -> DatasetClientType: ... @@ -373,7 +371,7 @@ def create_dataset( *, config_name: Optional[str] = None, alias: str = "", - is_public: bool = DEFAULT_IS_PUBLIC, + is_public: bool = False, ) -> DatasetClientType: """Create a TensorBay dataset with given name. diff --git a/tensorbay/client/job.py b/tensorbay/client/job.py index fbccaeb4e..20ae7a76b 100644 --- a/tensorbay/client/job.py +++ b/tensorbay/client/job.py @@ -6,9 +6,10 @@ """Basic structures of asynchronous jobs.""" from time import sleep -from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar +from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar, Union from tensorbay.client.requests import Client +from tensorbay.client.search import FusionSearchResult, SearchResult from tensorbay.client.struct import Draft from tensorbay.utility import AttrsMixin, ReprMixin, ReprType, attr, camel, common_loads @@ -280,3 +281,16 @@ def from_response_body( # type: ignore[override] # pylint: disable=arguments-d job._draft_getter = draft_getter # pylint: disable=protected-access return job + + +class BasicSearchJob(Job): + """This class defines :class:`BasicSearchJob`.""" + + @property + def result(self) -> Optional[Union[SearchResult, FusionSearchResult]]: + """Get the result of the BasicSearchJob. + + Return: + The search result of the BasicSearchJob. + + """ diff --git a/tensorbay/client/search.py b/tensorbay/client/search.py new file mode 100644 index 000000000..77afe290a --- /dev/null +++ b/tensorbay/client/search.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Graviti. Licensed under MIT License. +# + +"""The structure of the search result.""" + +from typing import TYPE_CHECKING, Union + +from tensorbay.client.lazy import PagingList +from tensorbay.client.requests import Client +from tensorbay.client.statistics import Statistics +from tensorbay.dataset import Frame, RemoteData +from tensorbay.sensor.sensor import Sensors +from tensorbay.utility import ReprMixin + +if TYPE_CHECKING: + from tensorbay.client.dataset import DatasetClient, FusionDatasetClient + + +class SearchResultBase(ReprMixin): + """This class defines the structure of the search result. + + Arguments: + job_id: The id of the search job. + search_result_id: The id of the search result. + client: The :class:`~tensorbay.client.requires.Client`. + + """ + + def __init__(self, job_id: str, search_result_id: str, client: Client) -> None: + pass + + def get_label_statistics(self) -> Statistics: + """Get label statistics of the search result. + + Return: + Required :class:`~tensorbay.client.dataset.Statistics`. + + """ + + def create_dataset( + self, name: str, alias: str = "", is_public: bool = False + ) -> Union["DatasetClient", "FusionDatasetClient"]: + """Create a TensorBay dataset based on the search result. + + Arguments: + name: Name of the dataset, unique for a user. + alias: Alias of the dataset, default is "". + is_public: Whether the dataset is a public dataset. + + Return: + The created :class:`~tensorbay.client.dataset.DatasetClient` instance or + :class:`~tensorbay.client.dataset.FusionDatasetClient` instance. + + """ + + +class SearchResult(SearchResultBase): + """This class defines the structure of the search result from normal dataset.""" + + def list_data(self, segment_name: str) -> PagingList[RemoteData]: + """List required data of the segment with given name. + + Arguments: + segment_name: Name of the segment. + + Return: + The PagingList of :class:`~tensorbay.dataset.data.RemoteData`. + + """ + + +class FusionSearchResult(SearchResultBase): + """This class defines the structure of the search result from fusion dataset.""" + + def list_frames(self, segment_name: str) -> PagingList[Frame]: + """List required frames of the segment with given name. + + Arguments: + segment_name: Name of the segment. + + Return: + The PagingList of :class:`~tensorbay.dataset.frame.Frame`. + + """ + + def get_sensors(self, segment_name: str) -> Sensors: + """Return the sensors of the segment with given name. + + Arguments: + segment_name: Name of the segment. + + Return: + The :class:`sensors<~tensorbay.sensor.sensor.Sensors>`instance. + + """ diff --git a/tensorbay/client/tests/test_dataset.py b/tensorbay/client/tests/test_dataset.py index 5e847b3be..114243d8f 100644 --- a/tensorbay/client/tests/test_dataset.py +++ b/tensorbay/client/tests/test_dataset.py @@ -13,7 +13,7 @@ from tensorbay.client import dataset, gas, segment from tensorbay.client.dataset import DatasetClient, FusionDatasetClient from tensorbay.client.diff import DataDiff, SegmentDiff -from tensorbay.client.gas import DEFAULT_BRANCH, DEFAULT_IS_PUBLIC, GAS +from tensorbay.client.gas import DEFAULT_BRANCH, GAS from tensorbay.client.lazy import ReturnGenerator from tensorbay.client.requests import Tqdm from tensorbay.client.segment import FusionSegmentClient, SegmentClient @@ -40,7 +40,7 @@ class TestDatasetClientBase: gas_client, status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID), alias="", - is_public=DEFAULT_IS_PUBLIC, + is_public=False, ) source_dataset_client = DatasetClient( "source_dataset", @@ -48,7 +48,7 @@ class TestDatasetClientBase: gas_client, status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID), alias="", - is_public=DEFAULT_IS_PUBLIC, + is_public=False, ) def test__create_segment(self, mocker): @@ -684,7 +684,7 @@ class TestFusionDatasetClient(TestDatasetClientBase): gas_client, status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID), alias="", - is_public=DEFAULT_IS_PUBLIC, + is_public=False, ) source_dataset_client = FusionDatasetClient( "source_dataset", @@ -692,7 +692,7 @@ class TestFusionDatasetClient(TestDatasetClientBase): gas_client, status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID), alias="", - is_public=DEFAULT_IS_PUBLIC, + is_public=False, ) def test__extract_all_data(self): diff --git a/tensorbay/client/tests/test_gas.py b/tensorbay/client/tests/test_gas.py index 602ec95a4..f06b4fb76 100644 --- a/tensorbay/client/tests/test_gas.py +++ b/tensorbay/client/tests/test_gas.py @@ -10,7 +10,7 @@ from tensorbay.client import gas from tensorbay.client.cloud_storage import CloudClient, StorageConfig from tensorbay.client.dataset import DatasetClient, FusionDatasetClient -from tensorbay.client.gas import DEFAULT_BRANCH, DEFAULT_IS_PUBLIC, GAS +from tensorbay.client.gas import DEFAULT_BRANCH, GAS from tensorbay.client.status import Status from tensorbay.client.struct import ROOT_COMMIT_ID, Draft, User from tensorbay.client.tests.utility import mock_response @@ -110,7 +110,7 @@ def test__get_dataset(self, mocker): "updateTime": 1622693494, "owner": "", "id": "123456", - "isPublic": DEFAULT_IS_PUBLIC, + "isPublic": False, } mocker.patch( f"{gas.__name__}.GAS._list_datasets", @@ -485,7 +485,7 @@ def test_upload_dataset(self, mocker): self.gas_client, status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID), alias="", - is_public=DEFAULT_IS_PUBLIC, + is_public=False, ), ) checkout = mocker.patch(f"{gas.__name__}.DatasetClient.checkout") diff --git a/tensorbay/client/tests/test_version.py b/tensorbay/client/tests/test_version.py index 50aeb02d1..bb7f6fea8 100644 --- a/tensorbay/client/tests/test_version.py +++ b/tensorbay/client/tests/test_version.py @@ -6,7 +6,7 @@ import pytest from tensorbay.client.dataset import DatasetClient -from tensorbay.client.gas import DEFAULT_BRANCH, DEFAULT_IS_PUBLIC, GAS +from tensorbay.client.gas import DEFAULT_BRANCH, GAS from tensorbay.client.job import SquashAndMergeJob from tensorbay.client.lazy import ReturnGenerator from tensorbay.client.status import Status @@ -22,7 +22,7 @@ class TestVersionControlMixin: gas_client, status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID), alias="", - is_public=DEFAULT_IS_PUBLIC, + is_public=False, ) @@ -34,7 +34,7 @@ class TestJobMixin: gas_client, status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID), alias="", - is_public=DEFAULT_IS_PUBLIC, + is_public=False, ) def test__create_job(self, mocker, mock_create_job): diff --git a/tensorbay/client/version.py b/tensorbay/client/version.py index 35e9799d9..a02380380 100644 --- a/tensorbay/client/version.py +++ b/tensorbay/client/version.py @@ -5,9 +5,9 @@ """Related methods of the TensorBay version control.""" -from typing import Any, Callable, Dict, Generator, Optional, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union -from tensorbay.client.job import SquashAndMergeJob +from tensorbay.client.job import BasicSearchJob, SquashAndMergeJob from tensorbay.client.lazy import PagingList from tensorbay.client.requests import Client from tensorbay.client.status import Status @@ -752,3 +752,83 @@ def list_jobs(self, status: Optional[str] = None) -> PagingList[SquashAndMergeJo lambda offset, limit: self._generate_jobs(status, offset, limit), 128, ) + + +class BasicSearch(JobMixin): + """This class defines :class:`BasicSearch`. + + Arguments: + client: The :class:`~tensorbay.client.requests.Client`. + dataset_id: Dataset ID. + status: The version control status of the dataset. + + """ + + _JOB_TYPE = "basicSearch" + + def __init__( + self, + client: Client, + dataset_id: str, + status: Status, + ) -> None: + pass + + def _generate_jobs( + self, + status: Optional[str] = None, + offset: int = 0, + limit: int = 128, + ) -> Generator[BasicSearchJob, None, int]: + pass + + def create_job( + self, + title: str = "", + description: str = "", + *, + conjunction: str, + filters: List[Tuple[Any]], + unit: str = "FILE", + ) -> BasicSearchJob: + """Create a :class:`BasicSearchJob`. + + Arguments: + title: The BasicSearchJob title. + description: The BasicSearchJob description. + conjunction: The logical conjunction between search filters, which includes "AND" and + "OR". + filters: The list of basic search criteria. + unit: The unit of basic search from fusion dataset. There are two options: + + 1. "FILE": get the data that meets search filters; + 2. "FRAME": if at least one data in a frame meets search filters, all data in + the frame will be get. + + Return: + The BasicSearchJob. + + """ + + def get_job(self, job_id: str) -> BasicSearchJob: + """Get a :class:`BasicSearchJob`. + + Arguments: + job_id: The BasicSearchJob id. + + Return: + The BasicSearchJob. + + """ + + def list_jobs(self, status: Optional[str] = None) -> PagingList[BasicSearchJob]: + """List the BasicSearchJob. + + Arguments: + status: The BasicSearchJob status which includes "QUEUING", "PROCESSING", "SUCCESS", + "FAIL", "ABORT" and None. None means all kinds of status. + + Return: + The PagingList of BasicSearchJob. + + """ diff --git a/tests/test_dataset.py b/tests/test_dataset.py index be872d193..e3bbc956e 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -6,7 +6,7 @@ import pytest from tensorbay import GAS -from tensorbay.client.gas import DEFAULT_BRANCH, DEFAULT_IS_PUBLIC +from tensorbay.client.gas import DEFAULT_BRANCH from tensorbay.client.struct import ROOT_COMMIT_ID from tensorbay.exception import ResourceNotExistError, ResponseError, StatusError from tests.utility import get_dataset_name @@ -75,7 +75,7 @@ def test_get_new_dataset(self, accesskey, url): assert dataset_client_get.status.commit_id == ROOT_COMMIT_ID assert dataset_client_get.status.branch_name == DEFAULT_BRANCH assert dataset_client_get.dataset_id == dataset_client.dataset_id - assert dataset_client_get.is_public == DEFAULT_IS_PUBLIC + assert dataset_client_get.is_public == False gas_client.delete_dataset(dataset_name) @@ -92,7 +92,7 @@ def test_get_dataset_to_latest_commit(self, accesskey, url): dataset_client_get = gas_client.get_dataset(dataset_name) assert dataset_client_get.status.commit_id == v2_commit_id assert dataset_client_get.status.branch_name == DEFAULT_BRANCH - assert dataset_client_get.is_public == DEFAULT_IS_PUBLIC + assert dataset_client_get.is_public == False gas_client.delete_dataset(dataset_name) @@ -109,7 +109,7 @@ def test_get_fusion_dataset_to_latest_commit(self, accesskey, url): dataset_client_get = gas_client.get_dataset(dataset_name, is_fusion=True) assert dataset_client_get.status.commit_id == v2_commit_id assert dataset_client_get.status.branch_name == DEFAULT_BRANCH - assert dataset_client_get.is_public == DEFAULT_IS_PUBLIC + assert dataset_client_get.is_public == False gas_client.delete_dataset(dataset_name)