Skip to content

Commit

Permalink
feat(client): adapt openAPI "getData"
Browse files Browse the repository at this point in the history
  • Loading branch information
graczhual committed Nov 9, 2021
1 parent c57ea4b commit 5069316
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
69 changes: 68 additions & 1 deletion tensorbay/client/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from tensorbay.client.requests import config
from tensorbay.client.status import Status
from tensorbay.dataset import AuthData, Data, Frame, RemoteData
from tensorbay.exception import FrameError, InvalidParamsError, ResponseError
from tensorbay.exception import FrameError, InvalidParamsError, ResourceNotExistError, ResponseError
from tensorbay.label import Label
from tensorbay.sensor.sensor import Sensor, Sensors
from tensorbay.utility import URL, FileMixin, chunked, locked
Expand Down Expand Up @@ -123,6 +123,19 @@ def _list_urls(self, offset: int = 0, limit: int = 128) -> Dict[str, Any]:
response = self._client.open_api_do("GET", "data/urls", self._dataset_id, params=params)
return response.json() # type: ignore[no-any-return]

def _get_data_details(self, remote_path: str) -> Dict[str, Any]:
params: Dict[str, Any] = {
"segmentName": self._name,
"remotePath": remote_path,
}
params.update(self._status.get_status_info())

if config.is_internal:
params["isInternal"] = True

response = self._client.open_api_do("GET", "data/details", self._dataset_id, params=params)
return response.json()["dataDetails"][0] # type: ignore[no-any-return]

def _list_data_details(self, offset: int = 0, limit: int = 128) -> Dict[str, Any]:
params: Dict[str, Any] = {
"segmentName": self._name,
Expand All @@ -137,6 +150,20 @@ def _list_data_details(self, offset: int = 0, limit: int = 128) -> Dict[str, Any
response = self._client.open_api_do("GET", "data/details", self._dataset_id, params=params)
return response.json() # type: ignore[no-any-return]

def _get_mask_url(self, mask_type: str, remote_path: str) -> str:
params: Dict[str, Any] = {
"segmentName": self._name,
"maskType": mask_type,
"remotePath": remote_path,
}
params.update(self._status.get_status_info())

if config.is_internal:
params["isInternal"] = True

response = self._client.open_api_do("GET", "masks/urls", self._dataset_id, params=params)
return response.json()["urls"][0]["url"] # type: ignore[no-any-return]

def _list_mask_urls(self, mask_type: str, offset: int = 0, limit: int = 128) -> Dict[str, Any]:
params: Dict[str, Any] = {
"segmentName": self._name,
Expand Down Expand Up @@ -653,6 +680,46 @@ def list_data_paths(self) -> PagingList[str]:
"""
return PagingList(self._generate_data_paths, 128)

def get_data(self, remote_path: str) -> RemoteData:
"""Get required Data object from a dataset segment.
Arguments:
remote_path: The remote paths of the required data.
Returns:
:class:`~tensorbay.dataset.data.RemoteData`.
Raises:
ResourceNotExistError: When the required data does not exist.
"""
if remote_path not in self.list_data_paths():
raise ResourceNotExistError(resource="data", identification=remote_path)

response = self._get_data_details(remote_path)
data_details = response["dataDetails"][0]

data = RemoteData.from_response_body(
data_details,
url=URL(data_details["url"], lambda: self._get_url(remote_path)),
cache_path=self._cache_path,
)
label = data.label

for key in _MASK_KEYS:
mask = getattr(label, key, None)
if mask:
# pylint: disable=protected-access
mask.url = URL.from_getter(
lambda k=key.upper(), r=remote_path: self._get_mask_url(k, r),
lambda k=key.upper(), r=remote_path: ( # type: ignore[misc, arg-type]
self._get_mask_url(k, r),
),
)
mask.cache_path = os.path.join(self._cache_path, key, mask.path)

return data

def list_data(self) -> PagingList[RemoteData]:
"""List required Data object in a dataset segment.
Expand Down
26 changes: 26 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,32 @@


class TestData:
def test_get_data(self, accesskey, url, tmp_path):
gas_client = GAS(access_key=accesskey, url=url)
dataset_name = get_dataset_name()
dataset_client = gas_client.create_dataset(dataset_name)

dataset_client.create_draft("draft-1")
dataset_client.upload_catalog(Catalog.loads(CATALOG))
segment_client = dataset_client.get_or_create_segment("segment1")
path = tmp_path / "sub"
path.mkdir()

# Upload data with label
for i in range(10):
local_path = path / f"hello{i}.txt"
local_path.write_text(f"CONTENT{i}")
data = Data(local_path=str(local_path))
data.label = Label.loads(LABEL)
segment_client.upload_data(data)

for i in range(10):
data = segment_client.get_data(f"hello{i}.txt")
assert data.path == f"hello{i}.txt"
assert data.label == Label.loads(LABEL)

gas_client.delete_dataset(dataset_name)

def test_list_file_order(self, accesskey, url, tmp_path):
gas_client = GAS(access_key=accesskey, url=url)
dataset_name = get_dataset_name()
Expand Down

0 comments on commit 5069316

Please sign in to comment.