From 7e08c3a3df816cad6eb20f4d8a141a211525b0d1 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Wed, 22 Jun 2022 23:26:28 +0200 Subject: [PATCH] Upgrade FAB to 4.1.1 (#24399) * Upgrade FAB to 4.1.1 The Flask Application Builder have been updated recently to support a number of newer dependencies. This PR is the attempt to migrate FAB to newer version. This includes: * update setup.py and setup.cfg upper and lower bounds to account for proper version of dependencies that FAB < 4.0.0 was blocking from upgrade * added typed Flask application retrieval with a custom application fields available for MyPy typing checks. * fix typing to account for typing hints added in multiple upgraded libraries optional values and content of request returned as Mapping * switch to PyJWT 2.* by using non-deprecated "required" claim as list rather than separate fields * add possibiliyt to install providers without constraints so that we could avoid errors on conflicting constraints when upgrade-to-newer-dependencies is used * add pre-commit to check that 2.4+ only get_airflow_app is not used in providers * avoid Bad Request in case the request sent to Flask 2.0 is not JSon content type * switch imports of internal classes to direct packages where classes are available rather than from "airflow.models" to satisfy MyPY * synchronize changes of FAB Security Manager 4.1.1 with our copy of the Security Manager. * add error handling for a few "None" cases detected by MyPY * corrected test cases that were broken by immutability of Flask 2 objects and better escaping done by Flask 2 * updated test cases to account for redirection to "path" rather than full URL by Flask2 Fixes: #22397 * fixup! Upgrade FAB to 4.1.1 (cherry picked from commit e2f19505bf3622935480e80bee55bf5b6d80097b) --- .github/workflows/ci.yml | 4 + .pre-commit-config.yaml | 2 +- Dockerfile.ci | 36 +- airflow/api/auth/backend/basic_auth.py | 5 +- .../api_connexion/endpoints/dag_endpoint.py | 9 +- .../endpoints/dag_run_endpoint.py | 20 +- .../endpoints/extra_link_endpoint.py | 4 +- .../api_connexion/endpoints/log_endpoint.py | 11 +- .../api_connexion/endpoints/pool_endpoint.py | 14 +- .../api_connexion/endpoints/request_dict.py | 24 ++ .../endpoints/role_and_permission_endpoint.py | 15 +- .../api_connexion/endpoints/task_endpoint.py | 7 +- .../endpoints/task_instance_endpoint.py | 17 +- .../api_connexion/endpoints/user_endpoint.py | 13 +- .../endpoints/variable_endpoint.py | 7 +- .../api_connexion/endpoints/xcom_endpoint.py | 5 +- airflow/api_connexion/schemas/dag_schema.py | 2 +- .../schemas/task_instance_schema.py | 2 +- airflow/api_connexion/security.py | 7 +- airflow/models/abstractoperator.py | 1 - airflow/operators/trigger_dagrun.py | 5 +- .../common/auth_backend/google_openid.py | 2 +- airflow/sensors/external_task.py | 6 +- airflow/utils/airflow_flask_app.py | 37 ++ airflow/utils/jwt_signer.py | 4 +- airflow/www/api/experimental/endpoints.py | 3 +- airflow/www/auth.py | 5 +- .../www/extensions/init_wsgi_middlewares.py | 2 +- airflow/www/fab_security/manager.py | 37 +- airflow/www/views.py | 118 ++++--- dev/breeze/README.md | 2 +- dev/breeze/setup.cfg | 2 +- .../commands/release_management_commands.py | 9 + .../src/airflow_breeze/params/shell_params.py | 1 + .../utils/docker_command_utils.py | 1 + .../src/airflow_breeze/utils/recording.py | 4 +- dev/send_email.py | 3 +- images/breeze/output-commands-hash.txt | 5 - .../output-verify-provider-packages.svg | 140 ++++---- newsfragments/24399.significant.rst | 31 ++ scripts/ci/docker-compose/_docker.env | 1 + scripts/ci/docker-compose/base.yml | 1 + scripts/ci/docker-compose/devcontainer.env | 1 + .../pre_commit_check_2_1_compatibility.py | 44 ++- scripts/docker/entrypoint_ci.sh | 36 +- scripts/in_container/_in_container_utils.sh | 29 +- setup.cfg | 47 +-- setup.py | 12 +- .../endpoints/test_dag_endpoint.py | 97 +++--- .../endpoints/test_dag_source_endpoint.py | 27 +- .../endpoints/test_xcom_endpoint.py | 6 +- .../api_connexion/schemas/test_dag_schema.py | 321 +++++++++--------- tests/conftest.py | 24 ++ .../remote_user_api_auth_backend.py | 6 +- tests/utils/test_serve_logs.py | 13 +- tests/www/views/test_views.py | 25 +- tests/www/views/test_views_decorators.py | 6 +- tests/www/views/test_views_log.py | 2 +- tests/www/views/test_views_mount.py | 4 +- 59 files changed, 765 insertions(+), 559 deletions(-) create mode 100644 airflow/api_connexion/endpoints/request_dict.py create mode 100644 airflow/utils/airflow_flask_app.py delete mode 100644 images/breeze/output-commands-hash.txt create mode 100644 newsfragments/24399.significant.rst diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c0cbb9f09bc33..57b396f6ace7b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -805,6 +805,8 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" run: > breeze verify-provider-packages --use-airflow-version wheel --use-packages-from-dist --package-format wheel + env: + SKIP_CONSTRAINTS: "${{ needs.build-info.outputs.upgradeToNewerDependencies }}" - name: "Remove airflow package and replace providers with 2.1-compliant versions" run: | rm -vf dist/apache_airflow-*.whl \ @@ -882,6 +884,8 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" run: > breeze verify-provider-packages --use-airflow-version sdist --use-packages-from-dist --package-format sdist + env: + SKIP_CONSTRAINTS: "${{ needs.build-info.outputs.upgradeToNewerDependencies }}" - name: "Fix ownership" run: breeze fix-ownership if: always() diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 99db08571455e..af0c2b0e1c84d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -347,7 +347,7 @@ repos: language: python files: ^BREEZE\.rst$|^dev/breeze/.*$ pass_filenames: false - additional_dependencies: ['rich>=12.4.4', 'rich-click'] + additional_dependencies: ['rich>=12.4.4', 'rich-click>=1.5'] - id: update-local-yml-file name: Update mounts in the local yml file entry: ./scripts/ci/pre_commit/pre_commit_local_yml_mounts.py diff --git a/Dockerfile.ci b/Dockerfile.ci index 537f84a71f2d2..337901d4a817b 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -686,9 +686,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo "${COLOR_BLUE}Uninstalling airflow and providers" echo uninstall_airflow_and_providers - echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "none" + else + echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi uninstall_providers elif [[ ${USE_AIRFLOW_VERSION} == "sdist" ]]; then echo @@ -696,9 +702,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo uninstall_airflow_and_providers echo - echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "none" + else + echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi uninstall_providers else echo @@ -706,9 +718,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo uninstall_airflow_and_providers echo - echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_released_airflow_version "${USE_AIRFLOW_VERSION}" "none" + else + echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi fi if [[ ${USE_PACKAGES_FROM_DIST=} == "true" ]]; then echo diff --git a/airflow/api/auth/backend/basic_auth.py b/airflow/api/auth/backend/basic_auth.py index 397a722a98cf2..12f00b435fe11 100644 --- a/airflow/api/auth/backend/basic_auth.py +++ b/airflow/api/auth/backend/basic_auth.py @@ -18,10 +18,11 @@ from functools import wraps from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast -from flask import Response, current_app, request +from flask import Response, request from flask_appbuilder.const import AUTH_LDAP from flask_login import login_user +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.www.fab_security.sqla.models import User CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None @@ -40,7 +41,7 @@ def auth_current_user() -> Optional[User]: if auth is None or not auth.username or not auth.password: return None - ab_security_manager = current_app.appbuilder.sm + ab_security_manager = get_airflow_app().appbuilder.sm user = None if ab_security_manager.auth_type == AUTH_LDAP: user = ab_security_manager.auth_user_ldap(auth.username, auth.password) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index e94707b127a69..40113021cfad6 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -18,7 +18,7 @@ from typing import Collection, Optional from connexion import NoContent -from flask import current_app, g, request +from flask import g, request from marshmallow import ValidationError from sqlalchemy.orm import Session from sqlalchemy.sql.expression import or_ @@ -37,6 +37,7 @@ from airflow.exceptions import AirflowException, DagNotFound from airflow.models.dag import DagModel, DagTag from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -55,7 +56,7 @@ def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) def get_dag_details(*, dag_id: str) -> APIResponse: """Get details of DAG.""" - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found") return dag_detail_schema.dump(dag) @@ -82,7 +83,7 @@ def get_dags( if dag_id_pattern: dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%')) - readable_dags = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) dags_query = dags_query.filter(DagModel.dag_id.in_(readable_dags)) if tags: @@ -142,7 +143,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat if dag_id_pattern == '~': dag_id_pattern = '%' dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%')) - editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user) + editable_dags = get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user) dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags)) if tags: diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index a83ca223b07ac..1fad48f7b6fe7 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -18,13 +18,14 @@ import pendulum from connexion import NoContent -from flask import current_app, g, request +from flask import g from marshmallow import ValidationError from sqlalchemy import or_ from sqlalchemy.orm import Query, Session from airflow.api.common.mark_tasks import set_dag_run_state_to_failed, set_dag_run_state_to_success from airflow.api_connexion import security +from airflow.api_connexion.endpoints.request_dict import get_json_request_dict from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters from airflow.api_connexion.schemas.dag_run_schema import ( @@ -37,6 +38,7 @@ from airflow.api_connexion.types import APIResponse from airflow.models import DagModel, DagRun from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType @@ -157,7 +159,7 @@ def get_dag_runs( # This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs. if dag_id == "~": - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder query = query.filter(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user))) else: query = query.filter(DagRun.dag_id == dag_id) @@ -189,13 +191,13 @@ def get_dag_runs( @provide_session def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: """Get list of DAG Runs""" - body = request.get_json() + body = get_json_request_dict() try: data = dagruns_batch_form_schema.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) query = session.query(DagRun) if data.get("dag_ids"): @@ -242,7 +244,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: detail=f"DAG with dag_id: '{dag_id}' has import errors", ) try: - post_body = dagrun_schema.load(request.json, session=session) + post_body = dagrun_schema.load(get_json_request_dict(), session=session) except ValidationError as err: raise BadRequest(detail=str(err)) @@ -258,7 +260,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: ) if not dagrun_instance: try: - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_run = dag.create_dagrun( run_type=DagRunType.MANUAL, run_id=run_id, @@ -267,7 +269,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: state=DagRunState.QUEUED, conf=post_body.get("conf"), external_trigger=True, - dag_hash=current_app.dag_bag.dags_hash.get(dag_id), + dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id), ) return dagrun_schema.dump(dag_run) except ValueError as ve: @@ -300,12 +302,12 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise NotFound(error_message) try: - post_body = set_dagrun_state_form_schema.load(request.json) + post_body = set_dagrun_state_form_schema.load(get_json_request_dict()) except ValidationError as err: raise BadRequest(detail=str(err)) state = post_body['state'] - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if state == DagRunState.SUCCESS: set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True) else: diff --git a/airflow/api_connexion/endpoints/extra_link_endpoint.py b/airflow/api_connexion/endpoints/extra_link_endpoint.py index 3e9535603bda3..94b36928bfd0c 100644 --- a/airflow/api_connexion/endpoints/extra_link_endpoint.py +++ b/airflow/api_connexion/endpoints/extra_link_endpoint.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from flask import current_app from sqlalchemy.orm.session import Session from airflow import DAG @@ -25,6 +24,7 @@ from airflow.exceptions import TaskNotFound from airflow.models.dagbag import DagBag from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -46,7 +46,7 @@ def get_extra_links( """Get extra links for task instance""" from airflow.models.taskinstance import TaskInstance - dagbag: DagBag = current_app.dag_bag + dagbag: DagBag = get_airflow_app().dag_bag dag: DAG = dagbag.get_dag(dag_id) if not dag: raise NotFound("DAG not found", detail=f'DAG with ID = "{dag_id}" not found') diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py index f1335fe527451..171cacb076e7c 100644 --- a/airflow/api_connexion/endpoints/log_endpoint.py +++ b/airflow/api_connexion/endpoints/log_endpoint.py @@ -14,10 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - from typing import Any, Optional -from flask import Response, current_app, request +from flask import Response, request from itsdangerous.exc import BadSignature from itsdangerous.url_safe import URLSafeSerializer from sqlalchemy.orm.session import Session @@ -29,6 +28,7 @@ from airflow.exceptions import TaskNotFound from airflow.models import TaskInstance from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.log.log_reader import TaskLogReader from airflow.utils.session import NEW_SESSION, provide_session @@ -52,7 +52,7 @@ def get_log( session: Session = NEW_SESSION, ) -> APIResponse: """Get logs for specific task instance""" - key = current_app.config["SECRET_KEY"] + key = get_airflow_app().config["SECRET_KEY"] if not token: metadata = {} else: @@ -87,7 +87,7 @@ def get_log( metadata['end_of_log'] = True raise NotFound(title="TaskInstance not found") - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if dag: try: ti.task = dag.get_task(ti.task_id) @@ -101,7 +101,8 @@ def get_log( if return_type == 'application/json' or return_type is None: # default logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata) logs = logs[0] if task_try_number is not None else logs - token = URLSafeSerializer(key).dumps(metadata) + # we must have token here, so we can safely ignore it + token = URLSafeSerializer(key).dumps(metadata) # type: ignore[assignment] return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs)) # text/plain. Stream logs = task_log_reader.read_log_stream(ti, task_try_number, metadata) diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py b/airflow/api_connexion/endpoints/pool_endpoint.py index e9c8aee252bec..8c3d3f3b86d38 100644 --- a/airflow/api_connexion/endpoints/pool_endpoint.py +++ b/airflow/api_connexion/endpoints/pool_endpoint.py @@ -16,13 +16,14 @@ # under the License. from typing import Optional -from flask import Response, request +from flask import Response from marshmallow import ValidationError from sqlalchemy import func from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from airflow.api_connexion import security +from airflow.api_connexion.endpoints.request_dict import get_json_request_dict from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.pool_schema import PoolCollection, pool_collection_schema, pool_schema @@ -83,9 +84,10 @@ def patch_pool( session: Session = NEW_SESSION, ) -> APIResponse: """Update a pool""" + request_dict = get_json_request_dict() # Only slots can be modified in 'default_pool' try: - if pool_name == Pool.DEFAULT_POOL_NAME and request.json["name"] != Pool.DEFAULT_POOL_NAME: + if pool_name == Pool.DEFAULT_POOL_NAME and request_dict["name"] != Pool.DEFAULT_POOL_NAME: if update_mask and len(update_mask) == 1 and update_mask[0].strip() == "slots": pass else: @@ -98,7 +100,7 @@ def patch_pool( raise NotFound(detail=f"Pool with name:'{pool_name}' not found") try: - patch_body = pool_schema.load(request.json) + patch_body = pool_schema.load(request_dict) except ValidationError as err: raise BadRequest(detail=str(err.messages)) @@ -119,7 +121,7 @@ def patch_pool( else: required_fields = {"name", "slots"} - fields_diff = required_fields - set(request.json.keys()) + fields_diff = required_fields - set(get_json_request_dict().keys()) if fields_diff: raise BadRequest(detail=f"Missing required property(ies): {sorted(fields_diff)}") @@ -134,12 +136,12 @@ def patch_pool( def post_pool(*, session: Session = NEW_SESSION) -> APIResponse: """Create a pool""" required_fields = {"name", "slots"} # Pool would require both fields in the post request - fields_diff = required_fields - set(request.json.keys()) + fields_diff = required_fields - set(get_json_request_dict().keys()) if fields_diff: raise BadRequest(detail=f"Missing required property(ies): {sorted(fields_diff)}") try: - post_body = pool_schema.load(request.json, session=session) + post_body = pool_schema.load(get_json_request_dict(), session=session) except ValidationError as err: raise BadRequest(detail=str(err.messages)) diff --git a/airflow/api_connexion/endpoints/request_dict.py b/airflow/api_connexion/endpoints/request_dict.py new file mode 100644 index 0000000000000..4d7ad21250586 --- /dev/null +++ b/airflow/api_connexion/endpoints/request_dict.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Mapping, cast + + +def get_json_request_dict() -> Mapping[str, Any]: + from flask import request + + return cast(Mapping[str, Any], request.get_json()) diff --git a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py index a25856e111b2c..25419066d20fa 100644 --- a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py +++ b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py @@ -18,7 +18,7 @@ from typing import List, Optional, Tuple from connexion import NoContent -from flask import current_app, request +from flask import request from marshmallow import ValidationError from sqlalchemy import asc, desc, func @@ -34,6 +34,7 @@ ) from airflow.api_connexion.types import APIResponse, UpdateMask from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.www.fab_security.sqla.models import Action, Role from airflow.www.security import AirflowSecurityManager @@ -54,7 +55,7 @@ def _check_action_and_resource(sm: AirflowSecurityManager, perms: List[Tuple[str @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE)]) def get_role(*, role_name: str) -> APIResponse: """Get role""" - ab_security_manager = current_app.appbuilder.sm + ab_security_manager = get_airflow_app().appbuilder.sm role = ab_security_manager.find_role(name=role_name) if not role: raise NotFound(title="Role not found", detail=f"Role with name {role_name!r} was not found") @@ -65,7 +66,7 @@ def get_role(*, role_name: str) -> APIResponse: @format_parameters({"limit": check_limit}) def get_roles(*, order_by: str = "name", limit: int, offset: Optional[int] = None) -> APIResponse: """Get roles""" - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder session = appbuilder.get_session total_entries = session.query(func.count(Role.id)).scalar() direction = desc if order_by.startswith("-") else asc @@ -89,7 +90,7 @@ def get_roles(*, order_by: str = "name", limit: int, offset: Optional[int] = Non @format_parameters({'limit': check_limit}) def get_permissions(*, limit: int, offset: Optional[int] = None) -> APIResponse: """Get permissions""" - session = current_app.appbuilder.get_session + session = get_airflow_app().appbuilder.get_session total_entries = session.query(func.count(Action.id)).scalar() query = session.query(Action) actions = query.offset(offset).limit(limit).all() @@ -99,7 +100,7 @@ def get_permissions(*, limit: int, offset: Optional[int] = None) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ROLE)]) def delete_role(*, role_name: str) -> APIResponse: """Delete a role""" - ab_security_manager = current_app.appbuilder.sm + ab_security_manager = get_airflow_app().appbuilder.sm role = ab_security_manager.find_role(name=role_name) if not role: raise NotFound(title="Role not found", detail=f"Role with name {role_name!r} was not found") @@ -110,7 +111,7 @@ def delete_role(*, role_name: str) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_ROLE)]) def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse: """Update a role""" - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder security_manager = appbuilder.sm body = request.json try: @@ -144,7 +145,7 @@ def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse @security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ROLE)]) def post_role() -> APIResponse: """Create a new role""" - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder security_manager = appbuilder.sm body = request.json try: diff --git a/airflow/api_connexion/endpoints/task_endpoint.py b/airflow/api_connexion/endpoints/task_endpoint.py index 28c39b000c28d..74b6e7e9ee8ed 100644 --- a/airflow/api_connexion/endpoints/task_endpoint.py +++ b/airflow/api_connexion/endpoints/task_endpoint.py @@ -16,8 +16,6 @@ # under the License. from operator import attrgetter -from flask import current_app - from airflow import DAG from airflow.api_connexion import security from airflow.api_connexion.exceptions import BadRequest, NotFound @@ -25,6 +23,7 @@ from airflow.api_connexion.types import APIResponse from airflow.exceptions import TaskNotFound from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app @security.requires_access( @@ -35,7 +34,7 @@ ) def get_task(*, dag_id: str, task_id: str) -> APIResponse: """Get simplified representation of a task.""" - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound("DAG not found") @@ -54,7 +53,7 @@ def get_task(*, dag_id: str, task_id: str) -> APIResponse: ) def get_tasks(*, dag_id: str, order_by: str = "task_id") -> APIResponse: """Get tasks for DAG""" - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound("DAG not found") tasks = dag.tasks diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index c2416ab0d9d44..6cc3e784e62a3 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -16,7 +16,6 @@ # under the License. from typing import Any, Iterable, List, Optional, Tuple, TypeVar -from flask import current_app, request from marshmallow import ValidationError from sqlalchemy import and_, func, or_ from sqlalchemy.exc import MultipleResultsFound @@ -25,6 +24,7 @@ from sqlalchemy.sql import ClauseElement from airflow.api_connexion import security +from airflow.api_connexion.endpoints.request_dict import get_json_request_dict from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import format_datetime, format_parameters from airflow.api_connexion.schemas.task_instance_schema import ( @@ -42,6 +42,7 @@ from airflow.models.dagrun import DagRun as DR from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState, State @@ -188,7 +189,7 @@ def get_mapped_task_instances( # 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404 if base_query.with_entities(func.count('*')).scalar() == 0: - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: error_message = f"DAG {dag_id} not found" raise NotFound(error_message) @@ -364,7 +365,7 @@ def get_task_instances( @provide_session def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: """Get list of task instances.""" - body = request.get_json() + body = get_json_request_dict() try: data = task_instance_batch_form.load(body) except ValidationError as err: @@ -423,20 +424,20 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: @provide_session def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" - body = request.get_json() + body = get_json_request_dict() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') dry_run = data.pop('dry_run') # We always pass dry_run here, otherwise this would try to confirm on the terminal! - task_instances = dag.clear(dry_run=True, dag_bag=current_app.dag_bag, **data) + task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, **data) if not dry_run: clear_task_instances( task_instances.all(), @@ -460,14 +461,14 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> @provide_session def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of task instances.""" - body = request.get_json() + body = get_json_request_dict() try: data = set_task_instance_state_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) error_message = f"Dag ID {dag_id} not found" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound(error_message) diff --git a/airflow/api_connexion/endpoints/user_endpoint.py b/airflow/api_connexion/endpoints/user_endpoint.py index 82375cebcaf16..2ed0db2aae864 100644 --- a/airflow/api_connexion/endpoints/user_endpoint.py +++ b/airflow/api_connexion/endpoints/user_endpoint.py @@ -17,7 +17,7 @@ from typing import List, Optional from connexion import NoContent -from flask import current_app, request +from flask import request from marshmallow import ValidationError from sqlalchemy import asc, desc, func from werkzeug.security import generate_password_hash @@ -33,13 +33,14 @@ ) from airflow.api_connexion.types import APIResponse, UpdateMask from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.www.fab_security.sqla.models import Role, User @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_USER)]) def get_user(*, username: str) -> APIResponse: """Get a user""" - ab_security_manager = current_app.appbuilder.sm + ab_security_manager = get_airflow_app().appbuilder.sm user = ab_security_manager.find_user(username=username) if not user: raise NotFound(title="User not found", detail=f"The User with username `{username}` was not found") @@ -50,7 +51,7 @@ def get_user(*, username: str) -> APIResponse: @format_parameters({"limit": check_limit}) def get_users(*, limit: int, order_by: str = "id", offset: Optional[str] = None) -> APIResponse: """Get users""" - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder session = appbuilder.get_session total_entries = session.query(func.count(User.id)).scalar() direction = desc if order_by.startswith("-") else asc @@ -86,7 +87,7 @@ def post_user() -> APIResponse: except ValidationError as e: raise BadRequest(detail=str(e.messages)) - security_manager = current_app.appbuilder.sm + security_manager = get_airflow_app().appbuilder.sm username = data["username"] email = data["email"] @@ -129,7 +130,7 @@ def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse: except ValidationError as e: raise BadRequest(detail=str(e.messages)) - security_manager = current_app.appbuilder.sm + security_manager = get_airflow_app().appbuilder.sm user = security_manager.find_user(username=username) if user is None: @@ -193,7 +194,7 @@ def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_USER)]) def delete_user(*, username: str) -> APIResponse: """Delete a user""" - security_manager = current_app.appbuilder.sm + security_manager = get_airflow_app().appbuilder.sm user = security_manager.find_user(username=username) if user is None: diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py index 067d163401454..4dfc0803c5c62 100644 --- a/airflow/api_connexion/endpoints/variable_endpoint.py +++ b/airflow/api_connexion/endpoints/variable_endpoint.py @@ -16,12 +16,13 @@ # under the License. from typing import Optional -from flask import Response, request +from flask import Response from marshmallow import ValidationError from sqlalchemy import func from sqlalchemy.orm import Session from airflow.api_connexion import security +from airflow.api_connexion.endpoints.request_dict import get_json_request_dict from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.variable_schema import variable_collection_schema, variable_schema @@ -78,7 +79,7 @@ def get_variables( def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Response: """Update a variable by key""" try: - data = variable_schema.load(request.json) + data = variable_schema.load(get_json_request_dict()) except ValidationError as err: raise BadRequest("Invalid Variable schema", detail=str(err.messages)) @@ -99,7 +100,7 @@ def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Resp def post_variables() -> Response: """Create a variable""" try: - data = variable_schema.load(request.json) + data = variable_schema.load(get_json_request_dict()) except ValidationError as err: raise BadRequest("Invalid Variable schema", detail=str(err.messages)) diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py index 9cc6b6d79a933..62c7262f7ed2c 100644 --- a/airflow/api_connexion/endpoints/xcom_endpoint.py +++ b/airflow/api_connexion/endpoints/xcom_endpoint.py @@ -16,7 +16,7 @@ # under the License. from typing import Optional -from flask import current_app, g +from flask import g from sqlalchemy import and_ from sqlalchemy.orm import Session @@ -27,6 +27,7 @@ from airflow.api_connexion.types import APIResponse from airflow.models import DagRun as DR, XCom from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -52,7 +53,7 @@ def get_xcom_entries( """Get all XCom values""" query = session.query(XCom) if dag_id == '~': - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) query = query.filter(XCom.dag_id.in_(readable_dag_ids)) query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) diff --git a/airflow/api_connexion/schemas/dag_schema.py b/airflow/api_connexion/schemas/dag_schema.py index 2f369113290d9..6e7410dc4f2ef 100644 --- a/airflow/api_connexion/schemas/dag_schema.py +++ b/airflow/api_connexion/schemas/dag_schema.py @@ -83,7 +83,7 @@ def get_owners(obj: DagModel): @staticmethod def get_token(obj: DagModel): """Return file token""" - serializer = URLSafeSerializer(conf.get('webserver', 'secret_key')) + serializer = URLSafeSerializer(conf.get_mandatory_value('webserver', 'secret_key')) return serializer.dumps(obj.fileloc) diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py index 37005256f6cdc..74824dbaf87c6 100644 --- a/airflow/api_connexion/schemas/task_instance_schema.py +++ b/airflow/api_connexion/schemas/task_instance_schema.py @@ -60,7 +60,7 @@ class Meta: pid = auto_field() executor_config = auto_field() sla_miss = fields.Nested(SlaMissSchema, dump_default=None) - rendered_fields = JsonObjectField(default={}) + rendered_fields = JsonObjectField(dump_default={}) def get_attribute(self, obj, attr, default): if attr == "sla_miss": diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index 3562c98eb4b35..6c84181f91bd3 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -18,16 +18,17 @@ from functools import wraps from typing import Callable, Optional, Sequence, Tuple, TypeVar, cast -from flask import Response, current_app +from flask import Response from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated +from airflow.utils.airflow_flask_app import get_airflow_app T = TypeVar("T", bound=Callable) def check_authentication() -> None: """Checks that the request has valid authorization information.""" - for auth in current_app.api_auth: + for auth in get_airflow_app().api_auth: response = auth.requires_authentication(Response)() if response.status_code == 200: return @@ -38,7 +39,7 @@ def check_authentication() -> None: def requires_access(permissions: Optional[Sequence[Tuple[str, str]]] = None) -> Callable[[T], T]: """Factory for decorator that checks current user's permissions against required permissions.""" - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder appbuilder.sm.sync_resource_permissions(permissions) def requires_access_decorator(func: T): diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 8d2e06442a2e5..4d50288673be0 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -302,7 +302,6 @@ def get_extra_links(self, ti: "TaskInstance", link_name: str) -> Optional[str]: return link.get_link(self, ti_key=ti.key) else: return link.get_link(self, ti.dag_run.logical_date) # type: ignore[misc] - return None def render_template_fields( self, diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 0689f14c56261..4578fd2df818b 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -23,7 +23,10 @@ from airflow.api.common.trigger_dag import trigger_dag from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists -from airflow.models import BaseOperator, BaseOperatorLink, DagBag, DagModel, DagRun +from airflow.models.baseoperator import BaseOperator, BaseOperatorLink +from airflow.models.dag import DagModel +from airflow.models.dagbag import DagBag +from airflow.models.dagrun import DagRun from airflow.models.xcom import XCom from airflow.utils import timezone from airflow.utils.context import Context diff --git a/airflow/providers/google/common/auth_backend/google_openid.py b/airflow/providers/google/common/auth_backend/google_openid.py index 496ac29616686..a267c0e63a1ca 100644 --- a/airflow/providers/google/common/auth_backend/google_openid.py +++ b/airflow/providers/google/common/auth_backend/google_openid.py @@ -88,7 +88,7 @@ def _verify_id_token(id_token: str) -> Optional[str]: def _lookup_user(user_email: str): - security_manager = current_app.appbuilder.sm + security_manager = current_app.appbuilder.sm # type: ignore[attr-defined] user = security_manager.find_user(email=user_email) if not user: diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 40c0a7a5665b7..30c27c7214dc7 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -23,7 +23,11 @@ from sqlalchemy import func from airflow.exceptions import AirflowException -from airflow.models import BaseOperatorLink, DagBag, DagModel, DagRun, TaskInstance +from airflow.models.baseoperator import BaseOperatorLink +from airflow.models.dag import DagModel +from airflow.models.dagbag import DagBag +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance from airflow.operators.empty import EmptyOperator from airflow.sensors.base import BaseSensorOperator from airflow.utils.helpers import build_airflow_url_with_query diff --git a/airflow/utils/airflow_flask_app.py b/airflow/utils/airflow_flask_app.py new file mode 100644 index 0000000000000..a14ff99398d21 --- /dev/null +++ b/airflow/utils/airflow_flask_app.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, List, cast + +from flask import Flask + +from airflow.models.dagbag import DagBag +from airflow.www.extensions.init_appbuilder import AirflowAppBuilder + + +class AirflowApp(Flask): + """Airflow Flask Application""" + + appbuilder: AirflowAppBuilder + dag_bag: DagBag + api_auth: List[Any] + + +def get_airflow_app() -> AirflowApp: + from flask import current_app + + return cast(AirflowApp, current_app) diff --git a/airflow/utils/jwt_signer.py b/airflow/utils/jwt_signer.py index 941a3d05981ce..e767997ebeb78 100644 --- a/airflow/utils/jwt_signer.py +++ b/airflow/utils/jwt_signer.py @@ -73,9 +73,7 @@ def verify_token(self, token: str) -> Dict[str, Any]: algorithms=[self._algorithm], options={ "verify_signature": True, - "require_exp": True, - "require_iat": True, - "require_nbf": True, + "require": ["exp", "iat", "nbf"], }, audience=self._audience, ) diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py index 898988db81c50..75256f13736fd 100644 --- a/airflow/www/api/experimental/endpoints.py +++ b/airflow/www/api/experimental/endpoints.py @@ -70,7 +70,8 @@ def add_deprecation_headers(response: Response): return response -api_experimental.after_request(add_deprecation_headers) +# This API is deprecated. We do not care too much about typing here +api_experimental.after_request(add_deprecation_headers) # type: ignore[arg-type] @api_experimental.route('/dags//dag_runs', methods=['POST']) diff --git a/airflow/www/auth.py b/airflow/www/auth.py index 9d40c00a5cf10..9d36cda883c92 100644 --- a/airflow/www/auth.py +++ b/airflow/www/auth.py @@ -37,7 +37,10 @@ def decorated(*args, **kwargs): appbuilder = current_app.appbuilder dag_id = ( - request.args.get("dag_id") or request.form.get("dag_id") or (request.json or {}).get("dag_id") + request.args.get("dag_id") + or request.form.get("dag_id") + or (request.is_json and request.json.get("dag_id")) + or None ) if appbuilder.sm.check_authorization(permissions, dag_id): return func(*args, **kwargs) diff --git a/airflow/www/extensions/init_wsgi_middlewares.py b/airflow/www/extensions/init_wsgi_middlewares.py index 0ed78073e92f5..00c04006ff68e 100644 --- a/airflow/www/extensions/init_wsgi_middlewares.py +++ b/airflow/www/extensions/init_wsgi_middlewares.py @@ -37,7 +37,7 @@ def init_wsgi_middleware(flask_app: Flask): base_url = "" if base_url: flask_app.wsgi_app = DispatcherMiddleware( # type: ignore - _root_app, mounts={base_url: flask_app.wsgi_app} + _root_app, mounts={base_url: flask_app.wsgi_app} # type: ignore ) # Apply ProxyFix middleware diff --git a/airflow/www/fab_security/manager.py b/airflow/www/fab_security/manager.py index 8381f7b08cdc7..2010e58c348bc 100644 --- a/airflow/www/fab_security/manager.py +++ b/airflow/www/fab_security/manager.py @@ -291,7 +291,7 @@ def create_jwt_manager(self, app) -> JWTManager: """ jwt_manager = JWTManager() jwt_manager.init_app(app) - jwt_manager.user_loader_callback_loader(self.load_user_jwt) + jwt_manager.user_lookup_loader(self.load_user_jwt) return jwt_manager def create_builtin_roles(self): @@ -654,6 +654,18 @@ def get_oauth_user_info(self, provider, resp): "email": data.get("email", ""), "role_keys": data.get("groups", []), } + # for Keycloak + if provider in ["keycloak", "keycloak_before_17"]: + me = self.appbuilder.sm.oauth_remotes[provider].get("openid-connect/userinfo") + me.raise_for_status() + data = me.json() + log.debug("User info from Keycloak: %s", data) + return { + "username": data.get("preferred_username", ""), + "first_name": data.get("given_name", ""), + "last_name": data.get("family_name", ""), + "email": data.get("email", ""), + } else: return {} @@ -1027,12 +1039,6 @@ def auth_user_ldap(self, username, password): try: # LDAP certificate settings - if self.auth_ldap_allow_self_signed: - ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW) - ldap.set_option(ldap.OPT_X_TLS_NEWCTX, 0) - elif self.auth_ldap_tls_demand: - ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) - ldap.set_option(ldap.OPT_X_TLS_NEWCTX, 0) if self.auth_ldap_tls_cacertdir: ldap.set_option(ldap.OPT_X_TLS_CACERTDIR, self.auth_ldap_tls_cacertdir) if self.auth_ldap_tls_cacertfile: @@ -1041,6 +1047,12 @@ def auth_user_ldap(self, username, password): ldap.set_option(ldap.OPT_X_TLS_CERTFILE, self.auth_ldap_tls_certfile) if self.auth_ldap_tls_keyfile: ldap.set_option(ldap.OPT_X_TLS_KEYFILE, self.auth_ldap_tls_keyfile) + if self.auth_ldap_allow_self_signed: + ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW) + ldap.set_option(ldap.OPT_X_TLS_NEWCTX, 0) + elif self.auth_ldap_tls_demand: + ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) + ldap.set_option(ldap.OPT_X_TLS_NEWCTX, 0) # Initialise LDAP connection con = ldap.initialize(self.auth_ldap_server) @@ -1354,7 +1366,10 @@ def get_user_menu_access(self, menu_names: Optional[List[str]] = None) -> Set[st return self._get_user_permission_resources(g.user, "menu_access", resource_names=menu_names) elif current_user_jwt: return self._get_user_permission_resources( - current_user_jwt, "menu_access", resource_names=menu_names + # the current_user_jwt is a lazy proxy, so we need to ignore type checking + current_user_jwt, # type: ignore[arg-type] + "menu_access", + resource_names=menu_names, ) else: return self._get_user_permission_resources(None, "menu_access", resource_names=menu_names) @@ -1660,9 +1675,9 @@ def load_user(self, user_id): """Load user by ID""" return self.get_user_by_id(int(user_id)) - def load_user_jwt(self, user_id): - """Load user JWT""" - user = self.load_user(user_id) + def load_user_jwt(self, _jwt_header, jwt_data): + identity = jwt_data["sub"] + user = self.load_user(identity) # Set flask g.user to JWT user, we can't do it on before request g.user = user return user diff --git a/airflow/www/views.py b/airflow/www/views.py index e9a52611fcd65..fbebef3ecf5d6 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -43,7 +43,6 @@ Response, abort, before_render_template, - current_app, flash, g, jsonify, @@ -118,6 +117,7 @@ from airflow.ti_deps.dependencies_deps import RUNNING_DEPS, SCHEDULER_QUEUED_DEPS from airflow.timetables.base import DataInterval, TimeRestriction from airflow.utils import json as utils_json, timezone, yaml +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.dates import infer_time_unit, scale_time_units from airflow.utils.docs import get_doc_url_for_provider, get_docs_url from airflow.utils.helpers import alchemy_to_dict @@ -622,13 +622,13 @@ def add_user_permissions_to_dag(sender, template, context, **extra): """ if 'dag' in context: dag = context['dag'] - can_create_dag_run = current_app.appbuilder.sm.has_access( + can_create_dag_run = get_airflow_app().appbuilder.sm.has_access( permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN ) - dag.can_edit = current_app.appbuilder.sm.can_edit_dag(dag.dag_id) + dag.can_edit = get_airflow_app().appbuilder.sm.can_edit_dag(dag.dag_id) dag.can_trigger = dag.can_edit and can_create_dag_run - dag.can_delete = current_app.appbuilder.sm.can_delete_dag(dag.dag_id) + dag.can_delete = get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id) context['dag'] = dag @@ -715,7 +715,7 @@ def index(self): end = start + dags_per_page # Get all the dag id the user could access - filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) with create_session() as session: # read orm_dags from the db @@ -824,7 +824,7 @@ def index(self): ) dashboard_alerts = [ - fm for fm in settings.DASHBOARD_UIALERTS if fm.should_show(current_app.appbuilder.sm) + fm for fm in settings.DASHBOARD_UIALERTS if fm.should_show(get_airflow_app().appbuilder.sm) ] def _iter_parsed_moved_data_table_names(): @@ -904,7 +904,7 @@ def dag_stats(self, session=None): """Dag statistics.""" dr = models.DagRun - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) dag_state_stats = session.query(dr.dag_id, dr.state, sqla.func.count(dr.state)).group_by( dr.dag_id, dr.state @@ -949,7 +949,7 @@ def dag_stats(self, session=None): @provide_session def task_stats(self, session=None): """Task Statistics""" - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) if not allowed_dag_ids: return wwwutils.json_response({}) @@ -1058,7 +1058,7 @@ def task_stats(self, session=None): @provide_session def last_dagruns(self, session=None): """Last DAG runs""" - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id} @@ -1182,7 +1182,7 @@ def legacy_dag_details(self): @provide_session def dag_details(self, dag_id, session=None): """Get Dag details.""" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id, session=session) title = "DAG Details" @@ -1258,7 +1258,7 @@ def rendered_templates(self, session): root = request.args.get('root', '') logging.info("Retrieving rendered templates.") - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) dag_run = dag.get_dagrun(execution_date=dttm, session=session) raw_task = dag.get_task(task_id).prepare_for_execution() @@ -1353,15 +1353,17 @@ def rendered_k8s(self, session: Session = NEW_SESSION): abort(404) dag_id = request.args.get('dag_id') task_id = request.args.get('task_id') + if task_id is None: + logging.warning("Task id not passed in the request") + abort(400) execution_date = request.args.get('execution_date') dttm = _safe_parse_datetime(execution_date) - form = DateTimeForm(data={'execution_date': dttm}) root = request.args.get('root', '') map_index = request.args.get('map_index', -1, type=int) logging.info("Retrieving rendered templates.") - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) task = dag.get_task(task_id) dag_run = dag.get_dagrun(execution_date=dttm, session=session) ti = dag_run.get_task_instance(task_id=task.task_id, map_index=map_index, session=session) @@ -1466,7 +1468,7 @@ def get_logs_with_metadata(self, session=None): ) try: - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if dag: ti.task = dag.get_task(ti.task_id) @@ -1597,7 +1599,7 @@ def task(self, session): map_index = request.args.get('map_index', -1, type=int) form = DateTimeForm(data={'execution_date': dttm}) root = request.args.get('root', '') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag or task_id not in dag.task_ids: flash(f"Task [{dag_id}.{task_id}] doesn't seem to exist at the moment", "error") @@ -1776,7 +1778,7 @@ def run(self, session=None): dag_run_id = request.form.get('dag_run_id') map_index = request.args.get('map_index', -1, type=int) origin = get_safe_url(request.form.get('origin')) - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) task = dag.get_task(task_id) ignore_all_deps = request.form.get('ignore_all_deps') == "true" @@ -1877,7 +1879,7 @@ def trigger(self, session=None): request_conf = request.values.get('conf') request_execution_date = request.values.get('execution_date', default=timezone.utcnow().isoformat()) is_dag_run_conf_overrides_params = conf.getboolean('core', 'dag_run_conf_overrides_params') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_orm = session.query(models.DagModel).filter(models.DagModel.dag_id == dag_id).first() if not dag_orm: flash(f"Cannot find dag {dag_id}") @@ -1978,7 +1980,7 @@ def trigger(self, session=None): state=State.QUEUED, conf=run_conf, external_trigger=True, - dag_hash=current_app.dag_bag.dags_hash.get(dag_id), + dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id), run_id=run_id, ) except (ValueError, ParamValidationError) as ve: @@ -2060,7 +2062,7 @@ def clear(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') origin = get_safe_url(request.form.get('origin')) - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if 'map_index' not in request.form: map_indexes: Optional[List[int]] = None @@ -2121,7 +2123,7 @@ def dagrun_clear(self): dag_run_id = request.form.get('dag_run_id') confirmed = request.form.get('confirmed') == "true" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dr = dag.get_dagrun(run_id=dag_run_id) start_date = dr.logical_date end_date = dr.logical_date @@ -2145,7 +2147,7 @@ def dagrun_clear(self): @provide_session def blocked(self, session=None): """Mark Dag Blocked.""" - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id} @@ -2168,7 +2170,7 @@ def blocked(self, session=None): payload = [] for dag_id, active_dag_runs in dags: max_active_runs = 0 - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if dag: # TODO: Make max_active_runs a column so we can query for it directly max_active_runs = dag.max_active_runs @@ -2185,7 +2187,7 @@ def _mark_dagrun_state_as_failed(self, dag_id, dag_run_id, confirmed): if not dag_run_id: return {'status': 'error', 'message': 'Invalid dag_run_id'} - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: return {'status': 'error', 'message': f'Cannot find DAG: {dag_id}'} @@ -2203,7 +2205,7 @@ def _mark_dagrun_state_as_success(self, dag_id, dag_run_id, confirmed): if not dag_run_id: return {'status': 'error', 'message': 'Invalid dag_run_id'} - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: return {'status': 'error', 'message': f'Cannot find DAG: {dag_id}'} @@ -2221,7 +2223,7 @@ def _mark_dagrun_state_as_queued(self, dag_id: str, dag_run_id: str, confirmed: if not dag_run_id: return {'status': 'error', 'message': 'Invalid dag_run_id'} - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: return {'status': 'error', 'message': f'Cannot find DAG: {dag_id}'} @@ -2295,7 +2297,7 @@ def dagrun_details(self, session=None): dag_id = request.args.get("dag_id") run_id = request.args.get("run_id") - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_run: Optional[DagRun] = ( session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one_or_none() ) @@ -2346,7 +2348,7 @@ def _mark_task_instance_state( past: bool, state: TaskInstanceState, ): - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not run_id: flash(f"Cannot mark tasks as {state}, seem that DAG {dag_id} has never run", "error") @@ -2394,7 +2396,7 @@ def confirm(self): past = to_boolean(args.get('past')) origin = origin or url_for('Airflow.index') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: msg = f'DAG {dag_id} not found' return redirect_or_json(origin, msg, status='error', status_code=404) @@ -2583,7 +2585,7 @@ def tree(self): @provide_session def grid(self, dag_id, session=None): """Get Dag's grid view.""" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) if not dag: flash(f'DAG "{dag_id}" seems to be missing from DagBag.', "error") @@ -2703,7 +2705,7 @@ def _convert_to_date(session, column): else: return func.date(column) - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) if not dag: flash(f'DAG "{dag_id}" seems to be missing from DagBag.', "error") @@ -2817,7 +2819,7 @@ def legacy_graph(self): @provide_session def graph(self, dag_id, session=None): """Get DAG as Graph.""" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) if not dag: flash(f'DAG "{dag_id}" seems to be missing.', "error") @@ -2929,7 +2931,7 @@ def duration(self, dag_id, session=None): default_dag_run = conf.getint('webserver', 'default_dag_run_display_number') dag_model = DagModel.get_dagmodel(dag_id) - dag: Optional[DAG] = current_app.dag_bag.get_dag(dag_id) + dag: Optional[DAG] = get_airflow_app().dag_bag.get_dag(dag_id) if dag is None: flash(f'DAG "{dag_id}" seems to be missing.', "error") return redirect(url_for('Airflow.index')) @@ -3081,7 +3083,7 @@ def legacy_tries(self): def tries(self, dag_id, session=None): """Shows all tries.""" default_dag_run = conf.getint('webserver', 'default_dag_run_display_number') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) base_date = request.args.get('base_date') num_runs = request.args.get('num_runs', default=default_dag_run, type=int) @@ -3171,7 +3173,7 @@ def legacy_landing_times(self): def landing_times(self, dag_id, session=None): """Shows landing times.""" default_dag_run = conf.getint('webserver', 'default_dag_run_display_number') - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) base_date = request.args.get('base_date') num_runs = request.args.get('num_runs', default=default_dag_run, type=int) @@ -3288,7 +3290,7 @@ def legacy_gantt(self): @provide_session def gantt(self, dag_id, session=None): """Show GANTT chart.""" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) root = request.args.get('root') @@ -3414,9 +3416,8 @@ def extra_links(self, session: "Session" = NEW_SESSION): task_id = request.args.get('task_id') map_index = request.args.get('map_index', -1, type=int) execution_date = request.args.get('execution_date') - link_name = request.args.get('link_name') dttm = _safe_parse_datetime(execution_date) - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag or task_id not in dag.task_ids: response = jsonify( @@ -3429,6 +3430,11 @@ def extra_links(self, session: "Session" = NEW_SESSION): return response task: "AbstractOperator" = dag.get_task(task_id) + link_name = request.args.get('link_name') + if link_name is None: + response = jsonify({'url': None, 'error': 'Link name not passed'}) + response.status_code = 400 + return response ti = ( session.query(TaskInstance) @@ -3466,7 +3472,7 @@ def extra_links(self, session: "Session" = NEW_SESSION): def task_instances(self): """Shows task instances.""" dag_id = request.args.get('dag_id') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dttm = request.args.get('execution_date') if dttm: @@ -3494,7 +3500,7 @@ def task_instances(self): def grid_data(self): """Returns grid data""" dag_id = request.args.get('dag_id') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: response = jsonify({'error': f"can't find dag {dag_id}"}) @@ -3546,7 +3552,7 @@ def robots(self): of the risk associated with exposing Airflow to the public internet, however it does not address the real security risks associated with such a deployment. """ - return send_from_directory(current_app.static_folder, 'robots.txt') + return send_from_directory(get_airflow_app().static_folder, 'robots.txt') @expose('/audit_log') @auth.has_access( @@ -3558,7 +3564,7 @@ def robots(self): @provide_session def audit_log(self, session=None): dag_id = request.args.get('dag_id') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) included_events = conf.get('webserver', 'audit_view_included_events', fallback=None) excluded_events = conf.get('webserver', 'audit_view_excluded_events', fallback=None) @@ -3663,9 +3669,9 @@ class DagFilter(BaseFilter): """Filter using DagIDs""" def apply(self, query, func): - if current_app.appbuilder.sm.has_all_dags_access(g.user): + if get_airflow_app().appbuilder.sm.has_all_dags_access(g.user): return query - filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) return query.filter(self.model.dag_id.in_(filter_dag_ids)) @@ -3688,7 +3694,7 @@ class AirflowPrivilegeVerifierModelView(AirflowModelView): @staticmethod def validate_dag_edit_access(item: Union[DagRun, TaskInstance]): """Validates whether the user has 'can_edit' access for this specific DAG.""" - if not current_app.appbuilder.sm.can_edit_dag(item.dag_id): + if not get_airflow_app().appbuilder.sm.can_edit_dag(item.dag_id): raise AirflowException(f"Access denied for dag_id {item.dag_id}") def pre_add(self, item: Union[DagRun, TaskInstance]): @@ -3719,7 +3725,7 @@ def check_dag_edit_acl_for_actions( items: Optional[Union[List[TaskInstance], List[DagRun], TaskInstance, DagRun]], *args, **kwargs, - ) -> None: + ) -> Callable: if items is None: dag_ids: Set[str] = set() elif isinstance(items, list): @@ -3734,7 +3740,7 @@ def check_dag_edit_acl_for_actions( ) for dag_id in dag_ids: - if not current_app.appbuilder.sm.can_edit_dag(dag_id): + if not get_airflow_app().appbuilder.sm.can_edit_dag(dag_id): flash(f"Access denied for dag_id {dag_id}", "danger") logging.warning("User %s tried to modify %s without having access.", g.user.username, dag_id) return redirect(self.get_default_url()) @@ -4337,7 +4343,9 @@ def fqueued_slots(self): def _can_create_variable() -> bool: - return current_app.appbuilder.sm.has_access(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE) + return get_airflow_app().appbuilder.sm.has_access( + permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE + ) class VariableModelView(AirflowModelView): @@ -4681,7 +4689,10 @@ def action_set_failed(self, drs: List[DagRun], session=None): for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all(): count += 1 altered_tis += set_dag_run_state_to_failed( - dag=current_app.dag_bag.get_dag(dr.dag_id), run_id=dr.run_id, commit=True, session=session + dag=get_airflow_app().dag_bag.get_dag(dr.dag_id), + run_id=dr.run_id, + commit=True, + session=session, ) altered_ti_count = len(altered_tis) flash(f"{count} dag runs and {altered_ti_count} task instances were set to failed") @@ -4706,7 +4717,10 @@ def action_set_success(self, drs: List[DagRun], session=None): for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all(): count += 1 altered_tis += set_dag_run_state_to_success( - dag=current_app.dag_bag.get_dag(dr.dag_id), run_id=dr.run_id, commit=True, session=session + dag=get_airflow_app().dag_bag.get_dag(dr.dag_id), + run_id=dr.run_id, + commit=True, + session=session, ) altered_ti_count = len(altered_tis) flash(f"{count} dag runs and {altered_ti_count} task instances were set to success") @@ -4726,7 +4740,7 @@ def action_clear(self, drs: List[DagRun], session=None): dag_to_tis: Dict[DAG, List[TaskInstance]] = {} for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all(): count += 1 - dag = current_app.dag_bag.get_dag(dr.dag_id) + dag = get_airflow_app().dag_bag.get_dag(dr.dag_id) tis_to_clear = dag_to_tis.setdefault(dag, []) tis_to_clear += dr.get_task_instances() @@ -5019,7 +5033,7 @@ def action_clear(self, task_instances, session=None): dag_to_tis = collections.defaultdict(list) for ti in task_instances: - dag = current_app.dag_bag.get_dag(ti.dag_id) + dag = get_airflow_app().dag_bag.get_dag(ti.dag_id) dag_to_tis[dag].append(ti) for dag, task_instances_list in dag_to_tis.items(): @@ -5135,7 +5149,7 @@ def autocomplete(self, session=None): dag_ids_query = dag_ids_query.filter(DagModel.is_paused) owners_query = owners_query.filter(DagModel.is_paused) - filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) dag_ids_query = dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids)) owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids)) diff --git a/dev/breeze/README.md b/dev/breeze/README.md index 7794f25e4e8eb..14a9f089a4834 100644 --- a/dev/breeze/README.md +++ b/dev/breeze/README.md @@ -52,6 +52,6 @@ PLEASE DO NOT MODIFY THE HASH BELOW! IT IS AUTOMATICALLY UPDATED BY PRE-COMMIT. --------------------------------------------------------------------------------------------------------- -Package config hash: 40b9b6908905e94c93809cca70c68c632731242798dba9cbe62473e965cb4e5d44eaaa817c5ce9334397f3794a350bc00e3cf319631a25c461a935a389191e7b +Package config hash: a80a853b2c32c284a68ccd6d468804b892a69f14d2ad1886bdaa892755cf6262660e2b9fc582bcae27ae478910055267a76edea2df658196198a0365150e93e5 --------------------------------------------------------------------------------------------------------- diff --git a/dev/breeze/setup.cfg b/dev/breeze/setup.cfg index 9c7154ce52f2a..c974560561053 100644 --- a/dev/breeze/setup.cfg +++ b/dev/breeze/setup.cfg @@ -64,7 +64,7 @@ install_requires = pyyaml requests rich>=12.4.4 - rich_click + rich-click>=1.5 [options.packages.find] where=src diff --git a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py index a5333a08ab0b6..caf43d47a409b 100644 --- a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py @@ -84,6 +84,7 @@ "--airflow-extras", "--use-packages-from-dist", "--package-format", + "--skip-constraints", "--debug", ], } @@ -511,6 +512,12 @@ def generate_constraints( @option_use_airflow_version @option_airflow_extras @option_airflow_constraints_reference +@click.option( + "--skip-constraints", + is_flag=True, + help="Do not use constraints when installing providers.", + envvar='SKIP_CONSTRAINTS', +) @option_use_packages_from_dist @option_installation_package_format @option_verbose @@ -522,6 +529,7 @@ def verify_provider_packages( dry_run: bool, use_airflow_version: Optional[str], airflow_constraints_reference: str, + skip_constraints: bool, airflow_extras: str, use_packages_from_dist: bool, debug: bool, @@ -538,6 +546,7 @@ def verify_provider_packages( airflow_extras=airflow_extras, airflow_constraints_reference=airflow_constraints_reference, use_packages_from_dist=use_packages_from_dist, + skip_constraints=skip_constraints, package_format=package_format, ) rebuild_or_pull_ci_image_if_needed(command_params=shell_params, dry_run=dry_run, verbose=verbose) diff --git a/dev/breeze/src/airflow_breeze/params/shell_params.py b/dev/breeze/src/airflow_breeze/params/shell_params.py index bb4ad3393ead3..b67d362186653 100644 --- a/dev/breeze/src/airflow_breeze/params/shell_params.py +++ b/dev/breeze/src/airflow_breeze/params/shell_params.py @@ -81,6 +81,7 @@ class ShellParams: postgres_version: str = ALLOWED_POSTGRES_VERSIONS[0] python: str = ALLOWED_PYTHON_MAJOR_MINOR_VERSIONS[0] skip_environment_initialization: bool = False + skip_constraints: bool = False start_airflow: str = "false" use_airflow_version: Optional[str] = None use_packages_from_dist: bool = False diff --git a/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py b/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py index 96b83b5ed7945..8af6010c40403 100644 --- a/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py +++ b/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py @@ -571,6 +571,7 @@ def update_expected_environment_variables(env: Dict[str, str]) -> None: "POSTGRES_VERSION": "postgres_version", "SQLITE_URL": "sqlite_url", "START_AIRFLOW": "start_airflow", + "SKIP_CONSTRAINTS": "skip_constraints", "SKIP_ENVIRONMENT_INITIALIZATION": "skip_environment_initialization", "USE_AIRFLOW_VERSION": "use_airflow_version", "USE_PACKAGES_FROM_DIST": "use_packages_from_dist", diff --git a/dev/breeze/src/airflow_breeze/utils/recording.py b/dev/breeze/src/airflow_breeze/utils/recording.py index 0ec34edccac94..2fe9f5b5558de 100644 --- a/dev/breeze/src/airflow_breeze/utils/recording.py +++ b/dev/breeze/src/airflow_breeze/utils/recording.py @@ -53,10 +53,10 @@ def __init__(self, **kwargs): atexit.register(save_ouput_as_svg) click.rich_click.MAX_WIDTH = width_int - click.formatting.FORCED_WIDTH = width_int - 2 + click.formatting.FORCED_WIDTH = width_int - 2 # type: ignore[attr-defined] click.rich_click.COLOR_SYSTEM = "standard" # monkeypatch rich_click console to record help (rich_click does not allow passing extra args to console) - click.rich_click.Console = RecordingConsole + click.rich_click.Console = RecordingConsole # type: ignore[misc] if output_file_for_recording and not in_autocomplete(): diff --git a/dev/send_email.py b/dev/send_email.py index 91a35b97cc425..2d796eb80066b 100755 --- a/dev/send_email.py +++ b/dev/send_email.py @@ -83,8 +83,7 @@ def show_message(entity: str, message: str): """ Show message on the Command Line """ - width, _ = click.get_terminal_size() - + width, _ = click.get_terminal_size() # type: ignore[attr-defined] click.secho("-" * width, fg="blue") click.secho(f"{entity} Message:", fg="bright_red", bold=True) click.secho("-" * width, fg="blue") diff --git a/images/breeze/output-commands-hash.txt b/images/breeze/output-commands-hash.txt deleted file mode 100644 index 0ea56c33a726b..0000000000000 --- a/images/breeze/output-commands-hash.txt +++ /dev/null @@ -1,5 +0,0 @@ - -# This file is automatically generated by pre-commit. If you have a conflict with this file -# Please do not solve it but run `breeze regenerate-command-images`. -# This command should fix the conflict and regenerate help images that you have conflict with. -9139ef44b7f1ba24ddee50b71d3867c2 diff --git a/images/breeze/output-verify-provider-packages.svg b/images/breeze/output-verify-provider-packages.svg index 12853b46a203f..fbfd468ffafd1 100644 --- a/images/breeze/output-verify-provider-packages.svg +++ b/images/breeze/output-verify-provider-packages.svg @@ -1,4 +1,4 @@ - + - - + + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + - Command: verify-provider-packages + Command: verify-provider-packages - + - - -Usage: breeze verify-provider-packages [OPTIONS] - -Verifies if all provider code is following expectations for providers. - -╭─ Provider verification flags ────────────────────────────────────────────────────────────────────────────────────────╮ ---use-airflow-versionUse (reinstall at entry) Airflow version from PyPI. It can also be `none`,        -`wheel`, or `sdist` if Airflow should be removed, installed from wheel packages   -or sdist packages available in dist folder respectively. Implies --mount-sources -`remove`.                                                                         -(none | wheel | sdist | <airflow_version>)                                        ---airflow-constraints-referenceConstraint reference to use. Useful with --use-airflow-version parameter to       -specify constraints for the installed version and to find newer dependencies      -(TEXT)                                                                            ---airflow-extrasAirflow extras to install when --use-airflow-version is used(TEXT) ---use-packages-from-distInstall all found packages (--package-format determines type) from 'dist' folder  -when entering breeze.                                                             ---package-formatFormat of packages that should be installed from dist.(wheel | sdist) -[default: wheel]                                       ---debugDrop user in shell instead of running the command. Useful for debugging. -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ ---verbose-vPrint verbose information about performed steps. ---dry-run-DIf dry-run is set, commands are only printed, not executed. ---github-repository-gGitHub repository used to pull, push run images.(TEXT)[default: apache/airflow] ---help-hShow this message and exit. -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ + + +Usage: breeze verify-provider-packages [OPTIONS] + +Verifies if all provider code is following expectations for providers. + +╭─ Provider verification flags ────────────────────────────────────────────────────────────────────────────────────────╮ +--use-airflow-versionUse (reinstall at entry) Airflow version from PyPI. It can also be `none`,        +`wheel`, or `sdist` if Airflow should be removed, installed from wheel packages   +or sdist packages available in dist folder respectively. Implies --mount-sources +`remove`.                                                                         +(none | wheel | sdist | <airflow_version>)                                        +--airflow-constraints-referenceConstraint reference to use. Useful with --use-airflow-version parameter to       +specify constraints for the installed version and to find newer dependencies      +(TEXT)                                                                            +--airflow-extrasAirflow extras to install when --use-airflow-version is used(TEXT) +--use-packages-from-distInstall all found packages (--package-format determines type) from 'dist' folder  +when entering breeze.                                                             +--package-formatFormat of packages that should be installed from dist.(wheel | sdist) +[default: wheel]                                       +--skip-constraintsDo not use constraints when installing providers. +--debugDrop user in shell instead of running the command. Useful for debugging. +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +--verbose-vPrint verbose information about performed steps. +--dry-run-DIf dry-run is set, commands are only printed, not executed. +--github-repository-gGitHub repository used to pull, push run images.(TEXT)[default: apache/airflow] +--help-hShow this message and exit. +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ diff --git a/newsfragments/24399.significant.rst b/newsfragments/24399.significant.rst new file mode 100644 index 0000000000000..7f1833a8e8c7e --- /dev/null +++ b/newsfragments/24399.significant.rst @@ -0,0 +1,31 @@ +We've upgraded Flask Application Builder to a major version 4.*. + +Flask Application Builder is one of the important components of Airflow Webserver, as +it uses a lof of dependencies that are essential to run the webserver and integrate it +in enterprise environments - especially authentication. + +The FAB 4.* upgrades a number of dependencies to major releases, which upgrades them to versions +that have a number of security issues fixed. A lot of tests were performed to bring the dependencies +in a backwards-compatible way, however the dependencies themselves implement breaking changes in their +internals so it might be that some of those changes might impact the users in case they are using the +libraries for their onw purposes. + +One important change that you likely will need to apply to Oauth configuration is to add +``server_metadata_url`` or ``jwks_uri`` and you can read about it more +in `this issue `_. + +Here is the list of breaking changes in dependencies that comes together with FAB 4: + +* Flask from 1.X to 2.X `breaking changes `_ + +* flask-jwt-extended 3.X to 4.X `breaking changes: `_ + +* Jinja2 2.X to 3.X `breaking changes: `_ + +* Werkzeug 1.X to 2.X `breaking changes `_ + +* pyJWT 1.X to 2.X `breaking changes: `_ + +* Click 7.X to 8.X `breaking changes: `_ + +* itsdangerous 1.X to 2.X `breaking changes `_ diff --git a/scripts/ci/docker-compose/_docker.env b/scripts/ci/docker-compose/_docker.env index 4edc849b57b93..b33cfea3602e5 100644 --- a/scripts/ci/docker-compose/_docker.env +++ b/scripts/ci/docker-compose/_docker.env @@ -59,6 +59,7 @@ RUN_TESTS LIST_OF_INTEGRATION_TESTS_TO_RUN RUN_SYSTEM_TESTS START_AIRFLOW +SKIP_CONSTRAINTS SKIP_ENVIRONMENT_INITIALIZATION SKIP_SSH_SETUP TEST_TYPE diff --git a/scripts/ci/docker-compose/base.yml b/scripts/ci/docker-compose/base.yml index 48e4d3df9606e..c1285eda88193 100644 --- a/scripts/ci/docker-compose/base.yml +++ b/scripts/ci/docker-compose/base.yml @@ -72,6 +72,7 @@ services: - LIST_OF_INTEGRATION_TESTS_TO_RUN=${LIST_OF_INTEGRATION_TESTS_TO_RUN} - RUN_SYSTEM_TESTS=${RUN_SYSTEM_TESTS} - START_AIRFLOW=${START_AIRFLOW} + - SKIP_CONSTRAINTS=${SKIP_CONSTRAINTS} - SKIP_ENVIRONMENT_INITIALIZATION=${SKIP_ENVIRONMENT_INITIALIZATION} - SKIP_SSH_SETUP=${SKIP_SSH_SETUP} - TEST_TYPE=${TEST_TYPE} diff --git a/scripts/ci/docker-compose/devcontainer.env b/scripts/ci/docker-compose/devcontainer.env index 1c4b27b36af67..ae51b204436ac 100644 --- a/scripts/ci/docker-compose/devcontainer.env +++ b/scripts/ci/docker-compose/devcontainer.env @@ -57,6 +57,7 @@ RUN_TESTS="false" LIST_OF_INTEGRATION_TESTS_TO_RUN="" RUN_SYSTEM_TESTS="" START_AIRFLOW="false" +SKIP_CONSTRAINTS="false" SKIP_SSH_SETUP="true" SKIP_ENVIRONMENT_INITIALIZATION="false" TEST_TYPE= diff --git a/scripts/ci/pre_commit/pre_commit_check_2_1_compatibility.py b/scripts/ci/pre_commit/pre_commit_check_2_1_compatibility.py index 0d43959ba5332..c28d29d76dcdb 100755 --- a/scripts/ci/pre_commit/pre_commit_check_2_1_compatibility.py +++ b/scripts/ci/pre_commit/pre_commit_check_2_1_compatibility.py @@ -36,6 +36,8 @@ GET_ATTR_MATCHER = re.compile(r".*getattr\((ti|TI), ['\"]run_id['\"]\).*") TI_RUN_ID_MATCHER = re.compile(r".*(ti|TI)\.run_id.*") TRY_NUM_MATCHER = re.compile(r".*context.*\[[\"']try_number[\"']].*") +GET_MANDATORY_MATCHER = re.compile(r".*conf\.get_mandatory_value") +GET_AIRFLOW_APP_MATCHER = re.compile(r".*get_airflow_app\(\)") def _check_file(_file: Path): @@ -57,13 +59,13 @@ def _check_file(_file: Path): if "if ti_key is not None:" not in lines[index - 1]: errors.append( f"[red]In {_file}:{index} there is a forbidden construct " - f"(Airflow 2.3.0 only):[/]\n\n" + "(Airflow 2.3.0 only):[/]\n\n" f"{lines[index-1]}\n{lines[index]}\n\n" - f"[yellow]When you use XCom.get_value( in providers, it should be in the form:[/]\n\n" - f"if ti_key is not None:\n" - f" value = XCom.get_value(...., ti_key=ti_key)\n\n" - f"See: https://airflow.apache.org/docs/apache-airflow-providers/" - f"howto/create-update-providers.html#using-providers-with-dynamic-task-mapping\n" + "[yellow]When you use XCom.get_value( in providers, it should be in the form:[/]\n\n" + "if ti_key is not None:\n" + " value = XCom.get_value(...., ti_key=ti_key)\n\n" + "See: https://airflow.apache.org/docs/apache-airflow-providers/" + "howto/create-update-providers.html#using-providers-with-dynamic-task-mapping\n" ) if "timezone.coerce_datetime" in line: errors.append( @@ -76,19 +78,37 @@ def _check_file(_file: Path): if "ti.map_index" in line: errors.append( f"[red]In {_file}:{index} there is a forbidden construct " - f"(Airflow 2.3+ only):[/]\n\n" + "(Airflow 2.3+ only):[/]\n\n" f"{lines[index]}\n\n" - f"[yellow]You should not use map_index field in providers " - f"as it is not available in Airflow 2.2[/]" + "[yellow]You should not use map_index field in providers " + "as it is only available in Airflow 2.3+[/]" ) if TRY_NUM_MATCHER.match(line): errors.append( f"[red]In {_file}:{index} there is a forbidden construct " - f"(Airflow 2.3+ only):[/]\n\n" + "(Airflow 2.3+ only):[/]\n\n" f"{lines[index]}\n\n" - f"[yellow]You should not expect try_number field for context in providers " - f"as it is not available in Airflow 2.2[/]" + "[yellow]You should not expect try_number field for context in providers " + "as it is only available in Airflow 2.3+[/]" + ) + + if GET_MANDATORY_MATCHER.match(line): + errors.append( + f"[red]In {_file}:{index} there is a forbidden construct " + "(Airflow 2.3+ only):[/]\n\n" + f"{lines[index]}\n\n" + "[yellow]You should not use conf.get_mandatory_value in providers " + "as it is only available in Airflow 2.3+[/]" + ) + + if GET_AIRFLOW_APP_MATCHER.match(line): + errors.append( + f"[red]In {_file}:{index} there is a forbidden construct " + "(Airflow 2.4+ only):[/]\n\n" + f"{lines[index]}\n\n" + "[yellow]You should not use airflow.utils.airflow_flask_app.get_airflow_app() in providers " + "as it is not available in Airflow 2.4+. Use current_app instead.[/]" ) diff --git a/scripts/docker/entrypoint_ci.sh b/scripts/docker/entrypoint_ci.sh index f5198a556c21c..2994604fc9274 100755 --- a/scripts/docker/entrypoint_ci.sh +++ b/scripts/docker/entrypoint_ci.sh @@ -94,9 +94,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo "${COLOR_BLUE}Uninstalling airflow and providers" echo uninstall_airflow_and_providers - echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "none" + else + echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi uninstall_providers elif [[ ${USE_AIRFLOW_VERSION} == "sdist" ]]; then echo @@ -104,9 +110,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo uninstall_airflow_and_providers echo - echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "none" + else + echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi uninstall_providers else echo @@ -114,9 +126,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo uninstall_airflow_and_providers echo - echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_released_airflow_version "${USE_AIRFLOW_VERSION}" "none" + else + echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi fi if [[ ${USE_PACKAGES_FROM_DIST=} == "true" ]]; then echo diff --git a/scripts/in_container/_in_container_utils.sh b/scripts/in_container/_in_container_utils.sh index d6a637e5c348b..66f2e6b083499 100644 --- a/scripts/in_container/_in_container_utils.sh +++ b/scripts/in_container/_in_container_utils.sh @@ -224,8 +224,12 @@ function install_airflow_from_wheel() { >&2 echo exit 4 fi - pip install "${airflow_package}${extras}" --constraint \ - "https://raw.githubusercontent.com/apache/airflow/${constraints_reference}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + if [[ ${constraints_reference} == "none" ]]; then + pip install "${airflow_package}${extras}" + else + pip install "${airflow_package}${extras}" --constraint \ + "https://raw.githubusercontent.com/apache/airflow/${constraints_reference}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + fi } function install_airflow_from_sdist() { @@ -250,8 +254,12 @@ function install_airflow_from_sdist() { >&2 echo exit 4 fi - pip install "${airflow_package}${extras}" --constraint \ - "https://raw.githubusercontent.com/apache/airflow/${constraints_reference}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + if [[ ${constraints_reference} == "none" ]]; then + pip install "${airflow_package}${extras}" + else + pip install "${airflow_package}${extras}" --constraint \ + "https://raw.githubusercontent.com/apache/airflow/${constraints_reference}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + fi } function uninstall_airflow() { @@ -278,17 +286,20 @@ function uninstall_airflow_and_providers() { function install_released_airflow_version() { local version="${1}" - echo - echo "Installing released ${version} version of airflow with extras: ${AIRFLOW_EXTRAS} and constraints constraints-${version}" - echo + local constraints_reference + constraints_reference="${2:-}" rm -rf "${AIRFLOW_SOURCES}"/*.egg-info if [[ ${AIRFLOW_EXTRAS} != "" ]]; then BRACKETED_AIRFLOW_EXTRAS="[${AIRFLOW_EXTRAS}]" else BRACKETED_AIRFLOW_EXTRAS="" fi - pip install "apache-airflow${BRACKETED_AIRFLOW_EXTRAS}==${version}" \ - --constraint "https://raw.githubusercontent.com/${CONSTRAINTS_GITHUB_REPOSITORY}/constraints-${version}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + if [[ ${constraints_reference} == "none" ]]; then + pip install "${airflow_package}${extras}" + else + pip install "apache-airflow${BRACKETED_AIRFLOW_EXTRAS}==${version}" \ + --constraint "https://raw.githubusercontent.com/${CONSTRAINTS_GITHUB_REPOSITORY}/constraints-${version}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + fi } function install_local_airflow_with_eager_upgrade() { diff --git a/setup.cfg b/setup.cfg index 1512a6201c46e..a32326cee29de 100644 --- a/setup.cfg +++ b/setup.cfg @@ -101,50 +101,31 @@ install_requires = cryptography>=0.9.3 deprecated>=1.2.13 dill>=0.2.2 - # Flask and all related libraries are limited to below 2.0.0 because we expect it to introduce - # Serious breaking changes. Flask 2.0 has been introduced in May 2021 and 2.0.2 version is available - # now (Feb 2022): TODO: we should attempt to migrate to Flask 2 and all below flask libraries soon. - flask>=1.1.0, <2.0 + flask>=2.0 # We are tightly coupled with FAB version because we vendored in part of FAB code related to security manager # This is done as part of preparation to removing FAB as dependency, but we are not ready for it yet # Every time we update FAB version here, please make sure that you review the classes and models in # `airflow/www/fab_security` with their upstream counterparts. In particular, make sure any breaking changes, # for example any new methods, are accounted for. - flask-appbuilder==3.4.5 - flask-caching>=1.5.0, <2.0.0 - flask-login>=0.3, <0.5 - # Strict upper-bound on the latest release of flask-session, - # as any schema changes will require a migration. - flask-session>=0.3.1, <=0.4.0 - flask-wtf>=0.14.3, <0.15 + flask-appbuilder==4.1.1 + flask-caching>=1.5.0 + flask-login>=0.5 + flask-session>=0.4.0 + flask-wtf>=0.14.3 graphviz>=0.12 gunicorn>=20.1.0 httpx importlib_metadata>=1.7;python_version<"3.9" importlib_resources>=5.2;python_version<"3.9" - # Logging is broken with itsdangerous > 2 - likely due to changed serializing support - # https://itsdangerous.palletsprojects.com/en/2.0.x/changes/#version-2-0-0 - # itsdangerous 2 has been released in May 2020 - # TODO: we should attempt to upgrade to line 2 of itsdangerous - itsdangerous>=1.1.0, <2.0 - # Jinja2 3.1 will remove the 'autoescape' and 'with' extensions, which would - # break Flask 1.x, so we limit this for future compatibility. Remove this - # when bumping Flask to >=2. - jinja2>=2.10.1,<3.1 - # Because connexion upper-bound is 5.0.0 and we depend on connexion, - # we pin to the same upper-bound as connexion. - jsonschema>=3.2.0, <5.0 + itsdangerous>=2.0 + jinja2>=2.10.1 + jsonschema>=3.2.0 lazy-object-proxy linkify-it-py>=2.0.0 lockfile>=0.12.2 markdown>=3.0 - # Markupsafe 2.1.0 breaks with error: import name 'soft_unicode' from 'markupsafe'. - # This should be removed when either this issue is closed: - # https://github.com/pallets/markupsafe/issues/284 - # or when we will be able to upgrade JINJA to newer version (currently limited due to Flask and - # Flask Application Builder) markdown-it-py>=2.1.0 - markupsafe>=1.1.1,<2.1.0 + markupsafe>=1.1.1 marshmallow-oneofschema>=2.0.1 mdit-py-plugins>=0.3.0 packaging>=14.0 @@ -153,8 +134,7 @@ install_requires = pluggy>=1.0 psutil>=4.2.0 pygments>=2.0.1 - # python daemon crashes with 'socket operation on non-socket' for python 3.8+ in version < 2.2.4 - # https://pagure.io/python-daemon/issue/34 + pyjwt>=2.0.0 python-daemon>=2.2.4 python-dateutil>=2.3 python-nvd3>=0.15.0 @@ -172,10 +152,7 @@ install_requires = termcolor>=1.1.0 typing-extensions>=3.7.4 unicodecsv>=0.14.1 - # Werkzeug is known to cause breaking changes and it is very closely tied with FlaskAppBuilder and other - # Flask dependencies and the limit to 1.* line should be reviewed when we upgrade Flask and remove - # FlaskAppBuilder. - werkzeug~=1.0, >=1.0.1 + werkzeug>=2.0 [options.packages.find] include = diff --git a/setup.py b/setup.py index f2cce10ed150e..16fcb73102544 100644 --- a/setup.py +++ b/setup.py @@ -617,16 +617,6 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'flake8-implicit-str-concat', 'flaky', 'freezegun', - # Github3 version 3.1.2 requires PyJWT>=2.3.0 which clashes with Flask App Builder where PyJWT is <2.0.0 - # Actually GitHub3.1.0 already introduced PyJWT>=2.3.0 but so far `pip` was able to resolve it without - # getting into a long backtracking loop and figure out that github3 3.0.0 version is the right version - # similarly limiting it to 3.1.2 causes pip not to enter the backtracking loop. Apparently when there - # are 3 versions with PyJWT>=2.3.0 (3.1.0, 3.1.1 an 3.1.2) pip enters into backtrack loop and fails - # to resolve that github3 3.0.0 is the right version to use. - # This limitation could be removed if PyJWT limitation < 2.0.0 is dropped from FAB or when - # pip resolution is improved to handle the case. The issue which describes this PIP behaviour - # and hopefully allowing to improve it is tracked in https://github.com/pypa/pip/issues/10924 - 'github3.py<3.1.0', 'gitpython', 'ipdb', 'jira', @@ -660,7 +650,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'qds-sdk>=1.9.6', 'pytest-httpx', 'requests_mock', - 'rich_click', + 'rich-click>=1.5', 'semver', 'towncrier', 'twine', diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index d95d4c38549df..09ef3c28ae050 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -19,12 +19,10 @@ from datetime import datetime import pytest -from itsdangerous import URLSafeSerializer from parameterized import parameterized from airflow import DAG from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.configuration import conf from airflow.models import DagBag, DagModel from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator @@ -34,8 +32,12 @@ from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags -SERIALIZER = URLSafeSerializer(conf.get('webserver', 'secret_key')) -FILE_TOKEN = SERIALIZER.dumps(__file__) + +@pytest.fixture() +def current_file_token(url_safe_serializer) -> str: + return url_safe_serializer.dumps(__file__) + + DAG_ID = "test_dag" TASK_ID = "op1" DAG2_ID = "test_dag2" @@ -246,7 +248,7 @@ def test_should_respond_403_with_granular_access_for_different_dag(self): class TestGetDagDetails(TestDagEndpoint): - def test_should_respond_200(self): + def test_should_respond_200(self, current_file_token): response = self.client.get( f"/api/v1/dags/{self.dag_id}/details", environ_overrides={'REMOTE_USER': "test"} ) @@ -262,7 +264,7 @@ def test_should_respond_200(self): "description": None, "doc_md": "details", "fileloc": __file__, - "file_token": FILE_TOKEN, + "file_token": current_file_token, "is_paused": None, "is_active": None, "is_subdag": False, @@ -294,7 +296,7 @@ def test_should_respond_200(self): } assert response.json == expected - def test_should_response_200_with_doc_md_none(self): + def test_should_response_200_with_doc_md_none(self, current_file_token): response = self.client.get( f"/api/v1/dags/{self.dag2_id}/details", environ_overrides={'REMOTE_USER': "test"} ) @@ -310,7 +312,7 @@ def test_should_response_200_with_doc_md_none(self): "description": None, "doc_md": None, "fileloc": __file__, - "file_token": FILE_TOKEN, + "file_token": current_file_token, "is_paused": None, "is_active": None, "is_subdag": False, @@ -335,7 +337,7 @@ def test_should_response_200_with_doc_md_none(self): } assert response.json == expected - def test_should_response_200_for_null_start_date(self): + def test_should_response_200_for_null_start_date(self, current_file_token): response = self.client.get( f"/api/v1/dags/{self.dag3_id}/details", environ_overrides={'REMOTE_USER': "test"} ) @@ -351,7 +353,7 @@ def test_should_response_200_for_null_start_date(self): "description": None, "doc_md": None, "fileloc": __file__, - "file_token": FILE_TOKEN, + "file_token": current_file_token, "is_paused": None, "is_active": None, "is_subdag": False, @@ -376,7 +378,7 @@ def test_should_response_200_for_null_start_date(self): } assert response.json == expected - def test_should_respond_200_serialized(self): + def test_should_respond_200_serialized(self, current_file_token): # Get the dag out of the dagbag before we patch it to an empty one SerializedDagModel.write_dag(self.app.dag_bag.get_dag(self.dag_id)) @@ -395,7 +397,7 @@ def test_should_respond_200_serialized(self): "description": None, "doc_md": "details", "fileloc": __file__, - "file_token": FILE_TOKEN, + "file_token": current_file_token, "is_paused": None, "is_active": None, "is_subdag": False, @@ -449,7 +451,7 @@ def test_should_respond_200_serialized(self): 'description': None, 'doc_md': 'details', 'fileloc': __file__, - "file_token": FILE_TOKEN, + "file_token": current_file_token, 'is_paused': None, "is_active": None, 'is_subdag': False, @@ -496,7 +498,7 @@ def test_should_raise_404_when_dag_is_not_found(self): class TestGetDags(TestDagEndpoint): @provide_session - def test_should_respond_200(self, session): + def test_should_respond_200(self, session, url_safe_serializer): self._create_dag_models(2) self._create_deactivated_dag() @@ -504,8 +506,8 @@ def test_should_respond_200(self, session): assert len(dags_query.all()) == 3 response = self.client.get("api/v1/dags", environ_overrides={'REMOTE_USER': "test"}) - file_token = SERIALIZER.dumps("/tmp/dag_1.py") - file_token2 = SERIALIZER.dumps("/tmp/dag_2.py") + file_token = url_safe_serializer.dumps("/tmp/dag_1.py") + file_token2 = url_safe_serializer.dumps("/tmp/dag_2.py") assert response.status_code == 200 assert { @@ -576,11 +578,11 @@ def test_should_respond_200(self, session): "total_entries": 2, } == response.json - def test_only_active_true_returns_active_dags(self): + def test_only_active_true_returns_active_dags(self, url_safe_serializer): self._create_dag_models(1) self._create_deactivated_dag() response = self.client.get("api/v1/dags?only_active=True", environ_overrides={'REMOTE_USER': "test"}) - file_token = SERIALIZER.dumps("/tmp/dag_1.py") + file_token = url_safe_serializer.dumps("/tmp/dag_1.py") assert response.status_code == 200 assert { "dags": [ @@ -619,12 +621,12 @@ def test_only_active_true_returns_active_dags(self): "total_entries": 1, } == response.json - def test_only_active_false_returns_all_dags(self): + def test_only_active_false_returns_all_dags(self, url_safe_serializer): self._create_dag_models(1) self._create_deactivated_dag() response = self.client.get("api/v1/dags?only_active=False", environ_overrides={'REMOTE_USER': "test"}) - file_token = SERIALIZER.dumps("/tmp/dag_1.py") - file_token_2 = SERIALIZER.dumps("/tmp/dag_del_1.py") + file_token = url_safe_serializer.dumps("/tmp/dag_1.py") + file_token_2 = url_safe_serializer.dumps("/tmp/dag_del_1.py") assert response.status_code == 200 assert { "dags": [ @@ -819,10 +821,8 @@ def test_should_respond_403_unauthorized(self): class TestPatchDag(TestDagEndpoint): - - file_token = SERIALIZER.dumps("/tmp/dag_1.py") - - def test_should_respond_200_on_patch_is_paused(self): + def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer): + file_token = url_safe_serializer.dumps("/tmp/dag_1.py") dag_model = self._create_dag_model() response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}", @@ -832,12 +832,11 @@ def test_should_respond_200_on_patch_is_paused(self): environ_overrides={'REMOTE_USER': "test"}, ) assert response.status_code == 200 - expected_response = { "dag_id": "TEST_DAG_1", "description": None, "fileloc": "/tmp/dag_1.py", - "file_token": self.file_token, + "file_token": file_token, "is_paused": False, "is_active": False, "is_subdag": False, @@ -918,7 +917,8 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) - def test_should_respond_200_with_update_mask(self): + def test_should_respond_200_with_update_mask(self, url_safe_serializer): + file_token = url_safe_serializer.dumps("/tmp/dag_1.py") dag_model = self._create_dag_model() payload = { "is_paused": False, @@ -934,7 +934,7 @@ def test_should_respond_200_with_update_mask(self): "dag_id": "TEST_DAG_1", "description": None, "fileloc": "/tmp/dag_1.py", - "file_token": self.file_token, + "file_token": file_token, "is_paused": False, "is_active": False, "is_subdag": False, @@ -1006,12 +1006,10 @@ def test_should_respond_403_unauthorized(self): class TestPatchDags(TestDagEndpoint): - - file_token = SERIALIZER.dumps("/tmp/dag_1.py") - file_token2 = SERIALIZER.dumps("/tmp/dag_2.py") - @provide_session - def test_should_respond_200_on_patch_is_paused(self, session): + def test_should_respond_200_on_patch_is_paused(self, session, url_safe_serializer): + file_token = url_safe_serializer.dumps("/tmp/dag_1.py") + file_token2 = url_safe_serializer.dumps("/tmp/dag_2.py") self._create_dag_models(2) self._create_deactivated_dag() @@ -1033,7 +1031,7 @@ def test_should_respond_200_on_patch_is_paused(self, session): "dag_id": "TEST_DAG_1", "description": None, "fileloc": "/tmp/dag_1.py", - "file_token": self.file_token, + "file_token": file_token, "is_paused": False, "is_active": True, "is_subdag": False, @@ -1064,7 +1062,7 @@ def test_should_respond_200_on_patch_is_paused(self, session): "dag_id": "TEST_DAG_2", "description": None, "fileloc": "/tmp/dag_2.py", - "file_token": self.file_token2, + "file_token": file_token2, "is_paused": False, "is_active": True, "is_subdag": False, @@ -1095,7 +1093,8 @@ def test_should_respond_200_on_patch_is_paused(self, session): "total_entries": 2, } == response.json - def test_only_active_true_returns_active_dags(self): + def test_only_active_true_returns_active_dags(self, url_safe_serializer): + file_token = url_safe_serializer.dumps("/tmp/dag_1.py") self._create_dag_models(1) self._create_deactivated_dag() response = self.client.patch( @@ -1112,7 +1111,7 @@ def test_only_active_true_returns_active_dags(self): "dag_id": "TEST_DAG_1", "description": None, "fileloc": "/tmp/dag_1.py", - "file_token": self.file_token, + "file_token": file_token, "is_paused": False, "is_active": True, "is_subdag": False, @@ -1143,7 +1142,8 @@ def test_only_active_true_returns_active_dags(self): "total_entries": 1, } == response.json - def test_only_active_false_returns_all_dags(self): + def test_only_active_false_returns_all_dags(self, url_safe_serializer): + file_token = url_safe_serializer.dumps("/tmp/dag_1.py") self._create_dag_models(1) self._create_deactivated_dag() response = self.client.patch( @@ -1154,7 +1154,7 @@ def test_only_active_false_returns_all_dags(self): environ_overrides={'REMOTE_USER': "test"}, ) - file_token_2 = SERIALIZER.dumps("/tmp/dag_del_1.py") + file_token_2 = url_safe_serializer.dumps("/tmp/dag_del_1.py") assert response.status_code == 200 assert { "dags": [ @@ -1162,7 +1162,7 @@ def test_only_active_false_returns_all_dags(self): "dag_id": "TEST_DAG_1", "description": None, "fileloc": "/tmp/dag_1.py", - "file_token": self.file_token, + "file_token": file_token, "is_paused": False, "is_active": True, "is_subdag": False, @@ -1399,7 +1399,9 @@ def test_should_respond_403_unauthorized(self): assert response.status_code == 403 - def test_should_respond_200_and_pause_dags(self): + def test_should_respond_200_and_pause_dags(self, url_safe_serializer): + file_token = url_safe_serializer.dumps("/tmp/dag_1.py") + file_token2 = url_safe_serializer.dumps("/tmp/dag_2.py") self._create_dag_models(2) response = self.client.patch( @@ -1417,7 +1419,7 @@ def test_should_respond_200_and_pause_dags(self): "dag_id": "TEST_DAG_1", "description": None, "fileloc": "/tmp/dag_1.py", - "file_token": self.file_token, + "file_token": file_token, "is_paused": True, "is_active": True, "is_subdag": False, @@ -1448,7 +1450,7 @@ def test_should_respond_200_and_pause_dags(self): "dag_id": "TEST_DAG_2", "description": None, "fileloc": "/tmp/dag_2.py", - "file_token": self.file_token2, + "file_token": file_token2, "is_paused": True, "is_active": True, "is_subdag": False, @@ -1480,9 +1482,10 @@ def test_should_respond_200_and_pause_dags(self): } == response.json @provide_session - def test_should_respond_200_and_pause_dag_pattern(self, session): + def test_should_respond_200_and_pause_dag_pattern(self, session, url_safe_serializer): + file_token = url_safe_serializer.dumps("/tmp/dag_1.py") self._create_dag_models(10) - file_token10 = SERIALIZER.dumps("/tmp/dag_10.py") + file_token10 = url_safe_serializer.dumps("/tmp/dag_10.py") response = self.client.patch( "/api/v1/dags?dag_id_pattern=TEST_DAG_1", @@ -1499,7 +1502,7 @@ def test_should_respond_200_and_pause_dag_pattern(self, session): "dag_id": "TEST_DAG_1", "description": None, "fileloc": "/tmp/dag_1.py", - "file_token": self.file_token, + "file_token": file_token, "is_paused": True, "is_active": True, "is_subdag": False, diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index 5fec9f27989c9..cb8b5a9c11e57 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -19,10 +19,8 @@ from typing import Optional import pytest -from itsdangerous import URLSafeSerializer from airflow import DAG -from airflow.configuration import conf from airflow.models import DagBag from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user @@ -73,15 +71,14 @@ def _get_dag_file_docstring(fileloc: str) -> Optional[str]: docstring = ast.get_docstring(module) return docstring - def test_should_respond_200_text(self): - serializer = URLSafeSerializer(conf.get('webserver', 'SECRET_KEY')) + def test_should_respond_200_text(self, url_safe_serializer): dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) dagbag.sync_to_db() first_dag: DAG = next(iter(dagbag.dags.values())) dag_docstring = self._get_dag_file_docstring(first_dag.fileloc) - url = f"/api/v1/dagSources/{serializer.dumps(first_dag.fileloc)}" + url = f"/api/v1/dagSources/{url_safe_serializer.dumps(first_dag.fileloc)}" response = self.client.get( url, headers={"Accept": "text/plain"}, environ_overrides={'REMOTE_USER': "test"} ) @@ -90,14 +87,13 @@ def test_should_respond_200_text(self): assert dag_docstring in response.data.decode() assert 'text/plain' == response.headers['Content-Type'] - def test_should_respond_200_json(self): - serializer = URLSafeSerializer(conf.get('webserver', 'SECRET_KEY')) + def test_should_respond_200_json(self, url_safe_serializer): dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) dagbag.sync_to_db() first_dag: DAG = next(iter(dagbag.dags.values())) dag_docstring = self._get_dag_file_docstring(first_dag.fileloc) - url = f"/api/v1/dagSources/{serializer.dumps(first_dag.fileloc)}" + url = f"/api/v1/dagSources/{url_safe_serializer.dumps(first_dag.fileloc)}" response = self.client.get( url, headers={"Accept": 'application/json'}, environ_overrides={'REMOTE_USER': "test"} ) @@ -106,13 +102,12 @@ def test_should_respond_200_json(self): assert dag_docstring in response.json['content'] assert 'application/json' == response.headers['Content-Type'] - def test_should_respond_406(self): - serializer = URLSafeSerializer(conf.get('webserver', 'SECRET_KEY')) + def test_should_respond_406(self, url_safe_serializer): dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) dagbag.sync_to_db() first_dag: DAG = next(iter(dagbag.dags.values())) - url = f"/api/v1/dagSources/{serializer.dumps(first_dag.fileloc)}" + url = f"/api/v1/dagSources/{url_safe_serializer.dumps(first_dag.fileloc)}" response = self.client.get( url, headers={"Accept": 'image/webp'}, environ_overrides={'REMOTE_USER': "test"} ) @@ -128,27 +123,25 @@ def test_should_respond_404(self): assert 404 == response.status_code - def test_should_raises_401_unauthenticated(self): - serializer = URLSafeSerializer(conf.get('webserver', 'SECRET_KEY')) + def test_should_raises_401_unauthenticated(self, url_safe_serializer): dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) dagbag.sync_to_db() first_dag: DAG = next(iter(dagbag.dags.values())) response = self.client.get( - f"/api/v1/dagSources/{serializer.dumps(first_dag.fileloc)}", + f"/api/v1/dagSources/{url_safe_serializer.dumps(first_dag.fileloc)}", headers={"Accept": "text/plain"}, ) assert_401(response) - def test_should_raise_403_forbidden(self): - serializer = URLSafeSerializer(conf.get('webserver', 'SECRET_KEY')) + def test_should_raise_403_forbidden(self, url_safe_serializer): dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) dagbag.sync_to_db() first_dag: DAG = next(iter(dagbag.dags.values())) response = self.client.get( - f"/api/v1/dagSources/{serializer.dumps(first_dag.fileloc)}", + f"/api/v1/dagSources/{url_safe_serializer.dumps(first_dag.fileloc)}", headers={"Accept": "text/plain"}, environ_overrides={'REMOTE_USER': "test_no_permissions"}, ) diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index 7c4452dc2bb63..efcba3271188d 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -185,7 +185,7 @@ def test_should_respond_200(self): response_data = response.json for xcom_entry in response_data['xcom_entries']: xcom_entry['timestamp'] = "TIMESTAMP" - assert response.json == { + assert response_data == { 'xcom_entries': [ { 'dag_id': dag_id, @@ -227,7 +227,7 @@ def test_should_respond_200_with_tilde_and_access_to_all_dags(self): response_data = response.json for xcom_entry in response_data['xcom_entries']: xcom_entry['timestamp'] = "TIMESTAMP" - assert response.json == { + assert response_data == { 'xcom_entries': [ { 'dag_id': dag_id_1, @@ -283,7 +283,7 @@ def test_should_respond_200_with_tilde_and_granular_dag_access(self): response_data = response.json for xcom_entry in response_data['xcom_entries']: xcom_entry['timestamp'] = "TIMESTAMP" - assert response.json == { + assert response_data == { 'xcom_entries': [ { 'dag_id': dag_id_1, diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py index ca7e04ae89226..040ce41f1d066 100644 --- a/tests/api_connexion/schemas/test_dag_schema.py +++ b/tests/api_connexion/schemas/test_dag_schema.py @@ -15,11 +15,8 @@ # specific language governing permissions and limitations # under the License. -import unittest from datetime import datetime -from itsdangerous import URLSafeSerializer - from airflow import DAG from airflow.api_connexion.schemas.dag_schema import ( DAGCollection, @@ -27,174 +24,168 @@ DAGDetailSchema, DAGSchema, ) -from airflow.configuration import conf from airflow.models import DagModel, DagTag -SERIALIZER = URLSafeSerializer(conf.get('webserver', 'SECRET_KEY')) - -class TestDagSchema(unittest.TestCase): - def test_serialize(self): - dag_model = DagModel( - dag_id="test_dag_id", - root_dag_id="test_root_dag_id", - is_paused=True, - is_active=True, - is_subdag=False, - fileloc="/root/airflow/dags/my_dag.py", - owners="airflow1,airflow2", - description="The description", - schedule_interval="5 4 * * *", - tags=[DagTag(name="tag-1"), DagTag(name="tag-2")], - ) - serialized_dag = DAGSchema().dump(dag_model) +def test_serialize_test_dag_schema(url_safe_serializer): + dag_model = DagModel( + dag_id="test_dag_id", + root_dag_id="test_root_dag_id", + is_paused=True, + is_active=True, + is_subdag=False, + fileloc="/root/airflow/dags/my_dag.py", + owners="airflow1,airflow2", + description="The description", + schedule_interval="5 4 * * *", + tags=[DagTag(name="tag-1"), DagTag(name="tag-2")], + ) + serialized_dag = DAGSchema().dump(dag_model) - assert { - "dag_id": "test_dag_id", - "description": "The description", - "fileloc": "/root/airflow/dags/my_dag.py", - "file_token": SERIALIZER.dumps("/root/airflow/dags/my_dag.py"), - "is_paused": True, - "is_active": True, - "is_subdag": False, - "owners": ["airflow1", "airflow2"], - "root_dag_id": "test_root_dag_id", - "schedule_interval": {"__type": "CronExpression", "value": "5 4 * * *"}, - "tags": [{"name": "tag-1"}, {"name": "tag-2"}], - 'next_dagrun': None, - 'has_task_concurrency_limits': True, - 'next_dagrun_data_interval_start': None, - 'next_dagrun_data_interval_end': None, - 'max_active_runs': 16, - 'next_dagrun_create_after': None, - 'last_expired': None, - 'max_active_tasks': 16, - 'last_pickled': None, - 'default_view': None, - 'last_parsed_time': None, - 'scheduler_lock': None, - 'timetable_description': None, - 'has_import_errors': None, - 'pickle_id': None, - } == serialized_dag + assert { + "dag_id": "test_dag_id", + "description": "The description", + "fileloc": "/root/airflow/dags/my_dag.py", + "file_token": url_safe_serializer.dumps("/root/airflow/dags/my_dag.py"), + "is_paused": True, + "is_active": True, + "is_subdag": False, + "owners": ["airflow1", "airflow2"], + "root_dag_id": "test_root_dag_id", + "schedule_interval": {"__type": "CronExpression", "value": "5 4 * * *"}, + "tags": [{"name": "tag-1"}, {"name": "tag-2"}], + 'next_dagrun': None, + 'has_task_concurrency_limits': True, + 'next_dagrun_data_interval_start': None, + 'next_dagrun_data_interval_end': None, + 'max_active_runs': 16, + 'next_dagrun_create_after': None, + 'last_expired': None, + 'max_active_tasks': 16, + 'last_pickled': None, + 'default_view': None, + 'last_parsed_time': None, + 'scheduler_lock': None, + 'timetable_description': None, + 'has_import_errors': None, + 'pickle_id': None, + } == serialized_dag -class TestDAGCollectionSchema(unittest.TestCase): - def test_serialize(self): - dag_model_a = DagModel(dag_id="test_dag_id_a", fileloc="/tmp/a.py") - dag_model_b = DagModel(dag_id="test_dag_id_b", fileloc="/tmp/a.py") - schema = DAGCollectionSchema() - instance = DAGCollection(dags=[dag_model_a, dag_model_b], total_entries=2) - assert { - "dags": [ - { - "dag_id": "test_dag_id_a", - "description": None, - "fileloc": "/tmp/a.py", - "file_token": SERIALIZER.dumps("/tmp/a.py"), - "is_paused": None, - "is_subdag": None, - "is_active": None, - "owners": [], - "root_dag_id": None, - "schedule_interval": None, - "tags": [], - 'next_dagrun': None, - 'has_task_concurrency_limits': True, - 'next_dagrun_data_interval_start': None, - 'next_dagrun_data_interval_end': None, - 'max_active_runs': 16, - 'next_dagrun_create_after': None, - 'last_expired': None, - 'max_active_tasks': 16, - 'last_pickled': None, - 'default_view': None, - 'last_parsed_time': None, - 'scheduler_lock': None, - 'timetable_description': None, - 'has_import_errors': None, - 'pickle_id': None, - }, - { - "dag_id": "test_dag_id_b", - "description": None, - "fileloc": "/tmp/a.py", - "file_token": SERIALIZER.dumps("/tmp/a.py"), - "is_active": None, - "is_paused": None, - "is_subdag": None, - "owners": [], - "root_dag_id": None, - "schedule_interval": None, - "tags": [], - 'next_dagrun': None, - 'has_task_concurrency_limits': True, - 'next_dagrun_data_interval_start': None, - 'next_dagrun_data_interval_end': None, - 'max_active_runs': 16, - 'next_dagrun_create_after': None, - 'last_expired': None, - 'max_active_tasks': 16, - 'last_pickled': None, - 'default_view': None, - 'last_parsed_time': None, - 'scheduler_lock': None, - 'timetable_description': None, - 'has_import_errors': None, - 'pickle_id': None, - }, - ], - "total_entries": 2, - } == schema.dump(instance) +def test_serialize_test_dag_collection_schema(url_safe_serializer): + dag_model_a = DagModel(dag_id="test_dag_id_a", fileloc="/tmp/a.py") + dag_model_b = DagModel(dag_id="test_dag_id_b", fileloc="/tmp/a.py") + schema = DAGCollectionSchema() + instance = DAGCollection(dags=[dag_model_a, dag_model_b], total_entries=2) + assert { + "dags": [ + { + "dag_id": "test_dag_id_a", + "description": None, + "fileloc": "/tmp/a.py", + "file_token": url_safe_serializer.dumps("/tmp/a.py"), + "is_paused": None, + "is_subdag": None, + "is_active": None, + "owners": [], + "root_dag_id": None, + "schedule_interval": None, + "tags": [], + 'next_dagrun': None, + 'has_task_concurrency_limits': True, + 'next_dagrun_data_interval_start': None, + 'next_dagrun_data_interval_end': None, + 'max_active_runs': 16, + 'next_dagrun_create_after': None, + 'last_expired': None, + 'max_active_tasks': 16, + 'last_pickled': None, + 'default_view': None, + 'last_parsed_time': None, + 'scheduler_lock': None, + 'timetable_description': None, + 'has_import_errors': None, + 'pickle_id': None, + }, + { + "dag_id": "test_dag_id_b", + "description": None, + "fileloc": "/tmp/a.py", + "file_token": url_safe_serializer.dumps("/tmp/a.py"), + "is_active": None, + "is_paused": None, + "is_subdag": None, + "owners": [], + "root_dag_id": None, + "schedule_interval": None, + "tags": [], + 'next_dagrun': None, + 'has_task_concurrency_limits': True, + 'next_dagrun_data_interval_start': None, + 'next_dagrun_data_interval_end': None, + 'max_active_runs': 16, + 'next_dagrun_create_after': None, + 'last_expired': None, + 'max_active_tasks': 16, + 'last_pickled': None, + 'default_view': None, + 'last_parsed_time': None, + 'scheduler_lock': None, + 'timetable_description': None, + 'has_import_errors': None, + 'pickle_id': None, + }, + ], + "total_entries": 2, + } == schema.dump(instance) -class TestDAGDetailSchema: - def test_serialize(self): - dag = DAG( - dag_id="test_dag", - start_date=datetime(2020, 6, 19), - doc_md="docs", - orientation="LR", - default_view="duration", - params={"foo": 1}, - tags=['example1', 'example2'], - ) - schema = DAGDetailSchema() +def test_serialize_test_dag_detail_schema(url_safe_serializer): + dag = DAG( + dag_id="test_dag", + start_date=datetime(2020, 6, 19), + doc_md="docs", + orientation="LR", + default_view="duration", + params={"foo": 1}, + tags=['example1', 'example2'], + ) + schema = DAGDetailSchema() - expected = { - 'catchup': True, - 'concurrency': 16, - 'max_active_tasks': 16, - 'dag_id': 'test_dag', - 'dag_run_timeout': None, - 'default_view': 'duration', - 'description': None, - 'doc_md': 'docs', - 'fileloc': __file__, - "file_token": SERIALIZER.dumps(__file__), - "is_active": None, - 'is_paused': None, - 'is_subdag': False, - 'orientation': 'LR', - 'owners': [], - 'params': { - 'foo': { - '__class': 'airflow.models.param.Param', - 'value': 1, - 'description': None, - 'schema': {}, - } - }, - 'schedule_interval': {'__type': 'TimeDelta', 'days': 1, 'seconds': 0, 'microseconds': 0}, - 'start_date': '2020-06-19T00:00:00+00:00', - 'tags': [{'name': "example1"}, {'name': "example2"}], - 'timezone': "Timezone('UTC')", - 'max_active_runs': 16, - 'pickle_id': None, - "end_date": None, - 'is_paused_upon_creation': None, - 'render_template_as_native_obj': False, - } - obj = schema.dump(dag) - expected.update({'last_parsed': obj['last_parsed']}) - assert obj == expected + expected = { + 'catchup': True, + 'concurrency': 16, + 'max_active_tasks': 16, + 'dag_id': 'test_dag', + 'dag_run_timeout': None, + 'default_view': 'duration', + 'description': None, + 'doc_md': 'docs', + 'fileloc': __file__, + "file_token": url_safe_serializer.dumps(__file__), + "is_active": None, + 'is_paused': None, + 'is_subdag': False, + 'orientation': 'LR', + 'owners': [], + 'params': { + 'foo': { + '__class': 'airflow.models.param.Param', + 'value': 1, + 'description': None, + 'schema': {}, + } + }, + 'schedule_interval': {'__type': 'TimeDelta', 'days': 1, 'seconds': 0, 'microseconds': 0}, + 'start_date': '2020-06-19T00:00:00+00:00', + 'tags': [{'name': "example1"}, {'name': "example2"}], + 'timezone': "Timezone('UTC')", + 'max_active_runs': 16, + 'pickle_id': None, + "end_date": None, + 'is_paused_upon_creation': None, + 'render_template_as_native_obj': False, + } + obj = schema.dump(dag) + expected.update({'last_parsed': obj['last_parsed']}) + assert obj == expected diff --git a/tests/conftest.py b/tests/conftest.py index 468da514b0feb..68d318e13c50a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,8 @@ # We should set these before loading _any_ of the rest of airflow so that the # unit test mode config is set as early as possible. +from itsdangerous import URLSafeSerializer + assert "airflow" not in sys.modules, "No airflow module can be imported before these lines" tests_directory = os.path.dirname(os.path.realpath(__file__)) @@ -55,6 +57,28 @@ def reset_environment(): os.environ[key] = init_env[key] +@pytest.fixture() +def secret_key() -> str: + """ + Return secret key configured. + :return: + """ + from airflow.configuration import conf + + the_key = conf.get('webserver', 'SECRET_KEY') + if the_key is None: + raise RuntimeError( + "The secret key SHOULD be configured as `[webserver] secret_key` in the " + "configuration/environment at this stage! " + ) + return the_key + + +@pytest.fixture() +def url_safe_serializer(secret_key) -> URLSafeSerializer: + return URLSafeSerializer(secret_key) + + @pytest.fixture() def reset_db(): """ diff --git a/tests/test_utils/remote_user_api_auth_backend.py b/tests/test_utils/remote_user_api_auth_backend.py index 1e6a0c70adf6d..187f57a7fd114 100644 --- a/tests/test_utils/remote_user_api_auth_backend.py +++ b/tests/test_utils/remote_user_api_auth_backend.py @@ -20,10 +20,12 @@ from functools import wraps from typing import Callable, Optional, Tuple, TypeVar, Union, cast -from flask import Response, current_app, request +from flask import Response, request from flask_login import login_user from requests.auth import AuthBase +from airflow.utils.airflow_flask_app import get_airflow_app + log = logging.getLogger(__name__) CLIENT_AUTH: Optional[Union[Tuple[str, str], AuthBase]] = None @@ -37,7 +39,7 @@ def init_app(_): def _lookup_user(user_email_or_username: str): - security_manager = current_app.appbuilder.sm + security_manager = get_airflow_app().appbuilder.sm user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user( username=user_email_or_username ) diff --git a/tests/utils/test_serve_logs.py b/tests/utils/test_serve_logs.py index f8d38817592b8..ebbc663a2cb0d 100644 --- a/tests/utils/test_serve_logs.py +++ b/tests/utils/test_serve_logs.py @@ -21,7 +21,6 @@ import pytest from freezegun import freeze_time -from airflow.configuration import conf from airflow.utils.jwt_signer import JWTSigner from airflow.utils.serve_logs import create_app from tests.test_utils.config import conf_vars @@ -49,18 +48,18 @@ def sample_log(tmpdir): @pytest.fixture -def signer(): +def signer(secret_key): return JWTSigner( - secret_key=conf.get('webserver', 'secret_key'), + secret_key=secret_key, expiration_time_in_seconds=30, audience="task-instance-logs", ) @pytest.fixture -def different_audience(): +def different_audience(secret_key): return JWTSigner( - secret_key=conf.get('webserver', 'secret_key'), + secret_key=secret_key, expiration_time_in_seconds=30, audience="different-audience", ) @@ -180,7 +179,7 @@ def test_wrong_audience(self, client: "FlaskClient", different_audience): ) @pytest.mark.parametrize("claim_to_remove", ["iat", "exp", "nbf", "aud"]) - def test_missing_claims(self, claim_to_remove: str, client: "FlaskClient"): + def test_missing_claims(self, claim_to_remove: str, client: "FlaskClient", secret_key): jwt_dict = { "aud": "task-instance-logs", "iat": datetime.datetime.utcnow(), @@ -191,7 +190,7 @@ def test_missing_claims(self, claim_to_remove: str, client: "FlaskClient"): jwt_dict.update({"filename": 'sample.log'}) token = jwt.encode( jwt_dict, - conf.get('webserver', 'secret_key'), + secret_key, algorithm="HS512", ) assert ( diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py index 887bd4898a0a6..fa79e145cba6c 100644 --- a/tests/www/views/test_views.py +++ b/tests/www/views/test_views.py @@ -375,52 +375,55 @@ def test_get_task_stats_from_query(): assert data == expected_data +INVALID_DATETIME_RESPONSE = "Invalid datetime: 'invalid'" + + @pytest.mark.parametrize( "url, content", [ ( '/rendered-templates?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( '/log?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( '/redirect_to_external_log?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( '/task?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/graph?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/graph?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/duration?base_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/tries?base_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/landing-times?base_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/gantt?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'extra_links?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ], ) diff --git a/tests/www/views/test_views_decorators.py b/tests/www/views/test_views_decorators.py index 0e4fc12857a8f..1de80c1214a28 100644 --- a/tests/www/views/test_views_decorators.py +++ b/tests/www/views/test_views_decorators.py @@ -213,9 +213,9 @@ def test_action_has_dag_edit_access(create_task_instance, class_type, no_instanc else: test_items = tis if class_type == TaskInstance else [ti.get_dagrun() for ti in tis] test_items = test_items[0] if len(test_items) == 1 else test_items - - with app.create_app(testing=True).app_context(): - with mock.patch("airflow.www.views.current_app.appbuilder.sm.can_edit_dag") as mocked_can_edit: + application = app.create_app(testing=True) + with application.app_context(): + with mock.patch.object(application.appbuilder.sm, "can_edit_dag") as mocked_can_edit: mocked_can_edit.return_value = True assert not isinstance(test_items, list) or len(test_items) == no_instances assert some_view_action_which_requires_dag_edit_access(None, test_items) is True diff --git a/tests/www/views/test_views_log.py b/tests/www/views/test_views_log.py index f697cd3772c28..988d28593649c 100644 --- a/tests/www/views/test_views_log.py +++ b/tests/www/views/test_views_log.py @@ -464,7 +464,7 @@ def test_redirect_to_external_log_with_local_log_handler(log_admin_client, task_ ) response = log_admin_client.get(url) assert 302 == response.status_code - assert 'http://localhost/home' == response.headers['Location'] + assert '/home' == response.headers['Location'] class _ExternalHandler(ExternalLoggingMixin): diff --git a/tests/www/views/test_views_mount.py b/tests/www/views/test_views_mount.py index a9fb8746657df..3f504e9b0f168 100644 --- a/tests/www/views/test_views_mount.py +++ b/tests/www/views/test_views_mount.py @@ -36,7 +36,7 @@ def factory(): @pytest.fixture() def client(app): - return werkzeug.test.Client(app, werkzeug.wrappers.BaseResponse) + return werkzeug.test.Client(app, werkzeug.wrappers.response.Response) def test_mount(client): @@ -54,4 +54,4 @@ def test_not_found(client): def test_index(client): resp = client.get('/test/') assert resp.status_code == 302 - assert resp.headers['Location'] == 'http://localhost/test/home' + assert resp.headers['Location'] == '/test/home'