diff --git a/airflow/providers/airbyte/hooks/airbyte.py b/airflow/providers/airbyte/hooks/airbyte.py index b1f6317530514..ab0d7e4baf15d 100644 --- a/airflow/providers/airbyte/hooks/airbyte.py +++ b/airflow/providers/airbyte/hooks/airbyte.py @@ -107,6 +107,18 @@ def get_job(self, job_id: int) -> Any: headers={"accept": "application/json"}, ) + def cancel_job(self, job_id: int) -> Any: + """ + Cancel the job when task is cancelled + + :param job_id: Required. Id of the Airbyte job + """ + return self.run( + endpoint=f"api/{self.api_version}/jobs/cancel", + json={"id": job_id}, + headers={"accept": "application/json"}, + ) + def test_connection(self): """Tests the Airbyte connection by hitting the health API""" self.method = 'GET' diff --git a/airflow/providers/airbyte/operators/airbyte.py b/airflow/providers/airbyte/operators/airbyte.py index ef2e2c1559902..7677795a6a524 100644 --- a/airflow/providers/airbyte/operators/airbyte.py +++ b/airflow/providers/airbyte/operators/airbyte.py @@ -67,14 +67,20 @@ def __init__( def execute(self, context: 'Context') -> None: """Create Airbyte Job and wait to finish""" - hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version) - job_object = hook.submit_sync_connection(connection_id=self.connection_id) - job_id = job_object.json()['job']['id'] + self.hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version) + job_object = self.hook.submit_sync_connection(connection_id=self.connection_id) + self.job_id = job_object.json()['job']['id'] - self.log.info("Job %s was submitted to Airbyte Server", job_id) + self.log.info("Job %s was submitted to Airbyte Server", self.job_id) if not self.asynchronous: - self.log.info('Waiting for job %s to complete', job_id) - hook.wait_for_job(job_id=job_id, wait_seconds=self.wait_seconds, timeout=self.timeout) - self.log.info('Job %s completed successfully', job_id) + self.log.info('Waiting for job %s to complete', self.job_id) + self.hook.wait_for_job(job_id=self.job_id, wait_seconds=self.wait_seconds, timeout=self.timeout) + self.log.info('Job %s completed successfully', self.job_id) - return job_id + return self.job_id + + def on_kill(self): + """Cancel the job if task is cancelled""" + if self.job_id: + self.log.info('on_kill: cancel the airbyte Job %s', self.job_id) + self.hook.cancel_job(self.job_id) diff --git a/tests/providers/airbyte/hooks/test_airbyte.py b/tests/providers/airbyte/hooks/test_airbyte.py index 77432cdec4c65..31b1d0ea81269 100644 --- a/tests/providers/airbyte/hooks/test_airbyte.py +++ b/tests/providers/airbyte/hooks/test_airbyte.py @@ -38,9 +38,12 @@ class TestAirbyteHook(unittest.TestCase): job_id = 1 sync_connection_endpoint = 'http://test-airbyte:8001/api/v1/connections/sync' get_job_endpoint = 'http://test-airbyte:8001/api/v1/jobs/get' + cancel_job_endpoint = 'http://test-airbyte:8001/api/v1/jobs/cancel' + health_endpoint = 'http://test-airbyte:8001/api/v1/health' _mock_sync_conn_success_response_body = {'job': {'id': 1}} _mock_job_status_success_response_body = {'job': {'status': 'succeeded'}} + _mock_job_cancel_status = 'cancelled' def setUp(self): db.merge_conn( @@ -71,6 +74,12 @@ def test_get_job_status(self, m): assert resp.status_code == 200 assert resp.json() == self._mock_job_status_success_response_body + @requests_mock.mock() + def test_cancel_job(self, m): + m.post(self.cancel_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body) + resp = self.hook.cancel_job(job_id=self.job_id) + assert resp.status_code == 200 + @mock.patch('airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job') def test_wait_for_job_succeeded(self, mock_get_job): mock_get_job.side_effect = [self.return_value_get_job(self.hook.SUCCEEDED)]