Skip to content

Commit

Permalink
Move auth parameter from extra to Hook parameter (#30212)
Browse files Browse the repository at this point in the history
For consistency, we are moving hive auth parameter to the Hook.
  • Loading branch information
potiuk authored Mar 21, 2023
1 parent 05c0841 commit f011401
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 7 deletions.
10 changes: 10 additions & 0 deletions airflow/providers/apache/hive/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
Changelog
---------

6.0.0
.....

Breaking changes
~~~~~~~~~~~~~~~~

The auth option is moved from the extra field to the auth parameter in the Hook. If you have extra
parameters defined in your connections as auth, you should move them to the DAG where your HiveOperator
or other Hive related operators are used.

5.1.3
.....

Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/apache/hive/hooks/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,13 @@ def __init__(
mapred_queue_priority: str | None = None,
mapred_job_name: str | None = None,
hive_cli_params: str = "",
auth: str | None = None,
) -> None:
super().__init__()
conn = self.get_connection(hive_cli_conn_id)
self.hive_cli_params: str = hive_cli_params
self.use_beeline: bool = conn.extra_dejson.get("use_beeline", False)
self.auth = conn.extra_dejson.get("auth", "noSasl")
self.auth = auth
self.conn = conn
self.run_as = run_as
self.sub_process: Any = None
Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/apache/hive/operators/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class HiveOperator(BaseOperator):
Possible settings include: VERY_HIGH, HIGH, NORMAL, LOW, VERY_LOW
:param mapred_job_name: This name will appear in the jobtracker.
This can make monitoring easier.
:param hive_cli_params: parameters passed to hive CLO
:param auth: optional authentication option passed for the Hive connection
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -88,6 +90,7 @@ def __init__(
mapred_queue_priority: str | None = None,
mapred_job_name: str | None = None,
hive_cli_params: str = "",
auth: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -104,6 +107,7 @@ def __init__(
self.mapred_queue_priority = mapred_queue_priority
self.mapred_job_name = mapred_job_name
self.hive_cli_params = hive_cli_params
self.auth = auth

job_name_template = conf.get_mandatory_value(
"hive",
Expand All @@ -127,6 +131,7 @@ def get_hook(self) -> HiveCliHook:
mapred_queue_priority=self.mapred_queue_priority,
mapred_job_name=self.mapred_job_name,
hive_cli_params=self.hive_cli_params,
auth=self.auth,
)

def prepare_template(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/apache/hive/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ description: |
`Apache Hive <https://hive.apache.org/>`__
versions:
- 6.0.0
- 5.1.3
- 5.1.2
- 5.1.1
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/apache/hive/transfers/mssql_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class MsSqlToHiveOperator(BaseOperator):
:param mssql_conn_id: source Microsoft SQL Server connection
:param hive_cli_conn_id: Reference to the
:ref:`Hive CLI connection id <howto/connection:hive_cli>`.
:param hive_auth: optional authentication option passed for the Hive connection
:param tblproperties: TBLPROPERTIES of the hive table being created
"""

Expand All @@ -79,6 +80,7 @@ def __init__(
delimiter: str = chr(1),
mssql_conn_id: str = "mssql_default",
hive_cli_conn_id: str = "hive_cli_default",
hive_auth: str | None = None,
tblproperties: dict | None = None,
**kwargs,
) -> None:
Expand All @@ -93,6 +95,7 @@ def __init__(
self.hive_cli_conn_id = hive_cli_conn_id
self.partition = partition or {}
self.tblproperties = tblproperties
self.hive_auth = hive_auth

@classmethod
def type_map(cls, mssql_type: int) -> str:
Expand All @@ -119,7 +122,7 @@ def execute(self, context: Context):
csv_writer.writerows(cursor)
tmp_file.flush()

hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id, auth=self.hive_auth)
self.log.info("Loading file into Hive")
hive.load_file(
tmp_file.name,
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/apache/hive/transfers/mysql_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class MySqlToHiveOperator(BaseOperator):
:param mysql_conn_id: source mysql connection
:param hive_cli_conn_id: Reference to the
:ref:`Hive CLI connection id <howto/connection:hive_cli>`.
:param hive_auth: optional authentication option passed for the Hive connection
:param tblproperties: TBLPROPERTIES of the hive table being created
"""

Expand All @@ -87,6 +88,7 @@ def __init__(
escapechar: str | None = None,
mysql_conn_id: str = "mysql_default",
hive_cli_conn_id: str = "hive_cli_default",
hive_auth: str | None = None,
tblproperties: dict | None = None,
**kwargs,
) -> None:
Expand All @@ -104,6 +106,7 @@ def __init__(
self.hive_cli_conn_id = hive_cli_conn_id
self.partition = partition or {}
self.tblproperties = tblproperties
self.hive_auth = hive_auth

@classmethod
def type_map(cls, mysql_type: int) -> str:
Expand All @@ -126,7 +129,7 @@ def type_map(cls, mysql_type: int) -> str:
return type_map.get(mysql_type, "STRING")

def execute(self, context: Context):
hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id, auth=self.hive_auth)
mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

self.log.info("Dumping MySQL query results to local file")
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/apache/hive/transfers/s3_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
input_compressed: bool = False,
tblproperties: dict | None = None,
select_expression: str | None = None,
hive_auth: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -128,14 +129,15 @@ def __init__(
self.input_compressed = input_compressed
self.tblproperties = tblproperties
self.select_expression = select_expression
self.hive_auth = hive_auth

if self.check_headers and not (self.field_dict is not None and self.headers):
raise AirflowException("To check_headers provide field_dict and headers")

def execute(self, context: Context):
# Downloading file from S3
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
hive_hook = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
hive_hook = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id, auth=self.hive_auth)
self.log.info("Downloading S3 file")

if self.wildcard_match:
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/apache/hive/transfers/vertica_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class VerticaToHiveOperator(BaseOperator):
:param vertica_conn_id: source Vertica connection
:param hive_cli_conn_id: Reference to the
:ref:`Hive CLI connection id <howto/connection:hive_cli>`.
:param hive_auth: optional authentication option passed for the Hive connection
"""

template_fields: Sequence[str] = ("sql", "partition", "hive_table")
Expand All @@ -76,6 +77,7 @@ def __init__(
delimiter: str = chr(1),
vertica_conn_id: str = "vertica_default",
hive_cli_conn_id: str = "hive_cli_default",
hive_auth: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -88,6 +90,7 @@ def __init__(
self.vertica_conn_id = vertica_conn_id
self.hive_cli_conn_id = hive_cli_conn_id
self.partition = partition or {}
self.hive_auth = hive_auth

@classmethod
def type_map(cls, vertica_type):
Expand All @@ -107,7 +110,7 @@ def type_map(cls, vertica_type):
return type_map.get(vertica_type, "STRING")

def execute(self, context: Context):
hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id, auth=self.hive_auth)
vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)

self.log.info("Dumping Vertica query results to local file")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ Extra (optional)

* ``use_beeline``
Specify as ``True`` if using the Beeline CLI. Default is ``False``.
* ``auth``
Specify the auth type for use with Hive Beeline CLI.
* ``proxy_user``
Specify a proxy user as an ``owner`` or ``login`` or keep blank if using a
custom proxy user.
Expand Down

0 comments on commit f011401

Please sign in to comment.