From 79026c4332b169a7ede5c1fc3db698da0d10426d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Wed, 12 Jul 2023 16:08:58 -0700 Subject: [PATCH 1/7] remove provider specific helper in favor of equivalent one in core --- .../providers/amazon/aws/hooks/lambda_function.py | 6 +++--- airflow/providers/amazon/aws/hooks/redshift_data.py | 6 +++--- airflow/providers/amazon/aws/operators/batch.py | 6 +++--- airflow/providers/amazon/aws/operators/sagemaker.py | 4 ++-- .../providers/amazon/aws/secrets/secrets_manager.py | 4 ++-- .../providers/amazon/aws/secrets/systems_manager.py | 4 ++-- .../providers/amazon/aws/sensors/lambda_function.py | 4 ++-- airflow/providers/amazon/aws/utils/__init__.py | 6 +++++- .../providers/amazon/aws/utils/connection_wrapper.py | 4 ++-- tests/providers/amazon/aws/utils/test_utils.py | 12 ------------ 10 files changed, 24 insertions(+), 32 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/lambda_function.py b/airflow/providers/amazon/aws/hooks/lambda_function.py index 58ecac8bcccb..bfc709871a52 100644 --- a/airflow/providers/amazon/aws/hooks/lambda_function.py +++ b/airflow/providers/amazon/aws/hooks/lambda_function.py @@ -21,7 +21,7 @@ from typing import Any from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.amazon.aws.utils import prune_dict class LambdaHook(AwsBaseHook): @@ -76,7 +76,7 @@ def invoke_lambda( "Payload": payload, "Qualifier": qualifier, } - return self.conn.invoke(**trim_none_values(invoke_args)) + return self.conn.invoke(**prune_dict(invoke_args)) def create_lambda( self, @@ -178,4 +178,4 @@ def create_lambda( "CodeSigningConfigArn": code_signing_config_arn, "Architectures": architectures, } - return self.conn.create_function(**trim_none_values(create_function_args)) + return self.conn.create_function(**prune_dict(create_function_args)) diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index fddd42bd61df..46295484838a 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any, Iterable from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook -from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.amazon.aws.utils import prune_dict if TYPE_CHECKING: from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa @@ -88,10 +88,10 @@ def execute_query( } if isinstance(sql, list): kwargs["Sqls"] = sql - resp = self.conn.batch_execute_statement(**trim_none_values(kwargs)) + resp = self.conn.batch_execute_statement(**prune_dict(kwargs)) else: kwargs["Sql"] = sql - resp = self.conn.execute_statement(**trim_none_values(kwargs)) + resp = self.conn.execute_statement(**prune_dict(kwargs)) statement_id = resp["Id"] diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index e6221ae3e0af..6edb7818a7b7 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -43,7 +43,7 @@ BatchCreateComputeEnvironmentTrigger, BatchJobTrigger, ) -from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.amazon.aws.utils import prune_dict from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher if TYPE_CHECKING: @@ -279,7 +279,7 @@ def submit_job(self, context: Context): } try: - response = self.hook.client.submit_job(**trim_none_values(args)) + response = self.hook.client.submit_job(**prune_dict(args)) except Exception as e: self.log.error( "AWS Batch job failed submission - job definition: %s - on queue %s", @@ -484,7 +484,7 @@ def execute(self, context: Context): "serviceRole": self.service_role, "tags": self.tags, } - response = self.hook.client.create_compute_environment(**trim_none_values(kwargs)) + response = self.hook.client.create_compute_environment(**prune_dict(kwargs)) arn = response["computeEnvironmentArn"] if self.deferrable: diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index ac1b7a73d2de..934cd51a6e3c 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -31,7 +31,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger -from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.amazon.aws.utils import prune_dict from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus from airflow.providers.amazon.aws.utils.tags import format_tags from airflow.utils.json import AirflowJsonEncoder @@ -1325,7 +1325,7 @@ def execute(self, context: Context) -> str: "Description": self.description, "Tags": format_tags(self.tags), } - ans = sagemaker_hook.conn.create_experiment(**trim_none_values(params)) + ans = sagemaker_hook.conn.create_experiment(**prune_dict(params)) arn = ans["ExperimentArn"] self.log.info("Experiment %s created successfully with ARN %s.", self.name, arn) return arn diff --git a/airflow/providers/amazon/aws/secrets/secrets_manager.py b/airflow/providers/amazon/aws/secrets/secrets_manager.py index a5acf19a372e..08da4956e9a2 100644 --- a/airflow/providers/amazon/aws/secrets/secrets_manager.py +++ b/airflow/providers/amazon/aws/secrets/secrets_manager.py @@ -26,7 +26,7 @@ from urllib.parse import unquote from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.amazon.aws.utils import prune_dict from airflow.secrets import BaseSecretsBackend from airflow.utils.log.logging_mixin import LoggingMixin @@ -184,7 +184,7 @@ def client(self): conn_id = f"{self.__class__.__name__}__connection" conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs) - client_kwargs = trim_none_values( + client_kwargs = prune_dict( { "region_name": conn_config.region_name, "verify": conn_config.verify, diff --git a/airflow/providers/amazon/aws/secrets/systems_manager.py b/airflow/providers/amazon/aws/secrets/systems_manager.py index 8b1daca1f7dc..d9771dbdb842 100644 --- a/airflow/providers/amazon/aws/secrets/systems_manager.py +++ b/airflow/providers/amazon/aws/secrets/systems_manager.py @@ -21,7 +21,7 @@ import re from functools import cached_property -from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.amazon.aws.utils import prune_dict from airflow.secrets import BaseSecretsBackend from airflow.utils.log.logging_mixin import LoggingMixin @@ -118,7 +118,7 @@ def client(self): conn_id = f"{self.__class__.__name__}__connection" conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs) - client_kwargs = trim_none_values( + client_kwargs = prune_dict( { "region_name": conn_config.region_name, "verify": conn_config.verify, diff --git a/airflow/providers/amazon/aws/sensors/lambda_function.py b/airflow/providers/amazon/aws/sensors/lambda_function.py index 772ed0689a3e..7a9f72b9e993 100644 --- a/airflow/providers/amazon/aws/sensors/lambda_function.py +++ b/airflow/providers/amazon/aws/sensors/lambda_function.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook -from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.amazon.aws.utils import prune_dict if TYPE_CHECKING: from airflow.utils.context import Context @@ -71,7 +71,7 @@ def poke(self, context: Context) -> bool: "FunctionName": self.function_name, "Qualifier": self.qualifier, } - state = self.hook.conn.get_function(**trim_none_values(get_function_args))["Configuration"]["State"] + state = self.hook.conn.get_function(**prune_dict(get_function_args))["Configuration"]["State"] if state in self.FAILURE_STATES: raise AirflowException( diff --git a/airflow/providers/amazon/aws/utils/__init__.py b/airflow/providers/amazon/aws/utils/__init__.py index 8418e204815b..306135c88012 100644 --- a/airflow/providers/amazon/aws/utils/__init__.py +++ b/airflow/providers/amazon/aws/utils/__init__.py @@ -21,13 +21,17 @@ from datetime import datetime from enum import Enum +from deprecated import deprecated + +from airflow.utils.helpers import prune_dict from airflow.version import version log = logging.getLogger(__name__) +@deprecated(reason="use prune_dict() instead") def trim_none_values(obj: dict): - return {key: val for key, val in obj.items() if val is not None} + return prune_dict(obj) def datetime_to_epoch(date_time: datetime) -> int: diff --git a/airflow/providers/amazon/aws/utils/connection_wrapper.py b/airflow/providers/amazon/aws/utils/connection_wrapper.py index 1520c3e1fc06..d10de361b7fc 100644 --- a/airflow/providers/amazon/aws/utils/connection_wrapper.py +++ b/airflow/providers/amazon/aws/utils/connection_wrapper.py @@ -27,7 +27,7 @@ from botocore.config import Config from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning -from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.amazon.aws.utils import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.log.secrets_masker import mask_secret from airflow.utils.types import NOTSET, ArgNotSet @@ -293,7 +293,7 @@ def extra_dejson(self): @property def session_kwargs(self) -> dict[str, Any]: """Additional kwargs passed to boto3.session.Session.""" - return trim_none_values( + return prune_dict( { "aws_access_key_id": self.aws_access_key_id, "aws_secret_access_key": self.aws_secret_access_key, diff --git a/tests/providers/amazon/aws/utils/test_utils.py b/tests/providers/amazon/aws/utils/test_utils.py index bf404aa4fc51..66d5f734dc7d 100644 --- a/tests/providers/amazon/aws/utils/test_utils.py +++ b/tests/providers/amazon/aws/utils/test_utils.py @@ -26,7 +26,6 @@ datetime_to_epoch_ms, datetime_to_epoch_us, get_airflow_version, - trim_none_values, ) DT = datetime(2000, 1, 1, tzinfo=pytz.UTC) @@ -37,17 +36,6 @@ class EnumTest(_StringCompareEnum): FOO = "FOO" -def test_trim_none_values(): - input_object = { - "test": "test", - "empty": None, - } - expected_output_object = { - "test": "test", - } - assert trim_none_values(input_object) == expected_output_object - - def test_datetime_to_epoch(): assert datetime_to_epoch(DT) == EPOCH From 2f55714cc055161dd6616b76c937165440965def Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Wed, 12 Jul 2023 16:12:56 -0700 Subject: [PATCH 2/7] fix imports --- airflow/providers/amazon/aws/hooks/lambda_function.py | 2 +- airflow/providers/amazon/aws/hooks/redshift_data.py | 2 +- airflow/providers/amazon/aws/operators/batch.py | 2 +- airflow/providers/amazon/aws/operators/sagemaker.py | 2 +- airflow/providers/amazon/aws/secrets/secrets_manager.py | 2 +- airflow/providers/amazon/aws/secrets/systems_manager.py | 2 +- airflow/providers/amazon/aws/sensors/lambda_function.py | 2 +- airflow/providers/amazon/aws/utils/connection_wrapper.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/lambda_function.py b/airflow/providers/amazon/aws/hooks/lambda_function.py index bfc709871a52..63ae14e88371 100644 --- a/airflow/providers/amazon/aws/hooks/lambda_function.py +++ b/airflow/providers/amazon/aws/hooks/lambda_function.py @@ -21,7 +21,7 @@ from typing import Any from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from airflow.providers.amazon.aws.utils import prune_dict +from airflow.utils.helpers import prune_dict class LambdaHook(AwsBaseHook): diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index 46295484838a..0601e19f0499 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any, Iterable from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook -from airflow.providers.amazon.aws.utils import prune_dict +from airflow.utils.helpers import prune_dict if TYPE_CHECKING: from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 6edb7818a7b7..ca536cbe07f6 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -43,8 +43,8 @@ BatchCreateComputeEnvironmentTrigger, BatchJobTrigger, ) -from airflow.providers.amazon.aws.utils import prune_dict from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher +from airflow.utils.helpers import prune_dict if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 934cd51a6e3c..2fc0f02850b0 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -31,9 +31,9 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger -from airflow.providers.amazon.aws.utils import prune_dict from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus from airflow.providers.amazon.aws.utils.tags import format_tags +from airflow.utils.helpers import prune_dict from airflow.utils.json import AirflowJsonEncoder if TYPE_CHECKING: diff --git a/airflow/providers/amazon/aws/secrets/secrets_manager.py b/airflow/providers/amazon/aws/secrets/secrets_manager.py index 08da4956e9a2..f2127259728f 100644 --- a/airflow/providers/amazon/aws/secrets/secrets_manager.py +++ b/airflow/providers/amazon/aws/secrets/secrets_manager.py @@ -26,8 +26,8 @@ from urllib.parse import unquote from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.providers.amazon.aws.utils import prune_dict from airflow.secrets import BaseSecretsBackend +from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/providers/amazon/aws/secrets/systems_manager.py b/airflow/providers/amazon/aws/secrets/systems_manager.py index d9771dbdb842..09e6ff90313b 100644 --- a/airflow/providers/amazon/aws/secrets/systems_manager.py +++ b/airflow/providers/amazon/aws/secrets/systems_manager.py @@ -21,8 +21,8 @@ import re from functools import cached_property -from airflow.providers.amazon.aws.utils import prune_dict from airflow.secrets import BaseSecretsBackend +from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/providers/amazon/aws/sensors/lambda_function.py b/airflow/providers/amazon/aws/sensors/lambda_function.py index 7a9f72b9e993..eec0f56ac724 100644 --- a/airflow/providers/amazon/aws/sensors/lambda_function.py +++ b/airflow/providers/amazon/aws/sensors/lambda_function.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook -from airflow.providers.amazon.aws.utils import prune_dict +from airflow.utils.helpers import prune_dict if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/airflow/providers/amazon/aws/utils/connection_wrapper.py b/airflow/providers/amazon/aws/utils/connection_wrapper.py index d10de361b7fc..b554f34fdb1f 100644 --- a/airflow/providers/amazon/aws/utils/connection_wrapper.py +++ b/airflow/providers/amazon/aws/utils/connection_wrapper.py @@ -27,7 +27,7 @@ from botocore.config import Config from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning -from airflow.providers.amazon.aws.utils import prune_dict +from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.log.secrets_masker import mask_secret from airflow.utils.types import NOTSET, ArgNotSet From 87bf4873691e6d12969521b191333d94d9ec929a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 14 Jul 2023 12:57:15 -0700 Subject: [PATCH 3/7] fix existing helper --- airflow/utils/helpers.py | 4 ++-- tests/utils/test_helpers.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index a29cf07a5919..e26237fc96d5 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -349,7 +349,7 @@ def is_empty(x): continue elif isinstance(v, (list, dict)): new_val = prune_dict(v, mode=mode) - if new_val: + if not is_empty(new_val): new_dict[k] = new_val else: new_dict[k] = v @@ -361,7 +361,7 @@ def is_empty(x): continue elif isinstance(v, (list, dict)): new_val = prune_dict(v, mode=mode) - if new_val: + if not is_empty(new_val): new_list.append(new_val) else: new_list.append(v) diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index f7ca52190910..3b72768fc75f 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -309,6 +309,8 @@ def assert_at_most_one(true=0, truthy=0, false=0, falsy=0, notset=0): "c": {"b": "", "c": "hi", "d": ["", 0, "1"]}, "d": ["", 0, "1"], "e": ["", 0, {"b": "", "c": "hi", "d": ["", 0, "1"]}, ["", 0, "1"], [""]], + "f": {}, + "g": [""], }, ), ( @@ -324,7 +326,7 @@ def assert_at_most_one(true=0, truthy=0, false=0, falsy=0, notset=0): def test_prune_dict(self, mode, expected): l1 = ["", 0, "1", None] d1 = {"a": None, "b": "", "c": "hi", "d": l1} - d2 = {"a": None, "b": "", "c": d1, "d": l1, "e": [None, "", 0, d1, l1, [""]]} + d2 = {"a": None, "b": "", "c": d1, "d": l1, "e": [None, "", 0, d1, l1, [""]], "f": {}, "g": [""]} assert prune_dict(d2, mode=mode) == expected From d02a804eb57622b27e24d087e90a0b75f1e5694f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 14 Jul 2023 14:29:58 -0700 Subject: [PATCH 4/7] replace method only if core version > 2.7) --- airflow/providers/amazon/aws/utils/__init__.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/utils/__init__.py b/airflow/providers/amazon/aws/utils/__init__.py index 306135c88012..b9699d7ef68f 100644 --- a/airflow/providers/amazon/aws/utils/__init__.py +++ b/airflow/providers/amazon/aws/utils/__init__.py @@ -18,20 +18,29 @@ import logging import re +import warnings from datetime import datetime from enum import Enum -from deprecated import deprecated - +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.utils.helpers import prune_dict from airflow.version import version log = logging.getLogger(__name__) -@deprecated(reason="use prune_dict() instead") def trim_none_values(obj: dict): - return prune_dict(obj) + from packaging.version import Version + + from airflow.version import version + + if Version(version) < Version("2.7"): + return {key: val for key, val in obj.items() if val is not None} + else: + warnings.warn( + "use airflow.utils.helpers.prune_dict() instead", AirflowProviderDeprecationWarning, stacklevel=2 + ) + return prune_dict(obj) def datetime_to_epoch(date_time: datetime) -> int: From a7022df344f15f3dde5d66844f026e2d20c60492 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 14 Jul 2023 14:36:23 -0700 Subject: [PATCH 5/7] un-replace existing function because it's not currently equivalent --- airflow/providers/amazon/aws/hooks/lambda_function.py | 6 +++--- airflow/providers/amazon/aws/hooks/redshift_data.py | 6 +++--- airflow/providers/amazon/aws/operators/batch.py | 6 +++--- airflow/providers/amazon/aws/operators/sagemaker.py | 4 ++-- airflow/providers/amazon/aws/secrets/secrets_manager.py | 4 ++-- airflow/providers/amazon/aws/secrets/systems_manager.py | 4 ++-- airflow/providers/amazon/aws/sensors/lambda_function.py | 4 ++-- airflow/providers/amazon/aws/utils/connection_wrapper.py | 4 ++-- 8 files changed, 19 insertions(+), 19 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/lambda_function.py b/airflow/providers/amazon/aws/hooks/lambda_function.py index 63ae14e88371..58ecac8bcccb 100644 --- a/airflow/providers/amazon/aws/hooks/lambda_function.py +++ b/airflow/providers/amazon/aws/hooks/lambda_function.py @@ -21,7 +21,7 @@ from typing import Any from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from airflow.utils.helpers import prune_dict +from airflow.providers.amazon.aws.utils import trim_none_values class LambdaHook(AwsBaseHook): @@ -76,7 +76,7 @@ def invoke_lambda( "Payload": payload, "Qualifier": qualifier, } - return self.conn.invoke(**prune_dict(invoke_args)) + return self.conn.invoke(**trim_none_values(invoke_args)) def create_lambda( self, @@ -178,4 +178,4 @@ def create_lambda( "CodeSigningConfigArn": code_signing_config_arn, "Architectures": architectures, } - return self.conn.create_function(**prune_dict(create_function_args)) + return self.conn.create_function(**trim_none_values(create_function_args)) diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index 0601e19f0499..fddd42bd61df 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any, Iterable from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook -from airflow.utils.helpers import prune_dict +from airflow.providers.amazon.aws.utils import trim_none_values if TYPE_CHECKING: from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa @@ -88,10 +88,10 @@ def execute_query( } if isinstance(sql, list): kwargs["Sqls"] = sql - resp = self.conn.batch_execute_statement(**prune_dict(kwargs)) + resp = self.conn.batch_execute_statement(**trim_none_values(kwargs)) else: kwargs["Sql"] = sql - resp = self.conn.execute_statement(**prune_dict(kwargs)) + resp = self.conn.execute_statement(**trim_none_values(kwargs)) statement_id = resp["Id"] diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index ca536cbe07f6..e6221ae3e0af 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -43,8 +43,8 @@ BatchCreateComputeEnvironmentTrigger, BatchJobTrigger, ) +from airflow.providers.amazon.aws.utils import trim_none_values from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher -from airflow.utils.helpers import prune_dict if TYPE_CHECKING: from airflow.utils.context import Context @@ -279,7 +279,7 @@ def submit_job(self, context: Context): } try: - response = self.hook.client.submit_job(**prune_dict(args)) + response = self.hook.client.submit_job(**trim_none_values(args)) except Exception as e: self.log.error( "AWS Batch job failed submission - job definition: %s - on queue %s", @@ -484,7 +484,7 @@ def execute(self, context: Context): "serviceRole": self.service_role, "tags": self.tags, } - response = self.hook.client.create_compute_environment(**prune_dict(kwargs)) + response = self.hook.client.create_compute_environment(**trim_none_values(kwargs)) arn = response["computeEnvironmentArn"] if self.deferrable: diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 2fc0f02850b0..ac1b7a73d2de 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -31,9 +31,9 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger +from airflow.providers.amazon.aws.utils import trim_none_values from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus from airflow.providers.amazon.aws.utils.tags import format_tags -from airflow.utils.helpers import prune_dict from airflow.utils.json import AirflowJsonEncoder if TYPE_CHECKING: @@ -1325,7 +1325,7 @@ def execute(self, context: Context) -> str: "Description": self.description, "Tags": format_tags(self.tags), } - ans = sagemaker_hook.conn.create_experiment(**prune_dict(params)) + ans = sagemaker_hook.conn.create_experiment(**trim_none_values(params)) arn = ans["ExperimentArn"] self.log.info("Experiment %s created successfully with ARN %s.", self.name, arn) return arn diff --git a/airflow/providers/amazon/aws/secrets/secrets_manager.py b/airflow/providers/amazon/aws/secrets/secrets_manager.py index f2127259728f..a5acf19a372e 100644 --- a/airflow/providers/amazon/aws/secrets/secrets_manager.py +++ b/airflow/providers/amazon/aws/secrets/secrets_manager.py @@ -26,8 +26,8 @@ from urllib.parse import unquote from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.amazon.aws.utils import trim_none_values from airflow.secrets import BaseSecretsBackend -from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin @@ -184,7 +184,7 @@ def client(self): conn_id = f"{self.__class__.__name__}__connection" conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs) - client_kwargs = prune_dict( + client_kwargs = trim_none_values( { "region_name": conn_config.region_name, "verify": conn_config.verify, diff --git a/airflow/providers/amazon/aws/secrets/systems_manager.py b/airflow/providers/amazon/aws/secrets/systems_manager.py index 09e6ff90313b..8b1daca1f7dc 100644 --- a/airflow/providers/amazon/aws/secrets/systems_manager.py +++ b/airflow/providers/amazon/aws/secrets/systems_manager.py @@ -21,8 +21,8 @@ import re from functools import cached_property +from airflow.providers.amazon.aws.utils import trim_none_values from airflow.secrets import BaseSecretsBackend -from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin @@ -118,7 +118,7 @@ def client(self): conn_id = f"{self.__class__.__name__}__connection" conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs) - client_kwargs = prune_dict( + client_kwargs = trim_none_values( { "region_name": conn_config.region_name, "verify": conn_config.verify, diff --git a/airflow/providers/amazon/aws/sensors/lambda_function.py b/airflow/providers/amazon/aws/sensors/lambda_function.py index eec0f56ac724..772ed0689a3e 100644 --- a/airflow/providers/amazon/aws/sensors/lambda_function.py +++ b/airflow/providers/amazon/aws/sensors/lambda_function.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook -from airflow.utils.helpers import prune_dict +from airflow.providers.amazon.aws.utils import trim_none_values if TYPE_CHECKING: from airflow.utils.context import Context @@ -71,7 +71,7 @@ def poke(self, context: Context) -> bool: "FunctionName": self.function_name, "Qualifier": self.qualifier, } - state = self.hook.conn.get_function(**prune_dict(get_function_args))["Configuration"]["State"] + state = self.hook.conn.get_function(**trim_none_values(get_function_args))["Configuration"]["State"] if state in self.FAILURE_STATES: raise AirflowException( diff --git a/airflow/providers/amazon/aws/utils/connection_wrapper.py b/airflow/providers/amazon/aws/utils/connection_wrapper.py index b554f34fdb1f..1520c3e1fc06 100644 --- a/airflow/providers/amazon/aws/utils/connection_wrapper.py +++ b/airflow/providers/amazon/aws/utils/connection_wrapper.py @@ -27,7 +27,7 @@ from botocore.config import Config from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning -from airflow.utils.helpers import prune_dict +from airflow.providers.amazon.aws.utils import trim_none_values from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.log.secrets_masker import mask_secret from airflow.utils.types import NOTSET, ArgNotSet @@ -293,7 +293,7 @@ def extra_dejson(self): @property def session_kwargs(self) -> dict[str, Any]: """Additional kwargs passed to boto3.session.Session.""" - return prune_dict( + return trim_none_values( { "aws_access_key_id": self.aws_access_key_id, "aws_secret_access_key": self.aws_secret_access_key, From 63d4bf643f52cc1b6fcf6d553bcb5c4d8ccd4447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 14 Jul 2023 14:39:24 -0700 Subject: [PATCH 6/7] add comments --- airflow/providers/amazon/aws/utils/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/airflow/providers/amazon/aws/utils/__init__.py b/airflow/providers/amazon/aws/utils/__init__.py index b9699d7ef68f..2466c7227131 100644 --- a/airflow/providers/amazon/aws/utils/__init__.py +++ b/airflow/providers/amazon/aws/utils/__init__.py @@ -35,8 +35,12 @@ def trim_none_values(obj: dict): from airflow.version import version if Version(version) < Version("2.7"): + # before version 2.7, the behavior is not the same. + # Empty dict and lists are removed from the given dict. return {key: val for key, val in obj.items() if val is not None} else: + # once airflow 2.6 rolls out of compatibility support for provider packages, + # we can remove this method and replace all usages in aws code. warnings.warn( "use airflow.utils.helpers.prune_dict() instead", AirflowProviderDeprecationWarning, stacklevel=2 ) From 335bb60e700f61e18d3f9d4c07cef5b904a32ec8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Mon, 17 Jul 2023 14:06:05 -0700 Subject: [PATCH 7/7] remove warning for now --- airflow/providers/amazon/aws/utils/__init__.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/airflow/providers/amazon/aws/utils/__init__.py b/airflow/providers/amazon/aws/utils/__init__.py index 2466c7227131..312366df26ae 100644 --- a/airflow/providers/amazon/aws/utils/__init__.py +++ b/airflow/providers/amazon/aws/utils/__init__.py @@ -18,11 +18,9 @@ import logging import re -import warnings from datetime import datetime from enum import Enum -from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.utils.helpers import prune_dict from airflow.version import version @@ -40,10 +38,10 @@ def trim_none_values(obj: dict): return {key: val for key, val in obj.items() if val is not None} else: # once airflow 2.6 rolls out of compatibility support for provider packages, - # we can remove this method and replace all usages in aws code. - warnings.warn( - "use airflow.utils.helpers.prune_dict() instead", AirflowProviderDeprecationWarning, stacklevel=2 - ) + # we can replace usages of this method with the core one in our code, + # and uncomment this warning for users who may use it. + # warnings.warn("use airflow.utils.helpers.prune_dict() instead", + # AirflowProviderDeprecationWarning, stacklevel=2) return prune_dict(obj)