From 7fe5b6d75542ace8b4fba99f13f6d193718d52f7 Mon Sep 17 00:00:00 2001 From: "changjun.zhu" Date: Fri, 24 Dec 2021 15:41:53 +0800 Subject: [PATCH] fix(client): fix AttributeError of result property of SquashAndMergeJob PR Closed: https://github.com/Graviti-AI/tensorbay-python-sdk/pull/1183 --- tensorbay/client/dataset.py | 2 +- tensorbay/client/job.py | 46 +++++++++++++++++++++++++++++++++++++ tensorbay/client/version.py | 37 ++++++++++++++++++----------- 3 files changed, 71 insertions(+), 14 deletions(-) diff --git a/tensorbay/client/dataset.py b/tensorbay/client/dataset.py index c67e0ffdd..e274d53bf 100644 --- a/tensorbay/client/dataset.py +++ b/tensorbay/client/dataset.py @@ -226,7 +226,7 @@ def squash_and_merge(self) -> SquashAndMerge: Required :class:`~tensorbay.client.version.SquashAndMerge`. """ - return SquashAndMerge(self._dataset_id, self._client, self._status) + return SquashAndMerge(self._client, self._dataset_id, self._status, self.get_draft) def enable_cache(self, cache_path: str = "") -> None: """Enable cache when open the remote data of the dataset. diff --git a/tensorbay/client/job.py b/tensorbay/client/job.py index 74eb67a77..50fdc254b 100644 --- a/tensorbay/client/job.py +++ b/tensorbay/client/job.py @@ -186,6 +186,8 @@ class SquashAndMergeJob(Job): """ + _T = TypeVar("_T", bound="SquashAndMergeJob") + def __init__( # pylint: disable=too-many-arguments self, client: Client, @@ -233,3 +235,47 @@ def result(self) -> Optional[Draft]: return self._draft_getter(draft_number) return None + + @classmethod + def from_response_body( # type: ignore[override] # pylint: disable=arguments-differ + cls: Type[_T], + body: Dict[str, Any], + *, + client: Client, + dataset_id: str, + job_updater: Callable[[str], Dict[str, Any]], # noqa: DAR101 + draft_getter: Callable[[int], Draft], + ) -> _T: + """Loads a :class:`SquashAndMergeJob` object from a response body. + + Arguments: + body: The response body which contains the information of a SquashAndMergeJob, + whose format should be like:: + + { + "title": + "jobId": + "arguments": + "createdAt": + "startedAt": + "finishedAt": + "status": + "errorMessage": + "result": + "description": + } + client: The :class:`~tensorbay.client.requests.Client`. + dataset_id: Dataset ID. + job_updater: The function to update the information of the SquashAndMergeJob instance. + draft_getter: The function to get draft by draft_number. + + Returns: + The loaded :class:`SquashAndMergeJob` object. + + """ + job = super().from_response_body( + body, client=client, dataset_id=dataset_id, job_updater=job_updater + ) + job._draft_getter = draft_getter # pylint: disable=protected-access + + return job diff --git a/tensorbay/client/version.py b/tensorbay/client/version.py index dda4d3fba..35e9799d9 100644 --- a/tensorbay/client/version.py +++ b/tensorbay/client/version.py @@ -5,7 +5,7 @@ """Related methods of the TensorBay version control.""" -from typing import Any, Dict, Generator, Optional, Union +from typing import Any, Callable, Dict, Generator, Optional, Union from tensorbay.client.job import SquashAndMergeJob from tensorbay.client.lazy import PagingList @@ -612,9 +612,10 @@ class SquashAndMerge(JobMixin): """This class defines :class:`SquashAndMerge`. Arguments: - dataset_id: Dataset ID. client: The :class:`~tensorbay.client.requests.Client`. + dataset_id: Dataset ID. status: The version control status of the dataset. + draft_getter: The function to get draft by draft_number. """ @@ -622,13 +623,15 @@ class SquashAndMerge(JobMixin): def __init__( self, - dataset_id: str, client: Client, + dataset_id: str, status: Status, + draft_getter: Callable[[int], Draft], ) -> None: - self._dataset_id = dataset_id self._client = client + self._dataset_id = dataset_id self._status = status + self._draft_getter = draft_getter def _generate_jobs( self, @@ -639,7 +642,11 @@ def _generate_jobs( response = self._list_jobs(self._JOB_TYPE, status, offset, limit) for item in response["jobs"]: yield SquashAndMergeJob.from_response_body( - item, dataset_id=self._dataset_id, client=self._client, job_updater=self._get_job + item, + dataset_id=self._dataset_id, + client=self._client, + job_updater=self._get_job, + draft_getter=self._draft_getter, ) return response["totalCount"] # type: ignore[no-any-return] @@ -703,12 +710,14 @@ def create_job( arguments["description"] = draft_description job_info = self._create_job(title, self._JOB_TYPE, arguments, description) - job = SquashAndMergeJob.from_response_body( - job_info, dataset_id=self._dataset_id, client=self._client, job_updater=self._get_job + return SquashAndMergeJob.from_response_body( + job_info, + dataset_id=self._dataset_id, + client=self._client, + job_updater=self._get_job, + draft_getter=self._draft_getter, ) - return job - def get_job(self, job_id: str) -> SquashAndMergeJob: """Get a :class:`SquashAndMergeJob`. @@ -720,12 +729,14 @@ def get_job(self, job_id: str) -> SquashAndMergeJob: """ job_info = self._get_job(job_id) - job = SquashAndMergeJob.from_response_body( - job_info, dataset_id=self._dataset_id, client=self._client, job_updater=self._get_job + return SquashAndMergeJob.from_response_body( + job_info, + dataset_id=self._dataset_id, + client=self._client, + job_updater=self._get_job, + draft_getter=self._draft_getter, ) - return job - def list_jobs(self, status: Optional[str] = None) -> PagingList[SquashAndMergeJob]: """List the SquashAndMergeJob.