Skip to content

Commit

Permalink
feat(client): design structure and interfaces of basicSearch
Browse files Browse the repository at this point in the history
  • Loading branch information
graczhual committed Mar 3, 2022
1 parent 00aff52 commit d9f0b8a
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 25 deletions.
13 changes: 12 additions & 1 deletion tensorbay/client/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tensorbay.client.segment import _STRATEGIES, FusionSegmentClient, SegmentClient
from tensorbay.client.statistics import Statistics
from tensorbay.client.status import Status
from tensorbay.client.version import SquashAndMerge, VersionControlMixin
from tensorbay.client.version import BasicSearch, SquashAndMerge, VersionControlMixin
from tensorbay.dataset import AuthData, Data, Frame, FusionSegment, Notes, RemoteData, Segment
from tensorbay.exception import (
FrameError,
Expand Down Expand Up @@ -228,6 +228,17 @@ def squash_and_merge(self) -> SquashAndMerge:
"""
return SquashAndMerge(self._client, self._dataset_id, self._status, self.get_draft)

@property # type: ignore[misc]
@functools.lru_cache()
def basic_search(self) -> BasicSearch:
"""Get class :class:`~tensorbay.client.version.BasicSearch`.
Returns:
Required :class:`~tensorbay.client.version.BasicSearch`.
"""
return BasicSearch(self._client, self._dataset_id, self._status)

def enable_cache(self, cache_path: str = "") -> None:
"""Enable cache when open the remote data of the dataset.
Expand Down
10 changes: 4 additions & 6 deletions tensorbay/client/gas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
DatasetClientType = Union[DatasetClient, FusionDatasetClient]

logger = logging.getLogger(__name__)

DEFAULT_BRANCH = "main"
DEFAULT_IS_PUBLIC = False


class GAS:
Expand Down Expand Up @@ -338,7 +336,7 @@ def create_dataset(
*,
config_name: Optional[str] = None,
alias: str = "",
is_public: bool = DEFAULT_IS_PUBLIC,
is_public: bool = False,
) -> DatasetClient:
...

Expand All @@ -350,7 +348,7 @@ def create_dataset(
*,
config_name: Optional[str] = None,
alias: str = "",
is_public: bool = DEFAULT_IS_PUBLIC,
is_public: bool = False,
) -> FusionDatasetClient:
...

Expand All @@ -362,7 +360,7 @@ def create_dataset(
*,
config_name: Optional[str] = None,
alias: str = "",
is_public: bool = DEFAULT_IS_PUBLIC,
is_public: bool = False,
) -> DatasetClientType:
...

Expand All @@ -373,7 +371,7 @@ def create_dataset(
*,
config_name: Optional[str] = None,
alias: str = "",
is_public: bool = DEFAULT_IS_PUBLIC,
is_public: bool = False,
) -> DatasetClientType:
"""Create a TensorBay dataset with given name.
Expand Down
16 changes: 15 additions & 1 deletion tensorbay/client/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
"""Basic structures of asynchronous jobs."""

from time import sleep
from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar
from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar, Union

from tensorbay.client.requests import Client
from tensorbay.client.search import FusionSearchResult, SearchResult
from tensorbay.client.struct import Draft
from tensorbay.utility import AttrsMixin, ReprMixin, ReprType, attr, camel, common_loads

Expand Down Expand Up @@ -280,3 +281,16 @@ def from_response_body( # type: ignore[override] # pylint: disable=arguments-d
job._draft_getter = draft_getter # pylint: disable=protected-access

return job


class BasicSearchJob(Job):
"""This class defines :class:`BasicSearchJob`."""

@property
def result(self) -> Optional[Union[SearchResult, FusionSearchResult]]:
"""Get the result of the BasicSearchJob.
Return:
The search result of the BasicSearchJob.
"""
97 changes: 97 additions & 0 deletions tensorbay/client/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python3
#
# Copyright 2021 Graviti. Licensed under MIT License.
#

"""The structure of the search result."""

from typing import TYPE_CHECKING, Union

from tensorbay.client.lazy import PagingList
from tensorbay.client.requests import Client
from tensorbay.client.statistics import Statistics
from tensorbay.dataset import Frame, RemoteData
from tensorbay.sensor.sensor import Sensors
from tensorbay.utility import ReprMixin

if TYPE_CHECKING:
from tensorbay.client.dataset import DatasetClient, FusionDatasetClient


class SearchResultBase(ReprMixin):
"""This class defines the structure of the search result.
Arguments:
job_id: The id of the search job.
search_result_id: The id of the search result.
client: The :class:`~tensorbay.client.requires.Client`.
"""

def __init__(self, job_id: str, search_result_id: str, client: Client) -> None:
pass

def get_label_statistics(self) -> Statistics:
"""Get label statistics of the search result.
Return:
Required :class:`~tensorbay.client.dataset.Statistics`.
"""

def create_dataset(
self, name: str, alias: str = "", is_public: bool = False
) -> Union["DatasetClient", "FusionDatasetClient"]:
"""Create a TensorBay dataset based on the search result.
Arguments:
name: Name of the dataset, unique for a user.
alias: Alias of the dataset, default is "".
is_public: Whether the dataset is a public dataset.
Return:
The created :class:`~tensorbay.client.dataset.DatasetClient` instance or
:class:`~tensorbay.client.dataset.FusionDatasetClient` instance.
"""


class SearchResult(SearchResultBase):
"""This class defines the structure of the search result from normal dataset."""

def list_data(self, segment_name: str) -> PagingList[RemoteData]:
"""List required data of the segment with given name.
Arguments:
segment_name: Name of the segment.
Return:
The PagingList of :class:`~tensorbay.dataset.data.RemoteData`.
"""


class FusionSearchResult(SearchResultBase):
"""This class defines the structure of the search result from fusion dataset."""

def list_frames(self, segment_name: str) -> PagingList[Frame]:
"""List required frames of the segment with given name.
Arguments:
segment_name: Name of the segment.
Return:
The PagingList of :class:`~tensorbay.dataset.frame.Frame`.
"""

def get_sensors(self, segment_name: str) -> Sensors:
"""Return the sensors of the segment with given name.
Arguments:
segment_name: Name of the segment.
Return:
The :class:`sensors<~tensorbay.sensor.sensor.Sensors>`instance.
"""
10 changes: 5 additions & 5 deletions tensorbay/client/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tensorbay.client import dataset, gas, segment
from tensorbay.client.dataset import DatasetClient, FusionDatasetClient
from tensorbay.client.diff import DataDiff, SegmentDiff
from tensorbay.client.gas import DEFAULT_BRANCH, DEFAULT_IS_PUBLIC, GAS
from tensorbay.client.gas import DEFAULT_BRANCH, GAS
from tensorbay.client.lazy import ReturnGenerator
from tensorbay.client.requests import Tqdm
from tensorbay.client.segment import FusionSegmentClient, SegmentClient
Expand All @@ -40,15 +40,15 @@ class TestDatasetClientBase:
gas_client,
status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID),
alias="",
is_public=DEFAULT_IS_PUBLIC,
is_public=False,
)
source_dataset_client = DatasetClient(
"source_dataset",
"544321",
gas_client,
status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID),
alias="",
is_public=DEFAULT_IS_PUBLIC,
is_public=False,
)

def test__create_segment(self, mocker):
Expand Down Expand Up @@ -684,15 +684,15 @@ class TestFusionDatasetClient(TestDatasetClientBase):
gas_client,
status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID),
alias="",
is_public=DEFAULT_IS_PUBLIC,
is_public=False,
)
source_dataset_client = FusionDatasetClient(
"source_dataset",
"544321",
gas_client,
status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID),
alias="",
is_public=DEFAULT_IS_PUBLIC,
is_public=False,
)

def test__extract_all_data(self):
Expand Down
6 changes: 3 additions & 3 deletions tensorbay/client/tests/test_gas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tensorbay.client import gas
from tensorbay.client.cloud_storage import CloudClient, StorageConfig
from tensorbay.client.dataset import DatasetClient, FusionDatasetClient
from tensorbay.client.gas import DEFAULT_BRANCH, DEFAULT_IS_PUBLIC, GAS
from tensorbay.client.gas import DEFAULT_BRANCH, GAS
from tensorbay.client.status import Status
from tensorbay.client.struct import ROOT_COMMIT_ID, Draft, User
from tensorbay.client.tests.utility import mock_response
Expand Down Expand Up @@ -110,7 +110,7 @@ def test__get_dataset(self, mocker):
"updateTime": 1622693494,
"owner": "",
"id": "123456",
"isPublic": DEFAULT_IS_PUBLIC,
"isPublic": False,
}
mocker.patch(
f"{gas.__name__}.GAS._list_datasets",
Expand Down Expand Up @@ -485,7 +485,7 @@ def test_upload_dataset(self, mocker):
self.gas_client,
status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID),
alias="",
is_public=DEFAULT_IS_PUBLIC,
is_public=False,
),
)
checkout = mocker.patch(f"{gas.__name__}.DatasetClient.checkout")
Expand Down
6 changes: 3 additions & 3 deletions tensorbay/client/tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from tensorbay.client.dataset import DatasetClient
from tensorbay.client.gas import DEFAULT_BRANCH, DEFAULT_IS_PUBLIC, GAS
from tensorbay.client.gas import DEFAULT_BRANCH, GAS
from tensorbay.client.job import SquashAndMergeJob
from tensorbay.client.lazy import ReturnGenerator
from tensorbay.client.status import Status
Expand All @@ -22,7 +22,7 @@ class TestVersionControlMixin:
gas_client,
status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID),
alias="",
is_public=DEFAULT_IS_PUBLIC,
is_public=False,
)


Expand All @@ -34,7 +34,7 @@ class TestJobMixin:
gas_client,
status=Status(DEFAULT_BRANCH, commit_id=ROOT_COMMIT_ID),
alias="",
is_public=DEFAULT_IS_PUBLIC,
is_public=False,
)

def test__create_job(self, mocker, mock_create_job):
Expand Down
84 changes: 82 additions & 2 deletions tensorbay/client/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

"""Related methods of the TensorBay version control."""

from typing import Any, Callable, Dict, Generator, Optional, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

from tensorbay.client.job import SquashAndMergeJob
from tensorbay.client.job import BasicSearchJob, SquashAndMergeJob
from tensorbay.client.lazy import PagingList
from tensorbay.client.requests import Client
from tensorbay.client.status import Status
Expand Down Expand Up @@ -752,3 +752,83 @@ def list_jobs(self, status: Optional[str] = None) -> PagingList[SquashAndMergeJo
lambda offset, limit: self._generate_jobs(status, offset, limit),
128,
)


class BasicSearch(JobMixin):
"""This class defines :class:`BasicSearch`.
Arguments:
client: The :class:`~tensorbay.client.requests.Client`.
dataset_id: Dataset ID.
status: The version control status of the dataset.
"""

_JOB_TYPE = "basicSearch"

def __init__(
self,
client: Client,
dataset_id: str,
status: Status,
) -> None:
pass

def _generate_jobs(
self,
status: Optional[str] = None,
offset: int = 0,
limit: int = 128,
) -> Generator[BasicSearchJob, None, int]:
pass

def create_job(
self,
title: str = "",
description: str = "",
*,
conjunction: str,
filters: List[Tuple[Any]],
unit: str = "FILE",
) -> BasicSearchJob:
"""Create a :class:`BasicSearchJob`.
Arguments:
title: The BasicSearchJob title.
description: The BasicSearchJob description.
conjunction: The logical conjunction between search filters, which includes "AND" and
"OR".
filters: The list of basic search criteria.
unit: The unit of basic search from fusion dataset. There are two options:
1. "FILE": get the data that meets search filters;
2. "FRAME": if at least one data in a frame meets search filters, all data in
the frame will be get.
Return:
The BasicSearchJob.
"""

def get_job(self, job_id: str) -> BasicSearchJob:
"""Get a :class:`BasicSearchJob`.
Arguments:
job_id: The BasicSearchJob id.
Return:
The BasicSearchJob.
"""

def list_jobs(self, status: Optional[str] = None) -> PagingList[BasicSearchJob]:
"""List the BasicSearchJob.
Arguments:
status: The BasicSearchJob status which includes "QUEUING", "PROCESSING", "SUCCESS",
"FAIL", "ABORT" and None. None means all kinds of status.
Return:
The PagingList of BasicSearchJob.
"""
Loading

0 comments on commit d9f0b8a

Please sign in to comment.