From 601649c1405050d444e04efaf6045d2b24d8d569 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 10 Jun 2023 16:30:24 +0200 Subject: [PATCH 1/4] Remove right trailing / from webserver base_url Signed-off-by: Hussein Awala --- airflow/models/taskinstance.py | 4 +-- .../www/extensions/init_wsgi_middlewares.py | 2 +- tests/www/test_app.py | 34 +++++++++++++------ 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 28dc168ec3a25..258a1067dd993 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -758,7 +758,7 @@ def generate_command( def log_url(self) -> str: """Log URL for TaskInstance.""" iso = quote(self.execution_date.isoformat()) - base_url = conf.get_mandatory_value("webserver", "BASE_URL") + base_url = conf.get_mandatory_value("webserver", "BASE_URL").rstrip("/") return ( f"{base_url}/log" f"?execution_date={iso}" @@ -770,7 +770,7 @@ def log_url(self) -> str: @property def mark_success_url(self) -> str: """URL to mark TI success.""" - base_url = conf.get_mandatory_value("webserver", "BASE_URL") + base_url = conf.get_mandatory_value("webserver", "BASE_URL").rstrip("/") return ( f"{base_url}/confirm" f"?task_id={self.task_id}" diff --git a/airflow/www/extensions/init_wsgi_middlewares.py b/airflow/www/extensions/init_wsgi_middlewares.py index 3ea47e92c86f6..3def382c3158c 100644 --- a/airflow/www/extensions/init_wsgi_middlewares.py +++ b/airflow/www/extensions/init_wsgi_middlewares.py @@ -38,7 +38,7 @@ def _root_app(env: WSGIEnvironment, resp: StartResponse) -> Iterable[bytes]: def init_wsgi_middleware(flask_app: Flask) -> None: """Handle X-Forwarded-* headers and base_url support.""" # Apply DispatcherMiddleware - base_url = urlsplit(conf.get("webserver", "base_url"))[2] + base_url = urlsplit(conf.get("webserver", "base_url").rstrip("/"))[2] if not base_url or base_url == "/": base_url = "" if base_url: diff --git a/tests/www/test_app.py b/tests/www/test_app.py index 61bfae3e12a87..5f5b54260e7c5 100644 --- a/tests/www/test_app.py +++ b/tests/www/test_app.py @@ -86,15 +86,18 @@ def debug_view(): assert b"success" == response.get_data() assert response.status_code == 200 - @conf_vars( - { - ("webserver", "base_url"): "http://localhost:8080/internal-client", - } + @pytest.mark.parametrize( + "base_url", + [ + "http://localhost:8080/internal-client", + "http://localhost:8080/internal-client/", + ], ) @dont_initialize_flask_app_submodules - def test_should_respect_base_url_ignore_proxy_headers(self): - app = application.cached_app(testing=True) - app.url_map.add(Rule("/debug", endpoint="debug")) + def test_should_respect_base_url_ignore_proxy_headers(self, base_url): + with conf_vars({("webserver", "base_url"): base_url}): + app = application.cached_app(testing=True) + app.url_map.add(Rule("/debug", endpoint="debug")) def debug_view(): from flask import request @@ -126,9 +129,15 @@ def debug_view(): assert b"success" == response.get_data() assert response.status_code == 200 + @pytest.mark.parametrize( + "base_url", + [ + "http://localhost:8080/internal-client", + "http://localhost:8080/internal-client/", + ], + ) @conf_vars( { - ("webserver", "base_url"): "http://localhost:8080/internal-client", ("webserver", "enable_proxy_fix"): "True", ("webserver", "proxy_fix_x_for"): "1", ("webserver", "proxy_fix_x_proto"): "1", @@ -138,9 +147,12 @@ def debug_view(): } ) @dont_initialize_flask_app_submodules - def test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing(self): - app = application.cached_app(testing=True) - app.url_map.add(Rule("/debug", endpoint="debug")) + def test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing( + self, base_url + ): + with conf_vars({("webserver", "base_url"): base_url}): + app = application.cached_app(testing=True) + app.url_map.add(Rule("/debug", endpoint="debug")) def debug_view(): from flask import request From 0480977a7171e772e0dc52b129cb0c1cb6f7ffe9 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 17 Jun 2023 17:22:49 +0200 Subject: [PATCH 2/4] use url join instead of removing trailing slash Signed-off-by: Hussein Awala --- airflow/models/taskinstance.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 258a1067dd993..f429e27e1546f 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -33,7 +33,7 @@ from pathlib import PurePath from types import TracebackType from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Tuple -from urllib.parse import quote +from urllib.parse import quote, urljoin import dill import jinja2 @@ -758,27 +758,27 @@ def generate_command( def log_url(self) -> str: """Log URL for TaskInstance.""" iso = quote(self.execution_date.isoformat()) - base_url = conf.get_mandatory_value("webserver", "BASE_URL").rstrip("/") - return ( - f"{base_url}/log" - f"?execution_date={iso}" + base_url = conf.get_mandatory_value("webserver", "BASE_URL") + return urljoin( + base_url, + f"log?execution_date={iso}" f"&task_id={self.task_id}" f"&dag_id={self.dag_id}" - f"&map_index={self.map_index}" + f"&map_index={self.map_index}", ) @property def mark_success_url(self) -> str: """URL to mark TI success.""" - base_url = conf.get_mandatory_value("webserver", "BASE_URL").rstrip("/") - return ( - f"{base_url}/confirm" - f"?task_id={self.task_id}" + base_url = conf.get_mandatory_value("webserver", "BASE_URL") + return urljoin( + base_url, + f"confirm?task_id={self.task_id}" f"&dag_id={self.dag_id}" f"&dag_run_id={quote(self.run_id)}" "&upstream=false" "&downstream=false" - "&state=success" + "&state=success", ) @provide_session From e75175fc286c9075306fb5ba6abe53cc18a48cc6 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 17 Jun 2023 21:45:26 +0200 Subject: [PATCH 3/4] raise an exception when base_url contains a trailing slash Signed-off-by: Hussein Awala --- .../www/extensions/init_wsgi_middlewares.py | 6 +++- tests/www/test_app.py | 33 ++++++++++++++----- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/airflow/www/extensions/init_wsgi_middlewares.py b/airflow/www/extensions/init_wsgi_middlewares.py index 3def382c3158c..434a872f657e2 100644 --- a/airflow/www/extensions/init_wsgi_middlewares.py +++ b/airflow/www/extensions/init_wsgi_middlewares.py @@ -25,6 +25,7 @@ from werkzeug.middleware.proxy_fix import ProxyFix from airflow.configuration import conf +from airflow.exceptions import AirflowConfigException if TYPE_CHECKING: from _typeshed.wsgi import StartResponse, WSGIEnvironment @@ -37,8 +38,11 @@ def _root_app(env: WSGIEnvironment, resp: StartResponse) -> Iterable[bytes]: def init_wsgi_middleware(flask_app: Flask) -> None: """Handle X-Forwarded-* headers and base_url support.""" + webserver_base_url = conf.get("webserver", "BASE_URL") + if webserver_base_url.endswith("/"): + raise AirflowConfigException("webserver.base_url conf cannot have a trailing slash.") # Apply DispatcherMiddleware - base_url = urlsplit(conf.get("webserver", "base_url").rstrip("/"))[2] + base_url = urlsplit(webserver_base_url)[2] if not base_url or base_url == "/": base_url = "" if base_url: diff --git a/tests/www/test_app.py b/tests/www/test_app.py index 5f5b54260e7c5..8dd3b57b2e394 100644 --- a/tests/www/test_app.py +++ b/tests/www/test_app.py @@ -18,6 +18,7 @@ from __future__ import annotations import hashlib +import re import runpy import sys from datetime import timedelta @@ -87,15 +88,23 @@ def debug_view(): assert response.status_code == 200 @pytest.mark.parametrize( - "base_url", + "base_url, expected_exception", [ - "http://localhost:8080/internal-client", - "http://localhost:8080/internal-client/", + ("http://localhost:8080/internal-client", None), + ( + "http://localhost:8080/internal-client/", + AirflowConfigException("webserver.base_url conf cannot have a trailing slash."), + ), ], ) @dont_initialize_flask_app_submodules - def test_should_respect_base_url_ignore_proxy_headers(self, base_url): + def test_should_respect_base_url_ignore_proxy_headers(self, base_url, expected_exception): with conf_vars({("webserver", "base_url"): base_url}): + if expected_exception: + with pytest.raises(expected_exception.__class__, match=re.escape(str(expected_exception))): + app = application.cached_app(testing=True) + app.url_map.add(Rule("/debug", endpoint="debug")) + return app = application.cached_app(testing=True) app.url_map.add(Rule("/debug", endpoint="debug")) @@ -130,10 +139,13 @@ def debug_view(): assert response.status_code == 200 @pytest.mark.parametrize( - "base_url", + "base_url, expected_exception", [ - "http://localhost:8080/internal-client", - "http://localhost:8080/internal-client/", + ("http://localhost:8080/internal-client", None), + ( + "http://localhost:8080/internal-client/", + AirflowConfigException("webserver.base_url conf cannot have a trailing slash."), + ), ], ) @conf_vars( @@ -148,9 +160,14 @@ def debug_view(): ) @dont_initialize_flask_app_submodules def test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing( - self, base_url + self, base_url, expected_exception ): with conf_vars({("webserver", "base_url"): base_url}): + if expected_exception: + with pytest.raises(expected_exception.__class__, match=re.escape(str(expected_exception))): + app = application.cached_app(testing=True) + app.url_map.add(Rule("/debug", endpoint="debug")) + return app = application.cached_app(testing=True) app.url_map.add(Rule("/debug", endpoint="debug")) From f697a33735adeba9b3d0bdb3aac6f7d3be237a65 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Mon, 19 Jun 2023 08:40:11 +0200 Subject: [PATCH 4/4] Update airflow/www/extensions/init_wsgi_middlewares.py Co-authored-by: Tzu-ping Chung --- airflow/www/extensions/init_wsgi_middlewares.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/www/extensions/init_wsgi_middlewares.py b/airflow/www/extensions/init_wsgi_middlewares.py index 434a872f657e2..37d4a074b4d18 100644 --- a/airflow/www/extensions/init_wsgi_middlewares.py +++ b/airflow/www/extensions/init_wsgi_middlewares.py @@ -38,7 +38,7 @@ def _root_app(env: WSGIEnvironment, resp: StartResponse) -> Iterable[bytes]: def init_wsgi_middleware(flask_app: Flask) -> None: """Handle X-Forwarded-* headers and base_url support.""" - webserver_base_url = conf.get("webserver", "BASE_URL") + webserver_base_url = conf.get_mandatory_value("webserver", "BASE_URL", fallback="") if webserver_base_url.endswith("/"): raise AirflowConfigException("webserver.base_url conf cannot have a trailing slash.") # Apply DispatcherMiddleware