diff --git a/airflow/providers/amazon/aws/waiters/batch.json b/airflow/providers/amazon/aws/waiters/batch.json index 3fbdd433771c..a8cd489ea380 100644 --- a/airflow/providers/amazon/aws/waiters/batch.json +++ b/airflow/providers/amazon/aws/waiters/batch.json @@ -17,7 +17,7 @@ "argument": "jobs[].status", "expected": "FAILED", "matcher": "pathAll", - "state": "failed" + "state": "failure" } ] }, @@ -37,13 +37,13 @@ "argument": "computeEnvironments[].status", "expected": "INVALID", "matcher": "pathAny", - "state": "failed" + "state": "failure" }, { "argument": "computeEnvironments[].status", "expected": "DELETED", "matcher": "pathAny", - "state": "failed" + "state": "failure" } ] } diff --git a/tests/providers/amazon/aws/waiters/test_custom_waiters.py b/tests/providers/amazon/aws/waiters/test_custom_waiters.py index 21c051f3b490..09b6742f7ae3 100644 --- a/tests/providers/amazon/aws/waiters/test_custom_waiters.py +++ b/tests/providers/amazon/aws/waiters/test_custom_waiters.py @@ -27,6 +27,7 @@ from moto import mock_eks from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, EcsTaskDefinitionStates from airflow.providers.amazon.aws.hooks.eks import EksHook @@ -295,3 +296,58 @@ def test_export_table_to_point_in_time_failed(self, mock_describe_export): ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, ) + + +class TestCustomBatchServiceWaiters: + """Test waiters from ``amazon/aws/waiters/batch.json``.""" + + JOB_ID = "test_job_id" + + @pytest.fixture(autouse=True) + def setup_test_cases(self, monkeypatch): + self.client = boto3.client("batch", region_name="eu-west-3") + monkeypatch.setattr(BatchClientHook, "conn", self.client) + + @pytest.fixture + def mock_describe_jobs(self): + """Mock ``BatchClientHook.Client.describe_jobs`` method.""" + with mock.patch.object(self.client, "describe_jobs") as m: + yield m + + def test_service_waiters(self): + hook_waiters = BatchClientHook(aws_conn_id=None).list_waiters() + assert "batch_job_complete" in hook_waiters + + @staticmethod + def describe_jobs(status: str): + """ + Helper function for generate minimal DescribeJobs response for a single job. + https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html + """ + return { + "jobs": [ + { + "status": status, + }, + ], + } + + def test_job_succeeded(self, mock_describe_jobs): + """Test job succeeded""" + mock_describe_jobs.side_effect = [ + self.describe_jobs(BatchClientHook.RUNNING_STATE), + self.describe_jobs(BatchClientHook.SUCCESS_STATE), + ] + waiter = BatchClientHook(aws_conn_id=None).get_waiter("batch_job_complete") + waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01, "MaxAttempts": 2}) + + def test_job_failed(self, mock_describe_jobs): + """Test job failed""" + mock_describe_jobs.side_effect = [ + self.describe_jobs(BatchClientHook.RUNNING_STATE), + self.describe_jobs(BatchClientHook.FAILURE_STATE), + ] + waiter = BatchClientHook(aws_conn_id=None).get_waiter("batch_job_complete") + + with pytest.raises(WaiterError, match="Waiter encountered a terminal failure state"): + waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01, "MaxAttempts": 2})