Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a check for trailing slash in webserver base_url #31833

Merged
merged 4 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pathlib import PurePath
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Tuple
from urllib.parse import quote
from urllib.parse import quote, urljoin

import dill
import jinja2
Expand Down Expand Up @@ -759,26 +759,26 @@ def log_url(self) -> str:
"""Log URL for TaskInstance."""
iso = quote(self.execution_date.isoformat())
base_url = conf.get_mandatory_value("webserver", "BASE_URL")
return (
f"{base_url}/log"
f"?execution_date={iso}"
return urljoin(
base_url,
f"log?execution_date={iso}"
f"&task_id={self.task_id}"
f"&dag_id={self.dag_id}"
f"&map_index={self.map_index}"
f"&map_index={self.map_index}",
)

@property
def mark_success_url(self) -> str:
"""URL to mark TI success."""
base_url = conf.get_mandatory_value("webserver", "BASE_URL")
return (
f"{base_url}/confirm"
f"?task_id={self.task_id}"
return urljoin(
base_url,
f"confirm?task_id={self.task_id}"
f"&dag_id={self.dag_id}"
f"&dag_run_id={quote(self.run_id)}"
"&upstream=false"
"&downstream=false"
"&state=success"
"&state=success",
)

@provide_session
Expand Down
6 changes: 5 additions & 1 deletion airflow/www/extensions/init_wsgi_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from werkzeug.middleware.proxy_fix import ProxyFix

from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException

if TYPE_CHECKING:
from _typeshed.wsgi import StartResponse, WSGIEnvironment
Expand All @@ -37,8 +38,11 @@ def _root_app(env: WSGIEnvironment, resp: StartResponse) -> Iterable[bytes]:

def init_wsgi_middleware(flask_app: Flask) -> None:
"""Handle X-Forwarded-* headers and base_url support."""
webserver_base_url = conf.get("webserver", "BASE_URL")
hussein-awala marked this conversation as resolved.
Show resolved Hide resolved
if webserver_base_url.endswith("/"):
raise AirflowConfigException("webserver.base_url conf cannot have a trailing slash.")
Copy link
Member

@pankajkoti pankajkoti Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering here that why can the base URL not have a trailing /?

Copy link
Member

@uranusjr uranusjr Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is due to how web servers rewrite URLs when you deploy a webapp under a prefix. Say you deploy under https://mydomain/airflow. When you visit https://mydomain/airflow/home in the browser (as an example; any client applies), the proxy server (e.g. NGINX) would strip the prefix, and forward the rest of the URL to the app, so in Python code (e.g. Airflow) we can just handle /home without needing to consider what prefix the entirety of Airflow is deployed under.

Now you may say, hey, why can’t either the proxy server, the gateway protocol (WSGI in Python’s case), or the web framework (Flask for Airflow’s case) be smarter and just strip or add the slash in between as appropriate? And you would be right! Some of them actually can do this (NGINX has merge_slashes, for example), but not everyone does since the slash is technically significant and does change the URL’s meaning (even if that meaning can be nonsensical in the prefix), and that extra check is not free. So it’s better to avoid fighting the tools and just stick to the technically correct configuration.

Copy link
Member

@pankajkoti pankajkoti Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I am coming from here on how the python urllib.parse.urljoin handles joining base URL and relative URLs.

Wouldn't https://my.astronomer.run/path qualify as a valid base URL? But looks like the urljoin just ignores the path when it does not end with a trailing /. If it was https://my.astronomer.run/path/, urljoin does not strip the path and makes the join as expected.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can handle more cases if users need it. This PR was a quick fix to prevent circular redirections.

I’m sure we can make this work with a trailing slash but requires a bit more work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Sorry, forgot to reply.)

Python’s urljoin implements the logic in <a href="..."> tags; I don’t recall the correct term for this, but it is the URL resolution logic when you click on links in a browser. URLs have two kinds, folder and file (not the right terms, but easier to understand). The URL would not have a trailing slash if you are in a file view; the trailing slash indicates a folder view. When you’re in a file, a relative path like foo indicates another file in the same folder. When you’re in a folder view, however, foo means the foo entry in the folder. This means the resolution logic would change depending on whether the URL has a trailing slash or not.

Note that Python’s local path joining logic (both pathlib and os.path) is different and does not consider the trailing slash; it instead matches the common logic used to join format in shell scripts.

But the prefix is an entirely different thing, and is intended to only be joined with simple string concatenation and skips all the folder/file logic because it is not semantically possible to support combining the prefix with absolute paths and relative paths. The only proper way to prepend a prefix is prefix + path (or other string concatenation methods such as f"{prefix}{path}", of course).

Copy link
Member

@pankajkoti pankajkoti Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understood, thank you @uranusjr for the detailed explanation and propsoal.

# Apply DispatcherMiddleware
base_url = urlsplit(conf.get("webserver", "base_url"))[2]
base_url = urlsplit(webserver_base_url)[2]
if not base_url or base_url == "/":
base_url = ""
if base_url:
Expand Down
51 changes: 40 additions & 11 deletions tests/www/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import hashlib
import re
import runpy
import sys
from datetime import timedelta
Expand Down Expand Up @@ -86,15 +87,26 @@ def debug_view():
assert b"success" == response.get_data()
assert response.status_code == 200

@conf_vars(
{
("webserver", "base_url"): "http://localhost:8080/internal-client",
}
@pytest.mark.parametrize(
"base_url, expected_exception",
[
("http://localhost:8080/internal-client", None),
(
"http://localhost:8080/internal-client/",
AirflowConfigException("webserver.base_url conf cannot have a trailing slash."),
),
],
)
@dont_initialize_flask_app_submodules
def test_should_respect_base_url_ignore_proxy_headers(self):
app = application.cached_app(testing=True)
app.url_map.add(Rule("/debug", endpoint="debug"))
def test_should_respect_base_url_ignore_proxy_headers(self, base_url, expected_exception):
with conf_vars({("webserver", "base_url"): base_url}):
if expected_exception:
with pytest.raises(expected_exception.__class__, match=re.escape(str(expected_exception))):
app = application.cached_app(testing=True)
app.url_map.add(Rule("/debug", endpoint="debug"))
return
app = application.cached_app(testing=True)
app.url_map.add(Rule("/debug", endpoint="debug"))

def debug_view():
from flask import request
Expand Down Expand Up @@ -126,9 +138,18 @@ def debug_view():
assert b"success" == response.get_data()
assert response.status_code == 200

@pytest.mark.parametrize(
"base_url, expected_exception",
[
("http://localhost:8080/internal-client", None),
(
"http://localhost:8080/internal-client/",
AirflowConfigException("webserver.base_url conf cannot have a trailing slash."),
),
],
)
@conf_vars(
{
("webserver", "base_url"): "http://localhost:8080/internal-client",
("webserver", "enable_proxy_fix"): "True",
("webserver", "proxy_fix_x_for"): "1",
("webserver", "proxy_fix_x_proto"): "1",
Expand All @@ -138,9 +159,17 @@ def debug_view():
}
)
@dont_initialize_flask_app_submodules
def test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing(self):
app = application.cached_app(testing=True)
app.url_map.add(Rule("/debug", endpoint="debug"))
def test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing(
self, base_url, expected_exception
):
with conf_vars({("webserver", "base_url"): base_url}):
if expected_exception:
with pytest.raises(expected_exception.__class__, match=re.escape(str(expected_exception))):
app = application.cached_app(testing=True)
app.url_map.add(Rule("/debug", endpoint="debug"))
return
app = application.cached_app(testing=True)
app.url_map.add(Rule("/debug", endpoint="debug"))

def debug_view():
from flask import request
Expand Down