Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature qubole hook support headers #15683

Merged
merged 13 commits into from
May 11, 2021
3 changes: 2 additions & 1 deletion airflow/providers/qubole/hooks/qubole.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,9 @@ def get_results(
cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=self.task_id)
self.cmd = self.cls.find(cmd_id)

include_headers_str = 'true' if include_headers else 'false'
self.cmd.get_results(
fp, inline, delim, fetch, arguments=[include_headers]
fp, inline, delim, fetch, arguments=[include_headers_str]
) # type: ignore[attr-defined]
fp.flush()
fp.close()
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/qubole/hooks/qubole_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_query_results(self) -> Optional[str]:
cmd_id = self.cmd.id
self.log.info("command id: %d", cmd_id)
query_result_buffer = StringIO()
self.cmd.get_results(fp=query_result_buffer, inline=True, delim=COL_DELIM, arguments=[True])
self.cmd.get_results(fp=query_result_buffer, inline=True, delim=COL_DELIM, arguments=['true'])
query_result = query_result_buffer.getvalue()
query_result_buffer.close()
return query_result
Expand Down
44 changes: 42 additions & 2 deletions tests/providers/qubole/hooks/test_qubole.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,29 @@
# specific language governing permissions and limitations
# under the License.
#
import unittest
from unittest import TestCase, mock

from qds_sdk.commands import PrestoCommand

from airflow.providers.qubole.hooks.qubole import QuboleHook

DAG_ID = "qubole_test_dag"
TASK_ID = "test_task"
RESULTS_WITH_HEADER = 'header1\theader2\nval1\tval2'
RESULTS_WITH_NO_HEADER = 'val1\tval2'

add_tags = QuboleHook._add_tags


class TestQuboleHook(unittest.TestCase):
# pylint: disable = unused-argument
def get_result_mock(fp, inline, delim, fetch, arguments):
if arguments[0] == 'true':
fp.write(bytearray(RESULTS_WITH_HEADER, 'utf-8'))
else:
fp.write(bytearray(RESULTS_WITH_NO_HEADER, 'utf-8'))


class TestQuboleHook(TestCase):
def test_add_string_to_tags(self):
tags = {'dag_id', 'task_id'}
add_tags(tags, 'string')
Expand All @@ -38,3 +53,28 @@ def test_add_tuple_to_tags(self):
tags = {'dag_id', 'task_id'}
add_tags(tags, ('value1', 'value2'))
assert {'dag_id', 'task_id', 'value1', 'value2'} == tags

@mock.patch('qds_sdk.commands.Command.get_results', new=get_result_mock)
def test_get_results_with_headers(self):
dag = mock.MagicMock()
dag.dag_id = DAG_ID
hook = QuboleHook(task_id=TASK_ID, command_type='prestocmd', dag=dag)

task = mock.MagicMock()
task.xcom_pull.return_value = 'test_command_id'
with mock.patch('qds_sdk.resource.Resource.find', return_value=PrestoCommand):
results = open(hook.get_results(ti=task, include_headers=True)).read()
assert results == RESULTS_WITH_HEADER

@mock.patch('qds_sdk.commands.Command.get_results', new=get_result_mock)
def test_get_results_without_headers(self):
dag = mock.MagicMock()
dag.dag_id = DAG_ID
hook = QuboleHook(task_id=TASK_ID, command_type='prestocmd', dag=dag)

task = mock.MagicMock()
task.xcom_pull.return_value = 'test_command_id'

with mock.patch('qds_sdk.resource.Resource.find', return_value=PrestoCommand):
results = open(hook.get_results(ti=task, include_headers=False)).read()
assert results == RESULTS_WITH_NO_HEADER
18 changes: 16 additions & 2 deletions tests/providers/qubole/operators/test_qubole.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
#

import unittest
from unittest import TestCase, mock

from airflow import settings
from airflow.models import DAG, Connection
Expand All @@ -36,7 +36,7 @@
DEFAULT_DATE = datetime(2017, 1, 1)


class TestQuboleOperator(unittest.TestCase):
class TestQuboleOperator(TestCase):
def setUp(self):
db.merge_conn(Connection(conn_id=DEFAULT_CONN, conn_type='HTTP'))
db.merge_conn(Connection(conn_id=TEST_CONN, conn_type='HTTP', host='http://localhost/api'))
Expand Down Expand Up @@ -180,3 +180,17 @@ def test_parameter_pool_passed(self):
test_pool = 'test_pool'
op = QuboleOperator(task_id=TASK_ID, pool=test_pool)
assert op.pool == test_pool

@mock.patch('airflow.providers.qubole.hooks.qubole.QuboleHook.get_results')
def test_parameter_include_header_passed(self, mock_get_results):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
qubole_operator = QuboleOperator(task_id=TASK_ID, dag=dag, command_type='prestocmd')
qubole_operator.get_results(include_headers=True)
mock_get_results.asset_called_with('include_headers', True)

@mock.patch('airflow.providers.qubole.hooks.qubole.QuboleHook.get_results')
def test_parameter_include_header_missing(self, mock_get_results):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
qubole_operator = QuboleOperator(task_id=TASK_ID, dag=dag, command_type='prestocmd')
qubole_operator.get_results()
mock_get_results.asset_called_with('include_headers', False)