diff --git a/airflow/providers/http/operators/http.py b/airflow/providers/http/operators/http.py index 6a86a73404e3..4dc3f1679279 100644 --- a/airflow/providers/http/operators/http.py +++ b/airflow/providers/http/operators/http.py @@ -18,12 +18,15 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, Sequence - +from requests import Response from requests.auth import AuthBase +import pickle +import base64 from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.http.hooks.http import HttpHook +from airflow.providers.http.triggers.http import HttpTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -89,6 +92,7 @@ def __init__( tcp_keep_alive_idle: int = 120, tcp_keep_alive_count: int = 20, tcp_keep_alive_interval: int = 30, + deferrable: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -106,23 +110,44 @@ def __init__( self.tcp_keep_alive_idle = tcp_keep_alive_idle self.tcp_keep_alive_count = tcp_keep_alive_count self.tcp_keep_alive_interval = tcp_keep_alive_interval + self.deferrable = deferrable def execute(self, context: Context) -> Any: - from airflow.utils.operator_helpers import determine_kwargs + if self.deferrable: + self.defer( + trigger=HttpTrigger( + http_conn_id=self.http_conn_id, + auth_type=self.auth_type, + method=self.method, + endpoint=self.endpoint, + headers=self.headers, + data=self.data, + extra_options=self.extra_options, + ), + method_name="execute_complete", + ) + else: + http = HttpHook( + self.method, + http_conn_id=self.http_conn_id, + auth_type=self.auth_type, + tcp_keep_alive=self.tcp_keep_alive, + tcp_keep_alive_idle=self.tcp_keep_alive_idle, + tcp_keep_alive_count=self.tcp_keep_alive_count, + tcp_keep_alive_interval=self.tcp_keep_alive_interval, + ) + + self.log.info("Calling HTTP method") - http = HttpHook( - self.method, - http_conn_id=self.http_conn_id, - auth_type=self.auth_type, - tcp_keep_alive=self.tcp_keep_alive, - tcp_keep_alive_idle=self.tcp_keep_alive_idle, - tcp_keep_alive_count=self.tcp_keep_alive_count, - tcp_keep_alive_interval=self.tcp_keep_alive_interval, - ) + response = http.run(self.endpoint, self.data, self.headers, self.extra_options) + return self.process_response(context=context, response=response) - self.log.info("Calling HTTP method") + def process_response(self, context: Context, response: Response) -> str: + """ + Process the response. + """ + from airflow.utils.operator_helpers import determine_kwargs - response = http.run(self.endpoint, self.data, self.headers, self.extra_options) if self.log_response: self.log.info(response.text) if self.response_check: @@ -133,3 +158,15 @@ def execute(self, context: Context) -> Any: kwargs = determine_kwargs(self.response_filter, [response], context) return self.response_filter(response, **kwargs) return response.text + + def execute_complete(self, context: Context, event: dict): + """ + Callback for when the trigger fires - returns immediately. + + Relies on trigger to throw an exception, otherwise it assumes execution was successful. + """ + if event["status"] == "success": + response = pickle.loads(base64.standard_b64decode(event["response"])) + return self.process_response(context=context, response=response) + else: + raise AirflowException(f"Unexpected error in the operation: {event['message']}") diff --git a/airflow/providers/http/triggers/__init__.py b/airflow/providers/http/triggers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airflow/providers/http/triggers/http.py b/airflow/providers/http/triggers/http.py new file mode 100644 index 000000000000..97163b11b588 --- /dev/null +++ b/airflow/providers/http/triggers/http.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations +import traceback +from typing import Any, AsyncIterator +import pickle +import base64 +import requests +from requests.structures import CaseInsensitiveDict + +from airflow.providers.http.hooks.http import HttpAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + +from aiohttp.client_reqrep import ClientResponse + + +class HttpTrigger(BaseTrigger): + """ + HttpTrigger run on the trigger worker. + + :param http_conn_id: http connection id that has the base + API url i.e https://www.google.com/ and optional authentication credentials. Default + headers can also be specified in the Extra field in json format. + :param auth_type: The auth type for the service + :param method: the API method to be called + :param endpoint: Endpoint to be called, i.e. ``resource/v1/query?``. + :param headers: Additional headers to be passed through as a dict. + :param data: Payload to be uploaded or request parameters. + :param extra_options: Additional kwargs to pass when creating a request. + For example, ``run(json=obj)`` is passed as + ``aiohttp.ClientSession().get(json=obj)``. + 2XX or 3XX status codes + """ + + def __init__( + self, + http_conn_id: str = "http_default", + auth_type: Any = None, + method: str = "POST", + endpoint: str | None = None, + headers: dict[str, str] | None = None, + data: Any = None, + extra_options: dict[str, Any] | None = None, + ): + super().__init__() + self.http_conn_id = http_conn_id + self.method = method + self.auth_type = auth_type + self.endpoint = endpoint + self.headers = headers + self.data = data + self.extra_options = extra_options + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes HttpTrigger arguments and classpath.""" + return ( + "airflow.providers.http.triggers.http.HttpTrigger", + { + "http_conn_id": self.http_conn_id, + "method": self.method, + "auth_type": self.auth_type, + "endpoint": self.endpoint, + "headers": self.headers, + "data": self.data, + "extra_options": self.extra_options, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """ + Makes a series of asynchronous http calls via an http hook. It yields a Trigger if + response is a 200 and run_state is successful, will retry the call up to the retry limit + if the error is 'retryable', otherwise it throws an exception. + """ + hook = HttpAsyncHook( + method=self.method, + http_conn_id=self.http_conn_id, + auth_type=self.auth_type, + ) + try: + client_response = await hook.run( + endpoint=self.endpoint, + data=self.data, + headers=self.headers, + extra_options=self.extra_options, + ) + response = await self._convert_response(client_response) + yield TriggerEvent( + { + "status": "success", + "response": base64.standard_b64encode(pickle.dumps(response)).decode("ascii"), + } + ) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) + # yield TriggerEvent({"status": "error", "message": str(traceback.format_exc())}) + + @staticmethod + async def _convert_response(client_response: ClientResponse) -> requests.Response: + """ + Convert aiohttp.client_reqrep.ClientResponse to requests.Response. + """ + response = requests.Response() + response._content = await client_response.read() + response.status_code = client_response.status + response.headers = CaseInsensitiveDict(client_response.headers) + response.url = client_response.url + response.history = client_response.history + response.encoding = client_response.get_encoding() + response.reason = client_response.reason + response.cookies = client_response.cookies + return response diff --git a/tests/providers/http/operators/test_http.py b/tests/providers/http/operators/test_http.py index fca910a7680e..3a943dd7fac4 100644 --- a/tests/providers/http/operators/test_http.py +++ b/tests/providers/http/operators/test_http.py @@ -17,12 +17,16 @@ # under the License. from __future__ import annotations +import base64 +import pickle from unittest import mock import pytest +from requests import Response -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.http.operators.http import SimpleHttpOperator +from airflow.providers.http.triggers.http import HttpTrigger @mock.patch.dict("os.environ", AIRFLOW_CONN_HTTP_EXAMPLE="http://www.example.com") @@ -81,3 +85,28 @@ def test_filters_response(self, requests_mock): ) result = operator.execute({}) assert result == {"value": 5} + + def test_async_defer_successfully(self, requests_mock): + operator = SimpleHttpOperator( + task_id="test_HTTP_op", + deferrable=True, + ) + with pytest.raises(TaskDeferred) as exc: + operator.execute({}) + assert isinstance(exc.value.trigger, HttpTrigger), "Trigger is not a HttpTrigger" + + def test_async_execute_successfully(self, requests_mock): + operator = SimpleHttpOperator( + task_id="test_HTTP_op", + deferrable=True, + ) + response = Response() + response._content = "content".encode("utf-8") + result = operator.execute_complete( + context={}, + event={ + "status": "success", + "response": base64.standard_b64encode(pickle.dumps(response)).decode("ascii"), + }, + ) + assert "content" == result diff --git a/tests/providers/http/triggers/__init__.py b/tests/providers/http/triggers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/providers/http/triggers/test_http.py b/tests/providers/http/triggers/test_http.py new file mode 100644 index 000000000000..b2a4deb196bd --- /dev/null +++ b/tests/providers/http/triggers/test_http.py @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import base64 +import pickle +from yarl import URL +from http.cookies import SimpleCookie +from asyncio import Future +from unittest import mock +from multidict import CIMultiDict, CIMultiDictProxy +from requests.structures import CaseInsensitiveDict + +import pytest + +from airflow.providers.http.triggers.http import HttpTrigger +from airflow.triggers.base import TriggerEvent + +from aiohttp.client_reqrep import ClientResponse + +HTTP_PATH = "airflow.providers.http.triggers.http.{}" +TEST_CONN_ID = "http_default" +TEST_AUTH_TYPE = None +TEST_METHOD = "POST" +TEST_ENDPOINT = "endpoint" +TEST_HEADERS = {"Authorization": "Bearer test"} +TEST_DATA = "" +TEST_EXTRA_OPTIONS = {} + + +@pytest.fixture +def trigger(): + return HttpTrigger( + http_conn_id=TEST_CONN_ID, + auth_type=TEST_AUTH_TYPE, + method=TEST_METHOD, + endpoint=TEST_ENDPOINT, + headers=TEST_HEADERS, + data=TEST_DATA, + extra_options=TEST_EXTRA_OPTIONS, + ) + + +@pytest.fixture +def client_response(): + client_response = mock.AsyncMock(ClientResponse) + client_response.read.return_value = "content".encode("utf-8") + client_response.status = 200 + client_response.headers = CIMultiDictProxy(CIMultiDict([("header", "value")])) + client_response.url = URL("https://example.com") + client_response.history = () + client_response.get_encoding.return_value = "utf-8" + client_response.reason = "reason" + client_response.cookies = SimpleCookie() + return client_response + + +class TestHttpTrigger: + @staticmethod + def _mock_run_result(result_to_mock): + f = Future() + f.set_result(result_to_mock) + return f + + def test_serialization(self, trigger): + """ + Asserts that the HttpTrigger correctly serializes its arguments + and classpath. + """ + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.http.triggers.http.HttpTrigger" + assert kwargs == { + "http_conn_id": TEST_CONN_ID, + "auth_type": TEST_AUTH_TYPE, + "method": TEST_METHOD, + "endpoint": TEST_ENDPOINT, + "headers": TEST_HEADERS, + "data": TEST_DATA, + "extra_options": TEST_EXTRA_OPTIONS, + } + + @pytest.mark.asyncio + @mock.patch(HTTP_PATH.format("HttpAsyncHook")) + async def test_trigger_on_success_yield_successfully(self, mock_hook, trigger, client_response): + """ + Tests the HttpTrigger only fires once the job execution reaches a successful state. + """ + mock_hook.return_value.run.return_value = self._mock_run_result(client_response) + response = await HttpTrigger._convert_response(client_response) + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent( + { + "status": "success", + "response": base64.standard_b64encode(pickle.dumps(response)).decode("ascii"), + } + ) + + @pytest.mark.asyncio + @mock.patch(HTTP_PATH.format("HttpAsyncHook")) + async def test_trigger_on_exec_yield_successfully(self, mock_hook, trigger): + """ + Test that HttpTrigger fires the correct event in case of an error. + """ + mock_hook.return_value.run.side_effect = Exception("Test exception") + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent({"status": "error", "message": "Test exception"}) + + @pytest.mark.asyncio + async def test_convert_response(self, client_response): + """ + Assert convert aiohttp.client_reqrep.ClientResponse to requests.Response. + """ + response = await HttpTrigger._convert_response(client_response) + assert response.content == await client_response.read() + assert response.status_code == client_response.status + assert response.headers == CaseInsensitiveDict(client_response.headers) + assert response.url == client_response.url + assert response.history == client_response.history + assert response.encoding == client_response.get_encoding() + assert response.reason == client_response.reason + assert response.cookies == client_response.cookies