diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py index 39d809e5167ed..86482d3d05468 100644 --- a/airflow/config_templates/airflow_local_settings.py +++ b/airflow/config_templates/airflow_local_settings.py @@ -203,6 +203,7 @@ # WASB buckets should start with "wasb" # just to help Airflow select correct handler REMOTE_BASE_LOG_FOLDER: str = conf.get_mandatory_value("logging", "REMOTE_BASE_LOG_FOLDER") + REMOTE_TASK_HANDLER_KWARGS = conf.getjson("logging", "REMOTE_TASK_HANDLER_KWARGS", fallback={}) if REMOTE_BASE_LOG_FOLDER.startswith("s3://"): S3_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = { @@ -252,7 +253,6 @@ "wasb_log_folder": REMOTE_BASE_LOG_FOLDER, "wasb_container": "airflow-logs", "filename_template": FILENAME_TEMPLATE, - "delete_local_copy": False, }, } @@ -315,3 +315,4 @@ "section 'elasticsearch' if you are using Elasticsearch. In the other case, " "'remote_base_log_folder' option in the 'logging' section." ) + DEFAULT_LOGGING_CONFIG["handlers"]["task"].update(REMOTE_TASK_HANDLER_KWARGS) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 274838bd3911b..38e5600f85757 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -606,6 +606,14 @@ logging: type: string example: ~ default: "" + delete_local_logs: + description: | + Whether the local log files for GCS, S3, WASB and OSS remote logging should be deleted after + they are uploaded to the remote location. + version_added: 2.6.0 + type: string + example: ~ + default: "False" google_key_path: description: | Path to Google Credential JSON file. If omitted, authorization based on `the Application Default @@ -628,6 +636,16 @@ logging: type: string example: ~ default: "" + remote_task_handler_kwargs: + description: | + The remote_task_handler_kwargs param is loaded into a dictionary and passed to __init__ of remote + task handler and it overrides the values provided by Airflow config. For example if you set + `delete_local_logs=False` and you provide ``{{"delete_local_copy": true}}``, then the local + log files will be deleted after they are uploaded to remote location. + version_added: 2.6.0 + type: string + example: '{"delete_local_copy": true}' + default: "" encrypt_s3_logs: description: | Use server-side encryption for logs stored in S3 diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 87bf519b74b97..ee9715fd7f396 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -345,6 +345,10 @@ remote_logging = False # reading logs, not writing them. remote_log_conn_id = +# Whether the local log files for GCS, S3, WASB and OSS remote logging should be deleted after +# they are uploaded to the remote location. +delete_local_logs = False + # Path to Google Credential JSON file. If omitted, authorization based on `the Application Default # Credentials # `__ will @@ -359,6 +363,13 @@ google_key_path = # Stackdriver logs should start with "stackdriver://" remote_base_log_folder = +# The remote_task_handler_kwargs param is loaded into a dictionary and passed to __init__ of remote +# task handler and it overrides the values provided by Airflow config. For example if you set +# `delete_local_logs=False` and you provide ``{{"delete_local_copy": true}}``, then the local +# log files will be deleted after they are uploaded to remote location. +# Example: remote_task_handler_kwargs = {{"delete_local_copy": true}} +remote_task_handler_kwargs = + # Use server-side encryption for logs stored in S3 encrypt_s3_logs = False diff --git a/airflow/providers/alibaba/cloud/log/oss_task_handler.py b/airflow/providers/alibaba/cloud/log/oss_task_handler.py index c443b4e014392..512eda90c69e1 100644 --- a/airflow/providers/alibaba/cloud/log/oss_task_handler.py +++ b/airflow/providers/alibaba/cloud/log/oss_task_handler.py @@ -20,6 +20,9 @@ import contextlib import os import pathlib +import shutil + +from packaging.version import Version from airflow.compat.functools import cached_property from airflow.configuration import conf @@ -28,6 +31,17 @@ from airflow.utils.log.logging_mixin import LoggingMixin +def get_default_delete_local_copy(): + """Load delete_local_logs conf if Airflow version > 2.6 and return False if not + TODO: delete this function when min airflow version >= 2.6 + """ + from airflow.version import version + + if Version(version) < Version("2.6"): + return False + return conf.getboolean("logging", "delete_local_logs") + + class OSSTaskHandler(FileTaskHandler, LoggingMixin): """ OSSTaskHandler is a python log handler that handles and reads @@ -35,7 +49,7 @@ class OSSTaskHandler(FileTaskHandler, LoggingMixin): uploads to and reads from OSS remote storage. """ - def __init__(self, base_log_folder, oss_log_folder, filename_template=None): + def __init__(self, base_log_folder, oss_log_folder, filename_template=None, **kwargs): self.log.info("Using oss_task_handler for remote logging...") super().__init__(base_log_folder, filename_template) (self.bucket_name, self.base_folder) = OSSHook.parse_oss_url(oss_log_folder) @@ -43,6 +57,9 @@ def __init__(self, base_log_folder, oss_log_folder, filename_template=None): self._hook = None self.closed = False self.upload_on_close = True + self.delete_local_copy = ( + kwargs["delete_local_copy"] if "delete_local_copy" in kwargs else get_default_delete_local_copy() + ) @cached_property def hook(self): @@ -92,7 +109,9 @@ def close(self): if os.path.exists(local_loc): # read log and remove old logs to get just the latest additions log = pathlib.Path(local_loc).read_text() - self.oss_write(log, remote_loc) + oss_write = self.oss_write(log, remote_loc) + if oss_write and self.delete_local_copy: + shutil.rmtree(os.path.dirname(local_loc)) # Mark closed so we don't double write if close is called twice self.closed = True @@ -154,15 +173,16 @@ def oss_read(self, remote_log_location, return_error=False): if return_error: return msg - def oss_write(self, log, remote_log_location, append=True): + def oss_write(self, log, remote_log_location, append=True) -> bool: """ - Writes the log to the remote_log_location. Fails silently if no hook - was created. + Writes the log to the remote_log_location and return `True` when done. Fails silently + and return `False` if no log was created. :param log: the log to write to the remote_log_location :param remote_log_location: the log's location in remote storage :param append: if False, any existing log file is overwritten. If True, the new log is appended to any existing logs. + :return: whether the log is successfully written to remote location or not. """ oss_remote_log_location = f"{self.base_folder}/{remote_log_location}" pos = 0 @@ -180,3 +200,5 @@ def oss_write(self, log, remote_log_location, append=True): str(pos), str(append), ) + return False + return True diff --git a/airflow/providers/amazon/aws/log/s3_task_handler.py b/airflow/providers/amazon/aws/log/s3_task_handler.py index 098f17a28a95c..20754075a2dac 100644 --- a/airflow/providers/amazon/aws/log/s3_task_handler.py +++ b/airflow/providers/amazon/aws/log/s3_task_handler.py @@ -19,6 +19,9 @@ import os import pathlib +import shutil + +from packaging.version import Version from airflow.compat.functools import cached_property from airflow.configuration import conf @@ -27,6 +30,17 @@ from airflow.utils.log.logging_mixin import LoggingMixin +def get_default_delete_local_copy(): + """Load delete_local_logs conf if Airflow version > 2.6 and return False if not + TODO: delete this function when min airflow version >= 2.6 + """ + from airflow.version import version + + if Version(version) < Version("2.6"): + return False + return conf.getboolean("logging", "delete_local_logs") + + class S3TaskHandler(FileTaskHandler, LoggingMixin): """ S3TaskHandler is a python log handler that handles and reads @@ -36,13 +50,18 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin): trigger_should_wrap = True - def __init__(self, base_log_folder: str, s3_log_folder: str, filename_template: str | None = None): + def __init__( + self, base_log_folder: str, s3_log_folder: str, filename_template: str | None = None, **kwargs + ): super().__init__(base_log_folder, filename_template) self.remote_base = s3_log_folder self.log_relative_path = "" self._hook = None self.closed = False self.upload_on_close = True + self.delete_local_copy = ( + kwargs["delete_local_copy"] if "delete_local_copy" in kwargs else get_default_delete_local_copy() + ) @cached_property def hook(self): @@ -84,7 +103,9 @@ def close(self): if os.path.exists(local_loc): # read log and remove old logs to get just the latest additions log = pathlib.Path(local_loc).read_text() - self.s3_write(log, remote_loc) + write_to_s3 = self.s3_write(log, remote_loc) + if write_to_s3 and self.delete_local_copy: + shutil.rmtree(os.path.dirname(local_loc)) # Mark closed so we don't double write if close is called twice self.closed = True @@ -164,16 +185,17 @@ def s3_read(self, remote_log_location: str, return_error: bool = False) -> str: return msg return "" - def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_retry: int = 1): + def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_retry: int = 1) -> bool: """ - Writes the log to the remote_log_location. Fails silently if no hook - was created. + Writes the log to the remote_log_location and return `True` when done. Fails silently + and return `False` if no log was created. :param log: the log to write to the remote_log_location :param remote_log_location: the log's location in remote storage :param append: if False, any existing log file is overwritten. If True, the new log is appended to any existing logs. :param max_retry: Maximum number of times to retry on upload failure + :return: whether the log is successfully written to remote location or not. """ try: if append and self.s3_log_exists(remote_log_location): @@ -181,6 +203,7 @@ def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_ log = "\n".join([old_log, log]) if old_log else log except Exception: self.log.exception("Could not verify previous log to append") + return False # Default to a single retry attempt because s3 upload failures are # rare but occasionally occur. Multiple retry attempts are unlikely @@ -199,3 +222,5 @@ def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_ self.log.warning("Failed attempt to write logs to %s, will retry", remote_log_location) else: self.log.exception("Could not write logs to %s", remote_log_location) + return False + return True diff --git a/airflow/providers/google/cloud/log/gcs_task_handler.py b/airflow/providers/google/cloud/log/gcs_task_handler.py index 4523cddc5f442..303145310f10a 100644 --- a/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -19,11 +19,13 @@ import logging import os +import shutil from pathlib import Path from typing import Collection # not sure why but mypy complains on missing `storage` but it is clearly there and is importable from google.cloud import storage # type: ignore[attr-defined] +from packaging.version import Version from airflow.compat.functools import cached_property from airflow.configuration import conf @@ -43,6 +45,17 @@ logger = logging.getLogger(__name__) +def get_default_delete_local_copy(): + """Load delete_local_logs conf if Airflow version > 2.6 and return False if not + TODO: delete this function when min airflow version >= 2.6 + """ + from airflow.version import version + + if Version(version) < Version("2.6"): + return False + return conf.getboolean("logging", "delete_local_logs") + + class GCSTaskHandler(FileTaskHandler, LoggingMixin): """ GCSTaskHandler is a python log handler that handles and reads @@ -63,6 +76,8 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin): :param gcp_scopes: Comma-separated string containing OAuth2 scopes :param project_id: Project ID to read the secrets from. If not passed, the project ID from credentials will be used. + :param delete_local_copy: Whether local log files should be deleted after they are downloaded when using + remote logging """ trigger_should_wrap = True @@ -77,6 +92,7 @@ def __init__( gcp_keyfile_dict: dict | None = None, gcp_scopes: Collection[str] | None = _DEFAULT_SCOPESS, project_id: str | None = None, + **kwargs, ): super().__init__(base_log_folder, filename_template) self.remote_base = gcs_log_folder @@ -87,6 +103,9 @@ def __init__( self.gcp_keyfile_dict = gcp_keyfile_dict self.scopes = gcp_scopes self.project_id = project_id + self.delete_local_copy = ( + kwargs["delete_local_copy"] if "delete_local_copy" in kwargs else get_default_delete_local_copy() + ) @cached_property def hook(self) -> GCSHook | None: @@ -147,7 +166,9 @@ def close(self): # read log and remove old logs to get just the latest additions with open(local_loc) as logfile: log = logfile.read() - self.gcs_write(log, remote_loc) + gcs_write = self.gcs_write(log, remote_loc) + if gcs_write and self.delete_local_copy: + shutil.rmtree(os.path.dirname(local_loc)) # Mark closed so we don't double write if close is called twice self.closed = True @@ -207,13 +228,14 @@ def _read(self, ti, try_number, metadata=None): return "".join([f"*** {x}\n" for x in messages]) + "\n".join(logs), {"end_of_log": True} - def gcs_write(self, log, remote_log_location): + def gcs_write(self, log, remote_log_location) -> bool: """ - Writes the log to the remote_log_location. Fails silently if no log - was created. + Writes the log to the remote_log_location and return `True` when done. Fails silently + and return `False` if no log was created. :param log: the log to write to the remote_log_location :param remote_log_location: the log's location in remote storage + :return: whether the log is successfully written to remote location or not. """ try: blob = storage.Blob.from_string(remote_log_location, self.client) @@ -232,6 +254,8 @@ def gcs_write(self, log, remote_log_location): blob.upload_from_string(log, content_type="text/plain") except Exception as e: self.log.error("Could not write logs to %s: %s", remote_log_location, e) + return False + return True @staticmethod def no_log_found(exc): diff --git a/airflow/providers/microsoft/azure/log/wasb_task_handler.py b/airflow/providers/microsoft/azure/log/wasb_task_handler.py index 52af3171c2d6c..aab84a04c5499 100644 --- a/airflow/providers/microsoft/azure/log/wasb_task_handler.py +++ b/airflow/providers/microsoft/azure/log/wasb_task_handler.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any from azure.core.exceptions import HttpResponseError +from packaging.version import Version from airflow.compat.functools import cached_property from airflow.configuration import conf @@ -30,6 +31,17 @@ from airflow.utils.log.logging_mixin import LoggingMixin +def get_default_delete_local_copy(): + """Load delete_local_logs conf if Airflow version > 2.6 and return False if not + TODO: delete this function when min airflow version >= 2.6 + """ + from airflow.version import version + + if Version(version) < Version("2.6"): + return False + return conf.getboolean("logging", "delete_local_logs") + + class WasbTaskHandler(FileTaskHandler, LoggingMixin): """ WasbTaskHandler is a python log handler that handles and reads @@ -44,9 +56,9 @@ def __init__( base_log_folder: str, wasb_log_folder: str, wasb_container: str, - delete_local_copy: str, *, filename_template: str | None = None, + **kwargs, ) -> None: super().__init__(base_log_folder, filename_template) self.wasb_container = wasb_container @@ -55,7 +67,9 @@ def __init__( self._hook = None self.closed = False self.upload_on_close = True - self.delete_local_copy = delete_local_copy + self.delete_local_copy = ( + kwargs["delete_local_copy"] if "delete_local_copy" in kwargs else get_default_delete_local_copy() + ) @cached_property def hook(self): @@ -107,9 +121,9 @@ def close(self) -> None: # read log and remove old logs to get just the latest additions with open(local_loc) as logfile: log = logfile.read() - self.wasb_write(log, remote_loc, append=True) + wasb_write = self.wasb_write(log, remote_loc, append=True) - if self.delete_local_copy: + if wasb_write and self.delete_local_copy: shutil.rmtree(os.path.dirname(local_loc)) # Mark closed so we don't double write if close is called twice self.closed = True @@ -209,7 +223,7 @@ def wasb_read(self, remote_log_location: str, return_error: bool = False): return msg return "" - def wasb_write(self, log: str, remote_log_location: str, append: bool = True) -> None: + def wasb_write(self, log: str, remote_log_location: str, append: bool = True) -> bool: """ Writes the log to the remote_log_location. Fails silently if no hook was created. @@ -227,3 +241,5 @@ def wasb_write(self, log: str, remote_log_location: str, append: bool = True) -> self.hook.load_string(log, self.wasb_container, remote_log_location, overwrite=True) except Exception: self.log.exception("Could not write logs to %s", remote_log_location) + return False + return True diff --git a/docs/apache-airflow/administration-and-deployment/logging-monitoring/logging-tasks.rst b/docs/apache-airflow/administration-and-deployment/logging-monitoring/logging-tasks.rst index c7d5819b31b83..f3da6dbabede7 100644 --- a/docs/apache-airflow/administration-and-deployment/logging-monitoring/logging-tasks.rst +++ b/docs/apache-airflow/administration-and-deployment/logging-monitoring/logging-tasks.rst @@ -25,6 +25,16 @@ Core Airflow provides an interface FileTaskHandler, which writes task logs to fi services (:doc:`apache-airflow-providers:index`) and some of them provide handlers that extend the logging capability of Apache Airflow. You can see all of these providers in :doc:`apache-airflow-providers:core-extensions/logging`. +When using S3, GCS, WASB or OSS remote logging service, you can delete the local log files after +they are uploaded to the remote location, by setting the config: + +.. code-block:: ini + + [logging] + remote_logging = True + remote_base_log_folder = schema://path/to/remote/log + delete_local_logs = True + Configuring logging ------------------- diff --git a/tests/core/test_logging_config.py b/tests/core/test_logging_config.py index 70636ca67ec23..0d8fbbed214ce 100644 --- a/tests/core/test_logging_config.py +++ b/tests/core/test_logging_config.py @@ -316,3 +316,24 @@ def test_log_group_arns_remote_logging_with_cloudwatch_handler( airflow_local_settings.DEFAULT_LOGGING_CONFIG["handlers"]["task"]["log_group_arn"] == log_group_arn ) + + def test_loading_remote_logging_with_kwargs(self): + """Test if logging can be configured successfully with kwargs""" + from airflow.config_templates import airflow_local_settings + from airflow.logging_config import configure_logging + from airflow.utils.log.s3_task_handler import S3TaskHandler + + with conf_vars( + { + ("logging", "remote_logging"): "True", + ("logging", "remote_log_conn_id"): "some_s3", + ("logging", "remote_base_log_folder"): "s3://some-folder", + ("logging", "remote_task_handler_kwargs"): '{"delete_local_copy": true}', + } + ): + importlib.reload(airflow_local_settings) + configure_logging() + + logger = logging.getLogger("airflow.task") + assert isinstance(logger.handlers[0], S3TaskHandler) + assert getattr(logger.handlers[0], "delete_local_copy") is True diff --git a/tests/providers/alibaba/cloud/log/test_oss_task_handler.py b/tests/providers/alibaba/cloud/log/test_oss_task_handler.py index 2cf999849143f..0d0348d8aa07a 100644 --- a/tests/providers/alibaba/cloud/log/test_oss_task_handler.py +++ b/tests/providers/alibaba/cloud/log/test_oss_task_handler.py @@ -17,10 +17,17 @@ # under the License. from __future__ import annotations +import os from unittest import mock from unittest.mock import PropertyMock +import pytest + from airflow.providers.alibaba.cloud.log.oss_task_handler import OSSTaskHandler +from airflow.utils.state import TaskInstanceState +from airflow.utils.timezone import datetime +from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_dags, clear_db_runs OSS_TASK_HANDLER_STRING = "airflow.providers.alibaba.cloud.log.oss_task_handler.{}" MOCK_OSS_CONN_ID = "mock_id" @@ -37,6 +44,20 @@ def setup_method(self): self.oss_log_folder = f"oss://{MOCK_BUCKET_NAME}/airflow/logs" self.oss_task_handler = OSSTaskHandler(self.base_log_folder, self.oss_log_folder) + @pytest.fixture(autouse=True) + def task_instance(self, create_task_instance): + self.ti = ti = create_task_instance( + dag_id="dag_for_testing_oss_task_handler", + task_id="task_for_testing_oss_task_handler", + execution_date=datetime(2020, 1, 1), + state=TaskInstanceState.RUNNING, + ) + ti.try_number = 1 + ti.raw = False + yield + clear_db_runs() + clear_db_dags() + @mock.patch(OSS_TASK_HANDLER_STRING.format("conf.get")) @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSHook")) def test_hook(self, mock_service, mock_conf_get): @@ -130,3 +151,29 @@ def test_oss_write_into_remote_non_existing_file_not_via_append(self, mock_servi mock_service.return_value.append_string.assert_called_once_with( MOCK_BUCKET_NAME, MOCK_CONTENT, "airflow/logs/1.log", 0 ) + + @pytest.mark.parametrize( + "delete_local_copy, expected_existence_of_local_copy, airflow_version", + [(True, False, "2.6.0"), (False, True, "2.6.0"), (True, True, "2.5.0"), (False, True, "2.5.0")], + ) + @mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.hook"), new_callable=PropertyMock) + def test_close_with_delete_local_copy_conf( + self, + mock_service, + tmp_path_factory, + delete_local_copy, + expected_existence_of_local_copy, + airflow_version, + ): + local_log_path = str(tmp_path_factory.mktemp("local-oss-log-location")) + with conf_vars({("logging", "delete_local_logs"): str(delete_local_copy)}), mock.patch( + "airflow.version.version", airflow_version + ): + handler = OSSTaskHandler(local_log_path, self.oss_log_folder) + + handler.log.info("test") + handler.set_context(self.ti) + assert handler.upload_on_close + + handler.close() + assert os.path.exists(handler.handler.baseFilename) == expected_existence_of_local_copy diff --git a/tests/providers/amazon/aws/log/test_s3_task_handler.py b/tests/providers/amazon/aws/log/test_s3_task_handler.py index aeca09d36d6a7..c5223bd21b3c1 100644 --- a/tests/providers/amazon/aws/log/test_s3_task_handler.py +++ b/tests/providers/amazon/aws/log/test_s3_task_handler.py @@ -236,3 +236,22 @@ def test_close_no_upload(self): with pytest.raises(ClientError): boto3.resource("s3").Object("bucket", self.remote_log_key).get() + + @pytest.mark.parametrize( + "delete_local_copy, expected_existence_of_local_copy, airflow_version", + [(True, False, "2.6.0"), (False, True, "2.6.0"), (True, True, "2.5.0"), (False, True, "2.5.0")], + ) + def test_close_with_delete_local_logs_conf( + self, delete_local_copy, expected_existence_of_local_copy, airflow_version + ): + with conf_vars({("logging", "delete_local_logs"): str(delete_local_copy)}), mock.patch( + "airflow.version.version", airflow_version + ): + handler = S3TaskHandler(self.local_log_location, self.remote_log_base) + + handler.log.info("test") + handler.set_context(self.ti) + assert handler.upload_on_close + + handler.close() + assert os.path.exists(handler.handler.baseFilename) == expected_existence_of_local_copy diff --git a/tests/providers/google/cloud/log/test_gcs_task_handler.py b/tests/providers/google/cloud/log/test_gcs_task_handler.py index 690ae9dc335e4..ead895a3340db 100644 --- a/tests/providers/google/cloud/log/test_gcs_task_handler.py +++ b/tests/providers/google/cloud/log/test_gcs_task_handler.py @@ -18,7 +18,7 @@ import copy import logging -import tempfile +import os from unittest import mock from unittest.mock import MagicMock @@ -48,10 +48,8 @@ def task_instance(self, create_task_instance): clear_db_dags() @pytest.fixture(autouse=True) - def local_log_location(self): - with tempfile.TemporaryDirectory() as td: - self.local_log_location = td - yield td + def local_log_location(self, tmp_path_factory): + return str(tmp_path_factory.mktemp("local-gcs-log-location")) @pytest.fixture(autouse=True) def gcs_task_handler(self, create_log_template, local_log_location): @@ -127,7 +125,7 @@ def test_should_read_from_local_on_logs_read_error(self, mock_blob, mock_client, "*** * gs://bucket/remote/log/location/1.log\n" "*** Unable to read remote log Failed to connect\n" "*** Found local files:\n" - f"*** * {self.local_log_location}/1.log\n" + f"*** * {self.gcs_task_handler.local_base}/1.log\n" ) assert metadata == {"end_of_log": True, "log_pos": 0} mock_blob.from_string.assert_called_once_with( @@ -246,3 +244,39 @@ def test_write_to_remote_on_close_failed_read_old_logs(self, mock_blob, mock_cli ], any_order=False, ) + + @pytest.mark.parametrize( + "delete_local_copy, expected_existence_of_local_copy, airflow_version", + [(True, False, "2.6.0"), (False, True, "2.6.0"), (True, True, "2.5.0"), (False, True, "2.5.0")], + ) + @mock.patch( + "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id", + return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"), + ) + @mock.patch("google.cloud.storage.Client") + @mock.patch("google.cloud.storage.Blob") + def test_close_with_delete_local_copy_conf( + self, + mock_blob, + mock_client, + mock_creds, + local_log_location, + delete_local_copy, + expected_existence_of_local_copy, + airflow_version, + ): + mock_blob.from_string.return_value.download_as_bytes.return_value = b"CONTENT" + with conf_vars({("logging", "delete_local_logs"): str(delete_local_copy)}), mock.patch( + "airflow.version.version", airflow_version + ): + handler = GCSTaskHandler( + base_log_folder=local_log_location, + gcs_log_folder="gs://bucket/remote/log/location", + ) + + handler.log.info("test") + handler.set_context(self.ti) + assert handler.upload_on_close + + handler.close() + assert os.path.exists(handler.handler.baseFilename) == expected_existence_of_local_copy diff --git a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py index 60b6947f619ae..73001ae21bfc2 100644 --- a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py +++ b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py @@ -17,6 +17,7 @@ from __future__ import annotations import copy +import os import tempfile from pathlib import Path from unittest import mock @@ -65,9 +66,6 @@ def setup_method(self): delete_local_copy=True, ) - def teardown_method(self): - self.wasb_task_handler.close() - @conf_vars({("logging", "remote_log_conn_id"): "wasb_default"}) @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") def test_hook(self, mock_service): @@ -175,3 +173,33 @@ def test_write_raises(self): mock_error.assert_called_once_with( "Could not write logs to %s", "remote/log/location/1.log", exc_info=True ) + + @pytest.mark.parametrize( + "delete_local_copy, expected_existence_of_local_copy, airflow_version", + [(True, False, "2.6.0"), (False, True, "2.6.0"), (True, True, "2.5.0"), (False, True, "2.5.0")], + ) + @mock.patch("airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler.wasb_write") + def test_close_with_delete_local_logs_conf( + self, + wasb_write_mock, + ti, + tmp_path_factory, + delete_local_copy, + expected_existence_of_local_copy, + airflow_version, + ): + with conf_vars({("logging", "delete_local_logs"): str(delete_local_copy)}), mock.patch( + "airflow.version.version", airflow_version + ): + handler = WasbTaskHandler( + base_log_folder=str(tmp_path_factory.mktemp("local-s3-log-location")), + wasb_log_folder=self.wasb_log_folder, + wasb_container=self.container_name, + ) + wasb_write_mock.return_value = True + handler.log.info("test") + handler.set_context(ti) + assert handler.upload_on_close + + handler.close() + assert os.path.exists(handler.handler.baseFilename) == expected_existence_of_local_copy