From 408a7d64f0e049279fb40048556d04b4037c78ee Mon Sep 17 00:00:00 2001 From: Megan Parker <91739562+megan-parker@users.noreply.github.com> Date: Mon, 14 Mar 2022 04:30:03 -0400 Subject: [PATCH] Add REST API endpoint for bulk update of DAGs (#19758) Added endpoint for bulk update of DAGs in the airflow stable API --- .../api_connexion/endpoints/dag_endpoint.py | 58 ++- airflow/api_connexion/openapi/v1.yaml | 69 ++- .../endpoints/test_dag_endpoint.py | 423 ++++++++++++++++++ 3 files changed, 533 insertions(+), 17 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 39c492e0563a8..e94707b127a69 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -89,7 +89,7 @@ def get_dags( cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags] dags_query = dags_query.filter(or_(*cond)) - total_entries = len(dags_query.all()) + total_entries = dags_query.count() dags = dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all() @@ -100,25 +100,67 @@ def get_dags( @provide_session def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION) -> APIResponse: """Update the specific DAG""" + try: + patch_body = dag_schema.load(request.json, session=session) + except ValidationError as err: + raise BadRequest(detail=str(err.messages)) + if update_mask: + patch_body_ = {} + if update_mask != ['is_paused']: + raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") + patch_body_[update_mask[0]] = patch_body[update_mask[0]] + patch_body = patch_body_ dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one_or_none() if not dag: raise NotFound(f"Dag with id: '{dag_id}' not found") + dag.is_paused = patch_body['is_paused'] + session.flush() + return dag_schema.dump(dag) + + +@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)]) +@format_parameters({'limit': check_limit}) +@provide_session +def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None): + """Patch multiple DAGs.""" try: patch_body = dag_schema.load(request.json, session=session) except ValidationError as err: - raise BadRequest("Invalid Dag schema", detail=str(err.messages)) + raise BadRequest(detail=str(err.messages)) if update_mask: patch_body_ = {} - if len(update_mask) > 1: + if update_mask != ['is_paused']: raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") update_mask = update_mask[0] - if update_mask != 'is_paused': - raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") patch_body_[update_mask] = patch_body[update_mask] patch_body = patch_body_ - setattr(dag, 'is_paused', patch_body['is_paused']) - session.commit() - return dag_schema.dump(dag) + if only_active: + dags_query = session.query(DagModel).filter(~DagModel.is_subdag, DagModel.is_active) + else: + dags_query = session.query(DagModel).filter(~DagModel.is_subdag) + + if dag_id_pattern == '~': + dag_id_pattern = '%' + dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%')) + editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user) + + dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags)) + if tags: + cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags] + dags_query = dags_query.filter(or_(*cond)) + + total_entries = dags_query.count() + + dags = dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all() + + dags_to_update = {dag.dag_id for dag in dags} + session.query(DagModel).filter(DagModel.dag_id.in_(dags_to_update)).update( + {DagModel.is_paused: patch_body['is_paused']}, synchronize_session='fetch' + ) + + session.flush() + + return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries)) @security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG)]) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 1e8b6eb1c5e7d..a25ad735adada 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -396,6 +396,10 @@ paths: /dags: get: summary: List DAGs + description: > + List DAGs in the database. + + `dag_id_pattern` can be set to match dags of a specific pattern x-openapi-router-controller: airflow.api_connexion.endpoints.dag_endpoint operationId: get_dags tags: [DAG] @@ -404,25 +408,56 @@ paths: - $ref: '#/components/parameters/PageOffset' - $ref: '#/components/parameters/OrderBy' - $ref: '#/components/parameters/FilterTags' - - name: only_active + - $ref: '#/components/parameters/OnlyActive' + - name: dag_id_pattern in: query schema: - type: boolean - default: true + type: string required: false description: | - Only return active DAGs. + If set, only return DAGs with dag_ids matching this pattern. + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/DAGCollection' + '401': + $ref: '#/components/responses/Unauthenticated' - *New in version 2.1.1* + patch: + summary: Update DAGs + description: > + Update DAGs of a given dag_id_pattern using UpdateMask. + + This endpoint allows specifying `~` as the dag_id_pattern to update all DAGs. + + *New in version 2.3.0* + x-openapi-router-controller: airflow.api_connexion.endpoints.dag_endpoint + operationId: patch_dags + parameters: + - $ref: '#/components/parameters/PageLimit' + - $ref: '#/components/parameters/PageOffset' + - $ref: '#/components/parameters/FilterTags' + - $ref: '#/components/parameters/UpdateMask' + - $ref: '#/components/parameters/OnlyActive' - name: dag_id_pattern in: query schema: type: string - required: false + required: true description: | - If set, only return DAGs with dag_ids matching this pattern. - - *New in version 2.3.0* + If set, only update DAGs with dag_ids matching this pattern. + tags: [DAG] + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/DAG' + example: + is_paused: true responses: '200': description: Success. @@ -432,6 +467,10 @@ paths: $ref: '#/components/schemas/DAGCollection' '401': $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' /dags/{dag_id}: parameters: @@ -3804,6 +3843,18 @@ components: *New in version 2.1.0* + OnlyActive: + in: query + name: only_active + schema: + type: boolean + default: true + required: false + description: | + Only filter active DAGs. + + *New in version 2.1.1* + # Other parameters FileToken: diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 7f9addd6548d1..82a2403f2fb4d 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -834,3 +834,426 @@ def test_should_respond_403_unauthorized(self): ) assert response.status_code == 403 + + +class TestPatchDags(TestDagEndpoint): + + file_token = SERIALIZER.dumps("/tmp/dag_1.py") + file_token2 = SERIALIZER.dumps("/tmp/dag_2.py") + + @provide_session + def test_should_respond_200_on_patch_is_paused(self, session): + self._create_dag_models(2) + self._create_deactivated_dag() + + dags_query = session.query(DagModel).filter(~DagModel.is_subdag) + assert len(dags_query.all()) == 3 + + response = self.client.patch( + "/api/v1/dags?dag_id_pattern=~", + json={ + "is_paused": False, + }, + environ_overrides={'REMOTE_USER': "test"}, + ) + + assert response.status_code == 200 + assert { + "dags": [ + { + "dag_id": "TEST_DAG_1", + "description": None, + "fileloc": "/tmp/dag_1.py", + "file_token": self.file_token, + "is_paused": False, + "is_active": True, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, + "tags": [], + }, + { + "dag_id": "TEST_DAG_2", + "description": None, + "fileloc": "/tmp/dag_2.py", + "file_token": self.file_token2, + "is_paused": False, + "is_active": True, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, + "tags": [], + }, + ], + "total_entries": 2, + } == response.json + + def test_only_active_true_returns_active_dags(self): + self._create_dag_models(1) + self._create_deactivated_dag() + response = self.client.patch( + "/api/v1/dags?only_active=True&dag_id_pattern=~", + json={ + "is_paused": False, + }, + environ_overrides={'REMOTE_USER': "test"}, + ) + assert response.status_code == 200 + assert { + "dags": [ + { + "dag_id": "TEST_DAG_1", + "description": None, + "fileloc": "/tmp/dag_1.py", + "file_token": self.file_token, + "is_paused": False, + "is_active": True, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, + "tags": [], + } + ], + "total_entries": 1, + } == response.json + + def test_only_active_false_returns_all_dags(self): + self._create_dag_models(1) + self._create_deactivated_dag() + response = self.client.patch( + "/api/v1/dags?only_active=False&dag_id_pattern=~", + json={ + "is_paused": False, + }, + environ_overrides={'REMOTE_USER': "test"}, + ) + + file_token_2 = SERIALIZER.dumps("/tmp/dag_del_1.py") + assert response.status_code == 200 + assert { + "dags": [ + { + "dag_id": "TEST_DAG_1", + "description": None, + "fileloc": "/tmp/dag_1.py", + "file_token": self.file_token, + "is_paused": False, + "is_active": True, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, + "tags": [], + }, + { + "dag_id": "TEST_DAG_DELETED_1", + "description": None, + "fileloc": "/tmp/dag_del_1.py", + "file_token": file_token_2, + "is_paused": False, + "is_active": False, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, + "tags": [], + }, + ], + "total_entries": 2, + } == response.json + + @parameterized.expand( + [ + ("api/v1/dags?tags=t1&dag_id_pattern=~", ['TEST_DAG_1', 'TEST_DAG_3']), + ("api/v1/dags?tags=t2&dag_id_pattern=~", ['TEST_DAG_2', 'TEST_DAG_3']), + ("api/v1/dags?tags=t1,t2&dag_id_pattern=~", ["TEST_DAG_1", "TEST_DAG_2", "TEST_DAG_3"]), + ("api/v1/dags?dag_id_pattern=~", ["TEST_DAG_1", "TEST_DAG_2", "TEST_DAG_3", "TEST_DAG_4"]), + ] + ) + def test_filter_dags_by_tags_works(self, url, expected_dag_ids): + # test filter by tags + dag1 = DAG(dag_id="TEST_DAG_1", tags=['t1']) + dag2 = DAG(dag_id="TEST_DAG_2", tags=['t2']) + dag3 = DAG(dag_id="TEST_DAG_3", tags=['t1', 't2']) + dag4 = DAG(dag_id="TEST_DAG_4") + dag1.sync_to_db() + dag2.sync_to_db() + dag3.sync_to_db() + dag4.sync_to_db() + response = self.client.patch( + url, + json={ + "is_paused": False, + }, + environ_overrides={'REMOTE_USER': "test"}, + ) + assert response.status_code == 200 + dag_ids = [dag["dag_id"] for dag in response.json["dags"]] + + assert expected_dag_ids == dag_ids + + @parameterized.expand( + [ + ("api/v1/dags?dag_id_pattern=DAG_1", {'TEST_DAG_1', 'SAMPLE_DAG_1'}), + ("api/v1/dags?dag_id_pattern=SAMPLE_DAG", {'SAMPLE_DAG_1', 'SAMPLE_DAG_2'}), + ( + "api/v1/dags?dag_id_pattern=_DAG_", + {"TEST_DAG_1", "TEST_DAG_2", 'SAMPLE_DAG_1', 'SAMPLE_DAG_2'}, + ), + ] + ) + def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids): + # test filter by tags + dag1 = DAG(dag_id="TEST_DAG_1") + dag2 = DAG(dag_id="TEST_DAG_2") + dag3 = DAG(dag_id="SAMPLE_DAG_1") + dag4 = DAG(dag_id="SAMPLE_DAG_2") + dag1.sync_to_db() + dag2.sync_to_db() + dag3.sync_to_db() + dag4.sync_to_db() + + response = self.client.patch( + url, + json={ + "is_paused": False, + }, + environ_overrides={'REMOTE_USER': "test"}, + ) + assert response.status_code == 200 + dag_ids = {dag["dag_id"] for dag in response.json["dags"]} + + assert expected_dag_ids == dag_ids + + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(3) + response = self.client.patch( + "api/v1/dags?dag_id_pattern=~", + json={ + "is_paused": False, + }, + environ_overrides={'REMOTE_USER': "test_granular_permissions"}, + ) + assert response.status_code == 200 + assert len(response.json['dags']) == 1 + assert response.json['dags'][0]['dag_id'] == 'TEST_DAG_1' + + @parameterized.expand( + [ + ("api/v1/dags?limit=1&dag_id_pattern=~", ["TEST_DAG_1"]), + ("api/v1/dags?limit=2&dag_id_pattern=~", ["TEST_DAG_1", "TEST_DAG_10"]), + ( + "api/v1/dags?offset=5&dag_id_pattern=~", + ["TEST_DAG_5", "TEST_DAG_6", "TEST_DAG_7", "TEST_DAG_8", "TEST_DAG_9"], + ), + ( + "api/v1/dags?offset=0&dag_id_pattern=~", + [ + "TEST_DAG_1", + "TEST_DAG_10", + "TEST_DAG_2", + "TEST_DAG_3", + "TEST_DAG_4", + "TEST_DAG_5", + "TEST_DAG_6", + "TEST_DAG_7", + "TEST_DAG_8", + "TEST_DAG_9", + ], + ), + ("api/v1/dags?limit=1&offset=5&dag_id_pattern=~", ["TEST_DAG_5"]), + ("api/v1/dags?limit=1&offset=1&dag_id_pattern=~", ["TEST_DAG_10"]), + ("api/v1/dags?limit=2&offset=2&dag_id_pattern=~", ["TEST_DAG_2", "TEST_DAG_3"]), + ] + ) + def test_should_respond_200_and_handle_pagination(self, url, expected_dag_ids): + self._create_dag_models(10) + + response = self.client.patch( + url, + json={ + "is_paused": False, + }, + environ_overrides={'REMOTE_USER': "test"}, + ) + + assert response.status_code == 200 + + dag_ids = [dag["dag_id"] for dag in response.json["dags"]] + + assert expected_dag_ids == dag_ids + assert 10 == response.json["total_entries"] + + def test_should_respond_200_default_limit(self): + self._create_dag_models(101) + + response = self.client.patch( + "api/v1/dags?dag_id_pattern=~", + json={ + "is_paused": False, + }, + environ_overrides={'REMOTE_USER': "test"}, + ) + + assert response.status_code == 200 + + assert 100 == len(response.json["dags"]) + assert 101 == response.json["total_entries"] + + def test_should_raises_401_unauthenticated(self): + response = self.client.patch( + "api/v1/dags?dag_id_pattern=~", + json={ + "is_paused": False, + }, + ) + + assert_401(response) + + def test_should_respond_403_unauthorized(self): + self._create_dag_models(1) + response = self.client.patch( + "api/v1/dags?dag_id_pattern=~", + json={ + "is_paused": False, + }, + environ_overrides={'REMOTE_USER': "test_no_permissions"}, + ) + + assert response.status_code == 403 + + def test_should_respond_200_and_pause_dags(self): + self._create_dag_models(2) + + response = self.client.patch( + "/api/v1/dags?dag_id_pattern=~", + json={ + "is_paused": True, + }, + environ_overrides={'REMOTE_USER': "test"}, + ) + + assert response.status_code == 200 + assert { + "dags": [ + { + "dag_id": "TEST_DAG_1", + "description": None, + "fileloc": "/tmp/dag_1.py", + "file_token": self.file_token, + "is_paused": True, + "is_active": True, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, + "tags": [], + }, + { + "dag_id": "TEST_DAG_2", + "description": None, + "fileloc": "/tmp/dag_2.py", + "file_token": self.file_token2, + "is_paused": True, + "is_active": True, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, + "tags": [], + }, + ], + "total_entries": 2, + } == response.json + + @provide_session + def test_should_respond_200_and_pause_dag_pattern(self, session): + self._create_dag_models(10) + file_token10 = SERIALIZER.dumps("/tmp/dag_10.py") + + response = self.client.patch( + "/api/v1/dags?dag_id_pattern=TEST_DAG_1", + json={ + "is_paused": True, + }, + environ_overrides={'REMOTE_USER': "test"}, + ) + + assert response.status_code == 200 + assert { + "dags": [ + { + "dag_id": "TEST_DAG_1", + "description": None, + "fileloc": "/tmp/dag_1.py", + "file_token": self.file_token, + "is_paused": True, + "is_active": True, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, + "tags": [], + }, + { + "dag_id": "TEST_DAG_10", + "description": None, + "fileloc": "/tmp/dag_10.py", + "file_token": file_token10, + "is_paused": True, + "is_active": True, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, + "tags": [], + }, + ], + "total_entries": 2, + } == response.json + + dags_not_updated = session.query(DagModel).filter(~DagModel.is_paused) + assert len(dags_not_updated.all()) == 8 + dags_updated = session.query(DagModel).filter(DagModel.is_paused) + assert len(dags_updated.all()) == 2 + + def test_should_respons_400_dag_id_pattern_missing(self): + self._create_dag_models(1) + response = self.client.patch( + "/api/v1/dags?only_active=True", + json={ + "is_paused": False, + }, + environ_overrides={'REMOTE_USER': "test"}, + ) + assert response.status_code == 400