Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add download_job_result method to JobAPI #119

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tdclient/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import csv
import email.utils
import gzip
import http
import io
import json
import logging
Expand All @@ -13,7 +14,6 @@
import tempfile
import time
import urllib.parse as urlparse
import warnings
from array import array

import msgpack
Expand Down Expand Up @@ -206,6 +206,8 @@ def get(self, path, params=None, headers=None, **kwargs):
urllib3.exceptions.TimeoutStateError,
urllib3.exceptions.TimeoutError,
urllib3.exceptions.PoolError,
http.client.IncompleteRead,
TimeoutError,
socket.error,
):
pass
Expand Down Expand Up @@ -494,7 +496,6 @@ def build_request(self, path=None, headers=None, endpoint=None):
return (url, _headers)

def send_request(self, method, url, fields=None, body=None, headers=None, **kwargs):

if body is None:
return self.http.request(
method, url, fields=fields, headers=headers, **kwargs
Expand Down
13 changes: 13 additions & 0 deletions tdclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,19 @@ def job_result_format_each(self, job_id, format, header=False):
for row in self.api.job_result_format_each(job_id, format, header=header):
yield row

def download_job_result(self, job_id, path, num_threads=4):
"""Save the job result into a msgpack.gz file.
Args:
job_id (str): job id
path (str): path to save the result
num_threads (int, optional): number of threads to download the result.
Default: 4

Returns:
`True` if success
"""
return self.api.download_job_result(job_id, path, num_threads=num_threads)

def kill(self, job_id):
"""
Args:
Expand Down
72 changes: 71 additions & 1 deletion tdclient/job_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

import codecs
import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor

import msgpack

from .util import create_url, get_or_else, parse_date


log = logging.getLogger(__name__)


class JobAPI:
"""Access to Job API

Expand Down Expand Up @@ -160,7 +166,7 @@ def show_job(self, job_id):
return job

def job_status(self, job_id):
""""Show job status
"""Show job status
Args:
job_id (str): job ID

Expand Down Expand Up @@ -257,6 +263,70 @@ def job_result_format_each(self, job_id, format, header=False):
else:
yield res.read()

def download_job_result(self, job_id, path, num_threads=4):
"""Download the job result to the specified path.

Args:
job_id (int): Job ID
path (str): Path to save the job result
num_threads (int): Number of threads to download the job result. Default is 4.
"""

# Format should be msgpack.gz because file size of job is compressed in msgpack.gz format.
file_size = self.show_job(job_id)["result_size"]
url = create_url(
"/v3/job/result/{job_id}?format={format}",
job_id=job_id,
format="msgpack.gz",
)

def get_chunk(url, start, end):
chunk_headers = {"Range": f"bytes={start}-{end}"}

response = self.get(url, headers=chunk_headers)
return response

def download_chunk(url, start, end, index, file_name):
with get_chunk(url, start, end) as response:
if response.status == 206: # Partial content (range supported)
with open(f"{file_name}.part{index}", "wb") as f:
for chunk in response.stream(1024):
f.write(chunk)
return True
else:
log.warning(
f"Unexpected response status: {response.status}. Body: {response.data}"
)
return False

def combine_chunks(file_name, total_parts):
with open(file_name, "wb") as final_file:
for i in range(total_parts):
with open(f"{file_name}.part{i}", "rb") as part_file:
final_file.write(part_file.read())
os.remove(f"{file_name}.part{i}")

def download_file_multithreaded(
url, file_name, file_size, num_threads=4, chunk_size=100 * 1024**2
):
start = 0
part_index = 0

with ThreadPoolExecutor(max_workers=num_threads) as executor:
while start < file_size:
end = min(start + chunk_size - 1, file_size - 1)
executor.submit(
download_chunk, url, start, end, part_index, file_name
)

start += chunk_size
part_index += 1

combine_chunks(file_name, part_index)

download_file_multithreaded(url, path, file_size, num_threads=num_threads)
return True

def kill(self, job_id):
"""Stop the specific job if it is running.

Expand Down
50 changes: 50 additions & 0 deletions tdclient/test/job_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import datetime
import json
import tempfile
import uuid
from unittest import mock

import dateutil.tz
Expand Down Expand Up @@ -207,6 +209,54 @@ def test_job_result_json_with_header_each_success():
assert result == rows


def test_download_job_result():
td = api.API("APIKEY")
body = b"""
{
"cpu_time": null,
"created_at": "2015-02-09 11:44:25 UTC",
"database": "sample_datasets",
"debug": {
"cmdout": "started at 2015-02-09T11:44:27Z\\nexecuting query: SELECT COUNT(1) FROM nasdaq\\n",
"stderr": null
},
"duration": 1,
"end_at": "2015-02-09 11:44:28 UTC",
"hive_result_schema": "[[\\"cnt\\", \\"bigint\\"]]",
"job_id": "12345",
"organization": null,
"priority": 1,
"query": "SELECT COUNT(1) FROM nasdaq",
"result": "",
"result_size": 22,
"retry_limit": 0,
"start_at": "2015-02-09 11:44:27 UTC",
"status": "success",
"type": "presto",
"updated_at": "2015-02-09 11:44:28 UTC",
"url": "http://console.example.com/jobs/12345",
"user_name": "[email protected]",
"linked_result_export_job_id": null,
"result_export_target_job_id": null,
"num_records": 4
}
"""
data = [
{"str": "value1", "int": 1, "float": 2.3},
{"str": "value3", "int": 4, "float": 5.6},
]
body_download = gzipb(msgpackb(data))
td.get = mock.MagicMock()
td.get.side_effect = [make_response(200, body), make_response(206, body_download)]
with tempfile.TemporaryDirectory() as tempdir:
temp = os.path.join(tempdir, str(uuid.uuid4()))
td.download_job_result(12345, temp)
td.get.assert_any_call("/v3/job/show/12345")
td.get.assert_any_call("/v3/job/result/12345?format=msgpack.gz", headers={'Range': 'bytes=0-21'})
with open(temp, "rb") as f:
result = msgunpackb(gunzipb(f.read()))
assert result == data

def test_kill_success():
td = api.API("APIKEY")
# TODO: should be replaced by wire dump
Expand Down
4 changes: 4 additions & 0 deletions tdclient/test/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def read(size=None):
else:
return b""

def stream(size=None):
yield read(size)

response.read.side_effect = read
response.stream.side_effect = stream
return response


Expand Down
Loading