diff --git a/airflow/utils/jwt_signer.py b/airflow/utils/jwt_signer.py new file mode 100644 index 0000000000000..941a3d05981ce --- /dev/null +++ b/airflow/utils/jwt_signer.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime, timedelta +from typing import Any, Dict + +import jwt + + +class JWTSigner: + """ + Signs and verifies JWT Token. Used to authorise and verify requests. + + :param secret_key: key used to sign the request + :param expiration_time_in_seconds: time after which the token becomes invalid (in seconds) + :param audience: audience that the request is expected to have + :param leeway_in_seconds: leeway that allows for a small clock skew between the two parties + :param algorithm: algorithm used for signing + """ + + def __init__( + self, + secret_key: str, + expiration_time_in_seconds: int, + audience: str, + leeway_in_seconds: int = 5, + algorithm: str = "HS512", + ): + self._secret_key = secret_key + self._expiration_time_in_seconds = expiration_time_in_seconds + self._audience = audience + self._leeway_in_seconds = leeway_in_seconds + self._algorithm = algorithm + + def generate_signed_token(self, extra_payload: Dict[str, Any]) -> str: + """ + Generate JWT with extra payload added. + :param extra_payload: extra payload that is added to the signed token + :return: signed token + """ + jwt_dict = { + "aud": self._audience, + "iat": datetime.utcnow(), + "nbf": datetime.utcnow(), + "exp": datetime.utcnow() + timedelta(seconds=self._expiration_time_in_seconds), + } + jwt_dict.update(extra_payload) + token = jwt.encode( + jwt_dict, + self._secret_key, + algorithm=self._algorithm, + ) + return token + + def verify_token(self, token: str) -> Dict[str, Any]: + payload = jwt.decode( + token, + self._secret_key, + leeway=timedelta(seconds=self._leeway_in_seconds), + algorithms=[self._algorithm], + options={ + "verify_signature": True, + "require_exp": True, + "require_iat": True, + "require_nbf": True, + }, + audience=self._audience, + ) + return payload diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index db34ea5f6b5b6..2c53529a72dc0 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -23,11 +23,10 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional, Tuple -from itsdangerous import TimedJSONWebSignatureSerializer - from airflow.configuration import AirflowConfigException, conf from airflow.utils.context import Context from airflow.utils.helpers import parse_template_string, render_template_to_string +from airflow.utils.jwt_signer import JWTSigner from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler from airflow.utils.session import create_session @@ -201,16 +200,17 @@ def _read(self, ti, try_number, metadata=None): except (AirflowConfigException, ValueError): pass - signer = TimedJSONWebSignatureSerializer( + signer = JWTSigner( secret_key=conf.get('webserver', 'secret_key'), - algorithm_name='HS512', - expires_in=conf.getint('webserver', 'log_request_clock_grace', fallback=30), - # This isn't really a "salt", more of a signing context - salt='task-instance-logs', + expiration_time_in_seconds=conf.getint( + 'webserver', 'log_request_clock_grace', fallback=30 + ), + audience="task-instance-logs", ) - response = httpx.get( - url, timeout=timeout, headers={'Authorization': signer.dumps(log_relative_path)} + url, + timeout=timeout, + headers={b'Authorization': signer.generate_signed_token({"filename": log_relative_path})}, ) response.encoding = "utf-8" diff --git a/airflow/utils/serve_logs.py b/airflow/utils/serve_logs.py index 50fdb47a024a6..e14162178b182 100644 --- a/airflow/utils/serve_logs.py +++ b/airflow/utils/serve_logs.py @@ -16,55 +16,89 @@ # under the License. """Serve logs process""" +import logging import os -import time import gunicorn.app.base from flask import Flask, abort, request, send_from_directory -from itsdangerous import TimedJSONWebSignatureSerializer +from jwt.exceptions import ( + ExpiredSignatureError, + ImmatureSignatureError, + InvalidAudienceError, + InvalidIssuedAtError, + InvalidSignatureError, +) from setproctitle import setproctitle from airflow.configuration import conf +from airflow.utils.docs import get_docs_url +from airflow.utils.jwt_signer import JWTSigner + +logger = logging.getLogger(__name__) def create_app(): flask_app = Flask(__name__, static_folder=None) - max_request_age = conf.getint('webserver', 'log_request_clock_grace', fallback=30) + expiration_time_in_seconds = conf.getint('webserver', 'log_request_clock_grace', fallback=30) log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER')) - signer = TimedJSONWebSignatureSerializer( + signer = JWTSigner( secret_key=conf.get('webserver', 'secret_key'), - algorithm_name='HS512', - expires_in=max_request_age, - # This isn't really a "salt", more of a signing context - salt='task-instance-logs', + expiration_time_in_seconds=expiration_time_in_seconds, + audience="task-instance-logs", ) # Prevent direct access to the logs port @flask_app.before_request def validate_pre_signed_url(): try: - auth = request.headers['Authorization'] - - # We don't actually care about the payload, just that the signature - # was valid and the `exp` claim is correct - filename, headers = signer.loads(auth, return_header=True) - - issued_at = int(headers['iat']) - expires_at = int(headers['exp']) - except Exception: + auth = request.headers.get('Authorization') + if auth is None: + logger.warning("The Authorization header is missing: %s.", request.headers) + abort(403) + payload = signer.verify_token(auth) + token_filename = payload.get("filename") + request_filename = request.view_args['filename'] + if token_filename is None: + logger.warning("The payload does not contain 'filename' key: %s.", payload) + abort(403) + if token_filename != request_filename: + logger.warning( + "The payload log_relative_path key is different than the one in token:" + "Request path: %s. Token path: %s.", + request_filename, + token_filename, + ) + abort(403) + except InvalidAudienceError: + logger.warning("Invalid audience for the request", exc_info=True) abort(403) - - if filename != request.view_args['filename']: + except InvalidSignatureError: + logger.warning("The signature of the request was wrong", exc_info=True) abort(403) - - # Validate the `iat` and `exp` are within `max_request_age` of now. - now = int(time.time()) - if abs(now - issued_at) > max_request_age: + except ImmatureSignatureError: + logger.warning("The signature of the request was sent from the future", exc_info=True) abort(403) - if abs(now - expires_at) > max_request_age: + except ExpiredSignatureError: + logger.warning( + "The signature of the request has expired. Make sure that all components " + "in your system have synchronized clocks. " + "See more at %s", + get_docs_url("configurations-ref.html#secret-key"), + exc_info=True, + ) abort(403) - if issued_at > expires_at or expires_at - issued_at > max_request_age: + except InvalidIssuedAtError: + logger.warning( + "The request was issues in the future. Make sure that all components " + "in your system have synchronized clocks. " + "See more at %s", + get_docs_url("configurations-ref.html#secret-key"), + exc_info=True, + ) + abort(403) + except Exception: + logger.warning("Unknown error", exc_info=True) abort(403) @flask_app.route('/log/') diff --git a/newsfragments/24519.misc.rst b/newsfragments/24519.misc.rst new file mode 100644 index 0000000000000..799d9141d2a0a --- /dev/null +++ b/newsfragments/24519.misc.rst @@ -0,0 +1 @@ +The JWT claims in the request to retrieve logs have been standardized: we use "nbf" and "aud" claims for maturity and audience of the requests. Also "filename" payload field is used to keep log name. diff --git a/tests/utils/test_serve_logs.py b/tests/utils/test_serve_logs.py index 168a43a012787..f8d38817592b8 100644 --- a/tests/utils/test_serve_logs.py +++ b/tests/utils/test_serve_logs.py @@ -14,12 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import datetime from typing import TYPE_CHECKING +import jwt import pytest -from itsdangerous import TimedJSONWebSignatureSerializer +from freezegun import freeze_time from airflow.configuration import conf +from airflow.utils.jwt_signer import JWTSigner from airflow.utils.serve_logs import create_app from tests.test_utils.config import conf_vars @@ -47,12 +50,19 @@ def sample_log(tmpdir): @pytest.fixture def signer(): - return TimedJSONWebSignatureSerializer( + return JWTSigner( secret_key=conf.get('webserver', 'secret_key'), - algorithm_name='HS512', - expires_in=30, - # This isn't really a "salt", more of a signing context - salt='task-instance-logs', + expiration_time_in_seconds=30, + audience="task-instance-logs", + ) + + +@pytest.fixture +def different_audience(): + return JWTSigner( + secret_key=conf.get('webserver', 'secret_key'), + expiration_time_in_seconds=30, + audience="different-audience", ) @@ -62,49 +72,134 @@ def test_forbidden_no_auth(self, client: "FlaskClient"): assert 403 == client.get('/log/sample.log').status_code def test_should_serve_file(self, client: "FlaskClient", signer): + response = client.get( + '/log/sample.log', + headers={ + 'Authorization': signer.generate_signed_token({"filename": 'sample.log'}), + }, + ) + assert response.data.decode() == LOG_DATA + assert response.status_code == 200 + + def test_forbidden_different_logname(self, client: "FlaskClient", signer): + response = client.get( + '/log/sample.log', + headers={ + 'Authorization': signer.generate_signed_token({"filename": 'different.log'}), + }, + ) + assert response.status_code == 403 + + def test_forbidden_expired(self, client: "FlaskClient", signer): + with freeze_time("2010-01-14"): + token = signer.generate_signed_token({"filename": 'sample.log'}) + assert ( + client.get( + '/log/sample.log', + headers={ + 'Authorization': token, + }, + ).status_code + == 403 + ) + + def test_forbidden_future(self, client: "FlaskClient", signer): + with freeze_time(datetime.datetime.utcnow() + datetime.timedelta(seconds=3600)): + token = signer.generate_signed_token({"filename": 'sample.log'}) assert ( - LOG_DATA - == client.get( + client.get( '/log/sample.log', headers={ - 'Authorization': signer.dumps('sample.log'), + 'Authorization': token, }, - ).data.decode() + ).status_code + == 403 ) - def test_forbidden_too_long_validity(self, client: "FlaskClient", signer): - signer.expires_in = 3600 + def test_ok_with_short_future_skew(self, client: "FlaskClient", signer): + with freeze_time(datetime.datetime.utcnow() + datetime.timedelta(seconds=1)): + token = signer.generate_signed_token({"filename": 'sample.log'}) assert ( - 403 - == client.get( + client.get( '/log/sample.log', headers={ - 'Authorization': signer.dumps('sample.log'), + 'Authorization': token, }, ).status_code + == 200 ) - def test_forbidden_expired(self, client: "FlaskClient", signer): - # Fake the time we think we are - signer.now = lambda: 0 + def test_ok_with_short_past_skew(self, client: "FlaskClient", signer): + with freeze_time(datetime.datetime.utcnow() - datetime.timedelta(seconds=31)): + token = signer.generate_signed_token({"filename": 'sample.log'}) + assert ( + client.get( + '/log/sample.log', + headers={ + 'Authorization': token, + }, + ).status_code + == 200 + ) + + def test_forbidden_with_long_future_skew(self, client: "FlaskClient", signer): + with freeze_time(datetime.datetime.utcnow() + datetime.timedelta(seconds=10)): + token = signer.generate_signed_token({"filename": 'sample.log'}) + assert ( + client.get( + '/log/sample.log', + headers={ + 'Authorization': token, + }, + ).status_code + == 403 + ) + + def test_forbidden_with_long_past_skew(self, client: "FlaskClient", signer): + with freeze_time(datetime.datetime.utcnow() - datetime.timedelta(seconds=40)): + token = signer.generate_signed_token({"filename": 'sample.log'}) + assert ( + client.get( + '/log/sample.log', + headers={ + 'Authorization': token, + }, + ).status_code + == 403 + ) + + def test_wrong_audience(self, client: "FlaskClient", different_audience): assert ( - 403 - == client.get( + client.get( '/log/sample.log', headers={ - 'Authorization': signer.dumps('sample.log'), + 'Authorization': different_audience.generate_signed_token({"filename": 'sample.log'}), }, ).status_code + == 403 ) - def test_wrong_context(self, client: "FlaskClient", signer): - signer.salt = None + @pytest.mark.parametrize("claim_to_remove", ["iat", "exp", "nbf", "aud"]) + def test_missing_claims(self, claim_to_remove: str, client: "FlaskClient"): + jwt_dict = { + "aud": "task-instance-logs", + "iat": datetime.datetime.utcnow(), + "nbf": datetime.datetime.utcnow(), + "exp": datetime.datetime.utcnow() + datetime.timedelta(seconds=30), + } + del jwt_dict[claim_to_remove] + jwt_dict.update({"filename": 'sample.log'}) + token = jwt.encode( + jwt_dict, + conf.get('webserver', 'secret_key'), + algorithm="HS512", + ) assert ( - 403 - == client.get( + client.get( '/log/sample.log', headers={ - 'Authorization': signer.dumps('sample.log'), + 'Authorization': token, }, ).status_code + == 403 )