Skip to content

Commit

Permalink
refactor(client): add "_JOB_TYPE" to class JobMixin in version.py
Browse files Browse the repository at this point in the history
PR Closed: #1242
  • Loading branch information
graczhual committed Apr 6, 2022
1 parent 29d8a32 commit c6d90c1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 29 deletions.
16 changes: 8 additions & 8 deletions tensorbay/client/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__( # pylint: disable=too-many-arguments
self,
client: Client,
dataset_id: str,
job_updater: Callable[[str, str], Dict[str, Any]],
job_updater: Callable[[str], Dict[str, Any]],
title: str,
job_id: str,
job_type: str,
Expand Down Expand Up @@ -107,7 +107,7 @@ def from_response_body(
*,
client: Client,
dataset_id: str,
job_updater: Callable[[str, str], Dict[str, Any]], # noqa: DAR101
job_updater: Callable[[str], Dict[str, Any]], # noqa: DAR101
) -> _T:
"""Loads a :class:`Job` object from a response body.
Expand Down Expand Up @@ -150,12 +150,12 @@ def update(self, until_complete: bool = False) -> None:
until_complete: Whether to update job information until it is complete.
"""
job_info = self._job_updater(self.job_id, self._job_type)
job_info = self._job_updater(self.job_id)

if until_complete:
while job_info["status"] in _JOB_NOT_COMPLETE_STATUS:
sleep(_JOB_UPDATE_INTERVAL)
job_info = self._job_updater(self.job_id, self._job_type)
job_info = self._job_updater(self.job_id)

self.started_at = job_info.get("startedAt")
self.finished_at = job_info.get("finishedAt")
Expand Down Expand Up @@ -201,7 +201,7 @@ def __init__( # pylint: disable=too-many-locals
client: Client,
*,
dataset_id: str,
job_updater: Callable[[str, str], Dict[str, Any]],
job_updater: Callable[[str], Dict[str, Any]],
draft_getter: Callable[[int], Draft],
title: str,
job_id: str,
Expand Down Expand Up @@ -254,7 +254,7 @@ def from_response_body( # type: ignore[override] # pylint: disable=arguments-d
*,
client: Client,
dataset_id: str,
job_updater: Callable[[str, str], Dict[str, Any]], # noqa: DAR101
job_updater: Callable[[str], Dict[str, Any]], # noqa: DAR101
draft_getter: Callable[[int], Draft],
) -> _T:
"""Loads a :class:`SquashAndMergeJob` object from a response body.
Expand Down Expand Up @@ -303,7 +303,7 @@ def __init__( # pylint: disable=too-many-locals
client: Client,
*,
dataset_id: str,
job_updater: Callable[[str, str], Dict[str, Any]],
job_updater: Callable[[str], Dict[str, Any]],
is_fusion: bool,
title: str,
job_id: str,
Expand Down Expand Up @@ -358,7 +358,7 @@ def from_response_body( # type: ignore[override] # pylint: disable=arguments-d
*,
client: Client,
dataset_id: str,
job_updater: Callable[[str, str], Dict[str, Any]], # noqa: DAR101
job_updater: Callable[[str], Dict[str, Any]], # noqa: DAR101
is_fusion: bool,
) -> _T:
"""Loads a :class:`BasicSearchJob` object from a response body.
Expand Down
5 changes: 2 additions & 3 deletions tensorbay/client/tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test__create_job(self, mocker, mock_create_job):
)
assert response_data == self.dataset_client.squash_and_merge._create_job(
post_data["title"],
post_data["jobType"],
post_data["arguments"],
post_data["description"],
)
Expand All @@ -65,7 +64,7 @@ def test__get_job(self, mocker, mock_get_job):
job_id = "123"
job_type = "squashAndMerge"
open_api_do, response_data = mock_get_job(mocker)
assert response_data == self.dataset_client.squash_and_merge._get_job(job_id, job_type)
assert response_data == self.dataset_client.squash_and_merge._get_job(job_id)
open_api_do.assert_called_once_with(
"GET", f"jobs/{job_id}", self.dataset_client.dataset_id, params={"jobType": job_type}
)
Expand All @@ -79,7 +78,7 @@ def test__list_jobs(self, mocker, mock_list_jobs):
}
open_api_do, response_data = mock_list_jobs(mocker)
assert response_data == self.dataset_client.squash_and_merge._list_jobs(
params["jobType"], params["status"], params["offset"], params["limit"]
params["status"], params["offset"], params["limit"]
)
open_api_do.assert_called_once_with(
"GET", "jobs", self.dataset_client.dataset_id, params=params
Expand Down
36 changes: 18 additions & 18 deletions tensorbay/client/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,27 +534,30 @@ class JobMixin:
_dataset_id: str
_client: Client
_status: Status
_JOB_TYPE: str

def _create_job(
self,
title: str,
job_type: str,
arguments: Dict[str, Any],
description: str = "",
) -> Dict[str, Any]:
"""Create a :class:`Job`.
Arguments:
title: The Job title.
job_type: The type of Job.
arguments: The arguments dict of the specific job.
description: The Job description.
Returns:
The info of the job.
"""
post_data: Dict[str, Any] = {"title": title, "jobType": job_type, "arguments": arguments}
post_data: Dict[str, Any] = {
"title": title,
"jobType": self._JOB_TYPE,
"arguments": arguments,
}
if description:
post_data["description"] = description

Expand All @@ -564,44 +567,41 @@ def _create_job(

response.update(
title=title,
jobType=job_type,
jobType=self._JOB_TYPE,
arguments=arguments,
status="QUEUING",
description=description,
)
return response

def _get_job(self, job_id: str, job_type: str) -> Dict[str, Any]:
def _get_job(self, job_id: str) -> Dict[str, Any]:
"""Get a :class:`Job`.
Arguments:
job_id: The Job id.
job_type: The type of Job.
Returns:
The info of Job.
"""
params = {"jobType": job_type}
params = {"jobType": self._JOB_TYPE}

response: Dict[str, Any] = self._client.open_api_do(
"GET", f"jobs/{job_id}", self._dataset_id, params=params
).json()

response.update(jobType=job_type)
response.update(jobType=self._JOB_TYPE)
return response

def _list_jobs(
self,
job_type: str,
status: Optional[str] = None,
offset: int = 0,
limit: int = 128,
) -> Dict[str, Any]:
"""Get a dict containing the information of :class:`Job` list.
Arguments:
job_type: Type of the Job.
status: The Job status which includes "QUEUING", "PROCESSING", "SUCCESS", "FAILED",
"ABORTED" and None. None means all kinds of status.
offset: The offset of the page.
Expand All @@ -611,13 +611,13 @@ def _list_jobs(
A dict containing the information of Job list.
"""
params = {"jobType": job_type, "status": status, "offset": offset, "limit": limit}
params = {"jobType": self._JOB_TYPE, "status": status, "offset": offset, "limit": limit}

response: Dict[str, Any] = self._client.open_api_do(
"GET", "jobs", self._dataset_id, params=params
).json()

response.update(jobType=job_type)
response.update(jobType=self._JOB_TYPE)
return response

def delete_job(self, job_id: str) -> None:
Expand Down Expand Up @@ -661,7 +661,7 @@ def _generate_jobs(
offset: int = 0,
limit: int = 128,
) -> Generator[SquashAndMergeJob, None, int]:
response = self._list_jobs(self._JOB_TYPE, status, offset, limit)
response = self._list_jobs(status, offset, limit)
for item in response["jobs"]:
yield SquashAndMergeJob.from_response_body(
item,
Expand Down Expand Up @@ -731,7 +731,7 @@ def create_job(
if draft_description:
arguments["description"] = draft_description

job_info = self._create_job(title, self._JOB_TYPE, arguments, description)
job_info = self._create_job(title, arguments, description)
return SquashAndMergeJob.from_response_body(
job_info,
dataset_id=self._dataset_id,
Expand All @@ -750,7 +750,7 @@ def get_job(self, job_id: str) -> SquashAndMergeJob:
The SquashAndMergeJob.
"""
job_info = self._get_job(job_id, self._JOB_TYPE)
job_info = self._get_job(job_id)
return SquashAndMergeJob.from_response_body(
job_info,
dataset_id=self._dataset_id,
Expand Down Expand Up @@ -801,7 +801,7 @@ def _generate_jobs(
offset: int = 0,
limit: int = 128,
) -> Generator[BasicSearchJob, None, int]:
response = self._list_jobs(self._JOB_TYPE, status, offset, limit)
response = self._list_jobs(status, offset, limit)
for item in response["jobs"]:
yield BasicSearchJob.from_response_body(
item,
Expand Down Expand Up @@ -862,7 +862,7 @@ def create_job(
"unit": unit,
}

job_info = self._create_job(title, self._JOB_TYPE, arguments, description)
job_info = self._create_job(title, arguments, description)
return BasicSearchJob.from_response_body(
job_info,
dataset_id=self._dataset_id,
Expand All @@ -881,7 +881,7 @@ def get_job(self, job_id: str) -> BasicSearchJob:
The BasicSearchJob.
"""
job_info = self._get_job(job_id, self._JOB_TYPE)
job_info = self._get_job(job_id)
return BasicSearchJob.from_response_body(
job_info,
dataset_id=self._dataset_id,
Expand Down

0 comments on commit c6d90c1

Please sign in to comment.