Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Hive/Spark on Python 3.12 #148

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 14 additions & 73 deletions apollo/integrations/db/hive_proxy_client.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,42 @@
import time
from base64 import standard_b64encode
from typing import (
Any,
Dict,
Optional,
)

from pyhive import hive
from thrift.transport import THttpClient
from TCLIService.ttypes import TOperationState
from impala import dbapi

from apollo.agent.models import AgentError
from apollo.integrations.db.base_db_proxy_client import BaseDbProxyClient

_ATTR_CONNECT_ARGS = "connect_args"


class HiveProxyCursor(hive.Cursor):
def async_execute(self, query: str, timeout: int, **kwargs: Any) -> None: # noqa
start_time = time.time()

self.execute(query, async_=True)

pending_states = (
TOperationState.INITIALIZED_STATE,
TOperationState.PENDING_STATE,
TOperationState.RUNNING_STATE,
)
time_passed = 0
while self.poll().operationState in pending_states:
time_passed = time.time() - start_time
if time_passed > timeout:
self.cancel()
break
time.sleep(10)

resp = self.poll()
if resp.operationState == TOperationState.ERROR_STATE:
msg = "Query failed, see cluster logs for details"
if time_passed > 0:
msg += f" (runtime: {time_passed}s)"
raise AgentError(msg, query, resp)
elif resp.operationState == TOperationState.CANCELED_STATE:
raise AgentError(f"Time out executing query: {time_passed}s", query, resp)


class HiveProxyConnection(hive.Connection):
def cursor(self, *args: Any, **kwargs: Any):
return HiveProxyCursor(self, *args, **kwargs)


class HiveProxyClient(BaseDbProxyClient):
"""
Proxy client for Hive. Credentials are expected to be supplied under "connect_args" and
will be passed directly to `hive.Connection`, so only attributes supported as parameters by
`hive.Connection` should be passed. If "mode" is not set to "binary", then the "connect_args"
will be used to create a new thrift transport that will be passed to `hive.Connection`.
`hive.Connection` should be passed.
"""

_MODE_BINARY = "binary"

def __init__(self, credentials: Optional[Dict], **kwargs: Any): # noqa
super().__init__(connection_type="hive")
if not credentials or _ATTR_CONNECT_ARGS not in credentials:
raise ValueError(
f"Hive agent client requires {_ATTR_CONNECT_ARGS} in credentials"
)

if credentials.get("mode") != self._MODE_BINARY:
self._connection = self._create_http_connection(
**credentials[_ATTR_CONNECT_ARGS]
)
else:
self._connection = HiveProxyConnection(**credentials[_ATTR_CONNECT_ARGS])

@classmethod
def _create_http_connection(
cls,
url: str,
username: str,
password: str,
user_agent: Optional[str] = None,
**kwargs: Any, # noqa
) -> hive.Connection:
transport = THttpClient.THttpClient(url)

auth = standard_b64encode(f"{username}:{password}".encode()).decode()
headers = dict(Authorization=f"Basic {auth}")
if user_agent:
headers["User-Agent"] = user_agent

transport.setCustomHeaders(headers)

try:
return HiveProxyConnection(thrift_transport=transport)
except EOFError or MemoryError:
raise AgentError("Error creating connection - credentials might be invalid")
self._connection = dbapi.connect(**credentials[_ATTR_CONNECT_ARGS])

def cursor(self):
# If close_finished_queries is true, impala will close every query once a DDL/DML query execution is finished
# or all rows are fetched. It will also call GetLog() before closing the query to get query metadata from Hive.
# GetLog() is not available for spark databricks causing this to break.
#
# Setting close_finished_queries to false will only close queries when execute() is called again
# or the cursor is closed. GetLog() is not automatically called so spark databricks works.
# With False the cursor will not have a rowcount for DML statements, this is fine for MC.
# https://github.com/cloudera/impyla/blob/e4c76169f7e5765c09b11c92fceb862dbb9b72be/impala/hiveserver2.py#L122
return self._connection.cursor(close_finished_queries=False)

@property
def wrapped_client(self):
Expand Down
8 changes: 3 additions & 5 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ tableauserverclient @ git+https://github.com/tableau/server-client-python.git@ma
teradatasql>=20.0.0.15
oscrypto @ git+https://github.com/wbond/oscrypto@master

# Note: 'pyhive[hive]' extras uses sasl that does not support Python 3.11,
# See https://github.com/cloudera/python-sasl/issues/30. Hence PyHive also supports
# pure-sasl via additional extras 'pyhive[hive_pure_sasl]' which support Python 3.11.
pyhive[hive_pure_sasl]==0.7.0 ; python_version >= "3.11"
pyhive[hive]==0.6.5 ; python_version < "3.11"
# Note this is a beta version of impyla that is needed in order to support HTTPS connections on python 3.12.
# It should be updated to stable version 0.20.0 once that is released.
impyla==0.20a1
werkzeug==3.0.3
20 changes: 9 additions & 11 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with Python 3.11
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile requirements.in
Expand Down Expand Up @@ -29,6 +29,8 @@ azure-mgmt-storage==21.2.1
# via -r requirements.in
azure-storage-blob==12.23.0
# via -r requirements.in
bitarray==3.0.0
# via impyla
blinker==1.8.2
# via flask
boto3==1.34.151
Expand Down Expand Up @@ -89,8 +91,6 @@ flask==2.3.3
# flask-compress
flask-compress==1.14
# via -r requirements.in
future==1.0.0
# via pyhive
google-api-core==2.19.1
# via
# google-api-python-client
Expand Down Expand Up @@ -129,6 +129,8 @@ idna==3.7
# via
# requests
# snowflake-connector-python
impyla==0.20a1
# via -r requirements.in
isodate==0.6.1
# via
# azure-mgmt-storage
Expand Down Expand Up @@ -206,9 +208,7 @@ protobuf==4.25.3
psycopg2-binary==2.9.9
# via -r requirements.in
pure-sasl==0.6.2
# via
# pyhive
# thrift-sasl
# via thrift-sasl
pyarrow==14.0.1
# via
# -r requirements.in
Expand All @@ -225,8 +225,6 @@ pycryptodome==3.20.0
# via
# -r requirements.in
# teradatasql
pyhive[hive-pure-sasl]==0.7.0 ; python_version >= "3.11"
# via -r requirements.in
pyjwt[crypto]==2.8.0
# via
# -r requirements.in
Expand All @@ -246,7 +244,6 @@ python-dateutil==2.9.0.post0
# via
# botocore
# pandas
# pyhive
pytz==2024.1
# via
# pandas
Expand Down Expand Up @@ -274,6 +271,7 @@ s3transfer==0.10.2
six==1.16.0
# via
# azure-core
# impyla
# isodate
# presto-python-client
# python-dateutil
Expand All @@ -294,10 +292,10 @@ teradatasql==20.0.0.15
thrift==0.16.0
# via
# databricks-sql-connector
# pyhive
# impyla
# thrift-sasl
thrift-sasl==0.4.3
# via pyhive
# via impyla
tomlkit==0.12.5
# via snowflake-connector-python
typing-extensions==4.12.2
Expand Down
23 changes: 9 additions & 14 deletions tests/test_hive_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
Optional,
)
from unittest import TestCase
from unittest.mock import (
Mock,
call,
patch,
)
from unittest.mock import Mock, call, patch

from apollo.agent.agent import Agent
from apollo.agent.constants import (
Expand All @@ -22,9 +18,11 @@
_HIVE_CREDENTIALS = {
"host": "localhost",
"port": "10000",
"username": "foo",
"user": "foo",
"database": "fizz",
"auth": None,
"auth_mechanism": "PLAIN",
"timeout": 870,
"use_ssl": False,
}


Expand All @@ -35,7 +33,7 @@ def setUp(self) -> None:
self._mock_cursor = Mock()
self._mock_connection.cursor.return_value = self._mock_cursor

@patch("apollo.integrations.db.hive_proxy_client.HiveProxyConnection")
@patch("apollo.integrations.db.hive_proxy_client.dbapi.connect")
def test_query(self, mock_connect):
query = "SELECT idx, value FROM table" # noqa
expected_data = [
Expand Down Expand Up @@ -70,7 +68,7 @@ def _test_run_query(
{"method": "cursor", "store": "_cursor"},
{
"target": "_cursor",
"method": "async_execute",
"method": "execute",
"args": [
query,
None,
Expand Down Expand Up @@ -104,10 +102,7 @@ def _test_run_query(
"hive",
"run_query",
operation_dict,
{
"connect_args": _HIVE_CREDENTIALS,
"mode": "binary",
},
{"connect_args": _HIVE_CREDENTIALS},
)

if raise_exception:
Expand All @@ -124,7 +119,7 @@ def _test_run_query(
result = response.result.get(ATTRIBUTE_NAME_RESULT)

mock_connect.assert_called_with(**_HIVE_CREDENTIALS)
self._mock_cursor.async_execute.assert_has_calls(
self._mock_cursor.execute.assert_has_calls(
[
call(query, None),
]
Expand Down