diff --git a/airflow/providers/http/CHANGELOG.rst b/airflow/providers/http/CHANGELOG.rst index 81f9ca1ad3a6e..a784cb73ed74c 100644 --- a/airflow/providers/http/CHANGELOG.rst +++ b/airflow/providers/http/CHANGELOG.rst @@ -24,6 +24,16 @@ Changelog --------- +Breaking changes +~~~~~~~~~~~~~~~~ + +The SimpleHTTPOperator, HttpSensor and HttpHook use now TCP_KEEPALIVE by default. +You can disable it by setting ``tcp_keep_alive`` to False and you can control keepalive parameters +by new ``tcp_keep_alive_*`` parameters added to constructor of the Hook, Operator and Sensor. Setting the +TCP_KEEPALIVE prevents some firewalls from closing a long-running connection that has long periods of +inactivity by sending empty TCP packets periodically. This has a very small impact on network traffic, +and potentially prevents the idle/hanging connections from being closed automatically by the firewalls. + 3.0.0 ..... diff --git a/airflow/providers/http/hooks/http.py b/airflow/providers/http/hooks/http.py index 98a36ec747073..373cd866c02ad 100644 --- a/airflow/providers/http/hooks/http.py +++ b/airflow/providers/http/hooks/http.py @@ -20,6 +20,7 @@ import requests import tenacity from requests.auth import HTTPBasicAuth +from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook @@ -34,6 +35,11 @@ class HttpHook(BaseHook): API url i.e https://www.google.com/ and optional authentication credentials. Default headers can also be specified in the Extra field in json format. :param auth_type: The auth type for the service + :param tcp_keep_alive: Enable TCP Keep Alive for the connection. + :param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds to ``socket.TCP_KEEPIDLE``). + :param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``) + :param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to + ``socket.TCP_KEEPINTVL``) """ conn_name_attr = 'http_conn_id' @@ -46,6 +52,10 @@ def __init__( method: str = 'POST', http_conn_id: str = default_conn_name, auth_type: Any = HTTPBasicAuth, + tcp_keep_alive: bool = True, + tcp_keep_alive_idle: int = 120, + tcp_keep_alive_count: int = 20, + tcp_keep_alive_interval: int = 30, ) -> None: super().__init__() self.http_conn_id = http_conn_id @@ -53,6 +63,10 @@ def __init__( self.base_url: str = "" self._retry_obj: Callable[..., Any] self.auth_type: Any = auth_type + self.tcp_keep_alive = tcp_keep_alive + self.keep_alive_idle = tcp_keep_alive_idle + self.keep_alive_count = tcp_keep_alive_count + self.keep_alive_interval = tcp_keep_alive_interval # headers may be passed through directly or in the "extra" field in the connection # definition @@ -115,6 +129,11 @@ def run( url = self.url_from_endpoint(endpoint) + if self.tcp_keep_alive: + keep_alive_adapter = TCPKeepAliveAdapter( + idle=self.keep_alive_idle, count=self.keep_alive_count, interval=self.keep_alive_interval + ) + session.mount(url, keep_alive_adapter) if self.method == 'GET': # GET uses params req = requests.Request(self.method, url, params=data, headers=headers, **request_kwargs) diff --git a/airflow/providers/http/operators/http.py b/airflow/providers/http/operators/http.py index a304c1d79faeb..425f2ca602def 100644 --- a/airflow/providers/http/operators/http.py +++ b/airflow/providers/http/operators/http.py @@ -54,6 +54,11 @@ class SimpleHttpOperator(BaseOperator): 'requests' documentation (options to modify timeout, ssl, etc.) :param log_response: Log the response (default: False) :param auth_type: The auth type for the service + :param tcp_keep_alive: Enable TCP Keep Alive for the connection. + :param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds to ``socket.TCP_KEEPIDLE``). + :param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``) + :param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to + ``socket.TCP_KEEPINTVL``) """ template_fields: Sequence[str] = ( @@ -78,6 +83,10 @@ def __init__( http_conn_id: str = 'http_default', log_response: bool = False, auth_type: Type[AuthBase] = HTTPBasicAuth, + tcp_keep_alive: bool = True, + tcp_keep_alive_idle: int = 120, + tcp_keep_alive_count: int = 20, + tcp_keep_alive_interval: int = 30, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -91,11 +100,23 @@ def __init__( self.extra_options = extra_options or {} self.log_response = log_response self.auth_type = auth_type + self.tcp_keep_alive = tcp_keep_alive + self.tcp_keep_alive_idle = tcp_keep_alive_idle + self.tcp_keep_alive_count = tcp_keep_alive_count + self.tcp_keep_alive_interval = tcp_keep_alive_interval def execute(self, context: 'Context') -> Any: from airflow.utils.operator_helpers import determine_kwargs - http = HttpHook(self.method, http_conn_id=self.http_conn_id, auth_type=self.auth_type) + http = HttpHook( + self.method, + http_conn_id=self.http_conn_id, + auth_type=self.auth_type, + tcp_keep_alive=self.tcp_keep_alive, + tcp_keep_alive_idle=self.tcp_keep_alive_idle, + tcp_keep_alive_count=self.tcp_keep_alive_count, + tcp_keep_alive_interval=self.tcp_keep_alive_interval, + ) self.log.info("Calling HTTP method") diff --git a/airflow/providers/http/provider.yaml b/airflow/providers/http/provider.yaml index 4437744507fd8..bf1b8a5b18c87 100644 --- a/airflow/providers/http/provider.yaml +++ b/airflow/providers/http/provider.yaml @@ -38,6 +38,7 @@ dependencies: # The 2.26.0 release of requests got rid of the chardet LGPL mandatory dependency, allowing us to # release it as a requirement for airflow - requests>=2.26.0 + - requests_toolbelt integrations: - integration-name: Hypertext Transfer Protocol (HTTP) diff --git a/airflow/providers/http/sensors/http.py b/airflow/providers/http/sensors/http.py index 75b4ccd3c225f..f2fd4add06df9 100644 --- a/airflow/providers/http/sensors/http.py +++ b/airflow/providers/http/sensors/http.py @@ -67,6 +67,11 @@ def response_check(response, task_instance): It should return True for 'pass' and False otherwise. :param extra_options: Extra options for the 'requests' library, see the 'requests' documentation (options to modify timeout, ssl, etc.) + :param tcp_keep_alive: Enable TCP Keep Alive for the connection. + :param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds to ``socket.TCP_KEEPIDLE``). + :param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``) + :param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to + ``socket.TCP_KEEPINTVL``) """ template_fields: Sequence[str] = ('endpoint', 'request_params', 'headers') @@ -81,6 +86,10 @@ def __init__( headers: Optional[Dict[str, Any]] = None, response_check: Optional[Callable[..., bool]] = None, extra_options: Optional[Dict[str, Any]] = None, + tcp_keep_alive: bool = True, + tcp_keep_alive_idle: int = 120, + tcp_keep_alive_count: int = 20, + tcp_keep_alive_interval: int = 30, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -91,15 +100,26 @@ def __init__( self.headers = headers or {} self.extra_options = extra_options or {} self.response_check = response_check - - self.hook = HttpHook(method=method, http_conn_id=http_conn_id) + self.tcp_keep_alive = tcp_keep_alive + self.tcp_keep_alive_idle = tcp_keep_alive_idle + self.tcp_keep_alive_count = tcp_keep_alive_count + self.tcp_keep_alive_interval = tcp_keep_alive_interval def poke(self, context: 'Context') -> bool: from airflow.utils.operator_helpers import determine_kwargs + hook = HttpHook( + method=self.method, + http_conn_id=self.http_conn_id, + tcp_keep_alive=self.tcp_keep_alive, + tcp_keep_alive_idle=self.tcp_keep_alive_idle, + tcp_keep_alive_count=self.tcp_keep_alive_count, + tcp_keep_alive_interval=self.tcp_keep_alive_interval, + ) + self.log.info('Poking: %s', self.endpoint) try: - response = self.hook.run( + response = hook.run( self.endpoint, data=self.request_params, headers=self.headers, diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 70ad0658fbf34..203fa5e22aff1 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -376,7 +376,8 @@ }, "http": { "deps": [ - "requests>=2.26.0" + "requests>=2.26.0", + "requests_toolbelt" ], "cross-providers-deps": [] }, diff --git a/tests/providers/http/hooks/test_http.py b/tests/providers/http/hooks/test_http.py index 69855f21a001b..0fdc8dd2aba84 100644 --- a/tests/providers/http/hooks/test_http.py +++ b/tests/providers/http/hooks/test_http.py @@ -19,6 +19,7 @@ import os import unittest from collections import OrderedDict +from http import HTTPStatus from unittest import mock import pytest @@ -26,6 +27,7 @@ import requests_mock import tenacity from parameterized import parameterized +from requests.adapters import Response from airflow.exceptions import AirflowException from airflow.models import Connection @@ -370,4 +372,40 @@ def test_connection_failure(self, m): assert msg == '500:NOT_OK' +class TestKeepAlive: + def test_keep_alive_enabled(self): + with mock.patch( + 'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port + ), mock.patch( + 'requests_toolbelt.adapters.socket_options.TCPKeepAliveAdapter.send' + ) as tcp_keep_alive_send, mock.patch( + 'requests.adapters.HTTPAdapter.send' + ) as http_send: + hook = HttpHook(method='GET') + response = Response() + response.status_code = HTTPStatus.OK + tcp_keep_alive_send.return_value = response + http_send.return_value = response + hook.run('v1/test') + tcp_keep_alive_send.assert_called() + http_send.assert_not_called() + + def test_keep_alive_disabled(self): + with mock.patch( + 'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port + ), mock.patch( + 'requests_toolbelt.adapters.socket_options.TCPKeepAliveAdapter.send' + ) as tcp_keep_alive_send, mock.patch( + 'requests.adapters.HTTPAdapter.send' + ) as http_send: + hook = HttpHook(method='GET', tcp_keep_alive=False) + response = Response() + response.status_code = HTTPStatus.OK + tcp_keep_alive_send.return_value = response + http_send.return_value = response + hook.run('v1/test') + tcp_keep_alive_send.assert_not_called() + http_send.assert_called() + + send_email_test = mock.Mock() diff --git a/tests/providers/http/sensors/test_http.py b/tests/providers/http/sensors/test_http.py index 2c9a2bbb7f5c8..d0dbaa24215f2 100644 --- a/tests/providers/http/sensors/test_http.py +++ b/tests/providers/http/sensors/test_http.py @@ -169,18 +169,18 @@ def resp_check(_): poke_interval=1, ) - with mock.patch.object(task.hook.log, 'error') as mock_errors: + with mock.patch('airflow.providers.http.hooks.http.HttpHook.log') as mock_log: with pytest.raises(AirflowSensorTimeout): task.execute(None) - assert mock_errors.called + assert mock_log.error.called calls = [ mock.call('HTTP error: %s', 'Not Found'), mock.call("This endpoint doesn't exist"), mock.call('HTTP error: %s', 'Not Found'), mock.call("This endpoint doesn't exist"), ] - mock_errors.assert_has_calls(calls) + mock_log.error.assert_has_calls(calls) class FakeSession: @@ -200,6 +200,9 @@ def prepare_request(self, request): def merge_environment_settings(self, _url, **kwargs): return kwargs + def mount(self, prefix, adapter): + pass + class TestHttpOpSensor(unittest.TestCase): def setUp(self):