Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in prune_dict where empty dict and list would be removed even in strict mode #32573

Merged
merged 9 commits into from
Jul 19, 2023
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/hooks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.utils.helpers import prune_dict


class LambdaHook(AwsBaseHook):
Expand Down Expand Up @@ -76,7 +76,7 @@ def invoke_lambda(
"Payload": payload,
"Qualifier": qualifier,
}
return self.conn.invoke(**trim_none_values(invoke_args))
return self.conn.invoke(**prune_dict(invoke_args))

def create_lambda(
self,
Expand Down Expand Up @@ -178,4 +178,4 @@ def create_lambda(
"CodeSigningConfigArn": code_signing_config_arn,
"Architectures": architectures,
}
return self.conn.create_function(**trim_none_values(create_function_args))
return self.conn.create_function(**prune_dict(create_function_args))
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import TYPE_CHECKING, Any, Iterable

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.utils.helpers import prune_dict

if TYPE_CHECKING:
from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa
Expand Down Expand Up @@ -88,10 +88,10 @@ def execute_query(
}
if isinstance(sql, list):
kwargs["Sqls"] = sql
resp = self.conn.batch_execute_statement(**trim_none_values(kwargs))
resp = self.conn.batch_execute_statement(**prune_dict(kwargs))
else:
kwargs["Sql"] = sql
resp = self.conn.execute_statement(**trim_none_values(kwargs))
resp = self.conn.execute_statement(**prune_dict(kwargs))

statement_id = resp["Id"]

Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
BatchCreateComputeEnvironmentTrigger,
BatchJobTrigger,
)
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -279,7 +279,7 @@ def submit_job(self, context: Context):
}

try:
response = self.hook.client.submit_job(**trim_none_values(args))
response = self.hook.client.submit_job(**prune_dict(args))
except Exception as e:
self.log.error(
"AWS Batch job failed submission - job definition: %s - on queue %s",
Expand Down Expand Up @@ -484,7 +484,7 @@ def execute(self, context: Context):
"serviceRole": self.service_role,
"tags": self.tags,
}
response = self.hook.client.create_compute_environment(**trim_none_values(kwargs))
response = self.hook.client.create_compute_environment(**prune_dict(kwargs))
arn = response["computeEnvironmentArn"]

if self.deferrable:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils.helpers import prune_dict
from airflow.utils.json import AirflowJsonEncoder

if TYPE_CHECKING:
Expand Down Expand Up @@ -1325,7 +1325,7 @@ def execute(self, context: Context) -> str:
"Description": self.description,
"Tags": format_tags(self.tags),
}
ans = sagemaker_hook.conn.create_experiment(**trim_none_values(params))
ans = sagemaker_hook.conn.create_experiment(**prune_dict(params))
arn = ans["ExperimentArn"]
self.log.info("Experiment %s created successfully with ARN %s.", self.name, arn)
return arn
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/secrets/secrets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from urllib.parse import unquote

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.secrets import BaseSecretsBackend
from airflow.utils.helpers import prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin


Expand Down Expand Up @@ -184,7 +184,7 @@ def client(self):

conn_id = f"{self.__class__.__name__}__connection"
conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs)
client_kwargs = trim_none_values(
client_kwargs = prune_dict(
{
"region_name": conn_config.region_name,
"verify": conn_config.verify,
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/secrets/systems_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import re
from functools import cached_property

from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.secrets import BaseSecretsBackend
from airflow.utils.helpers import prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin


Expand Down Expand Up @@ -118,7 +118,7 @@ def client(self):

conn_id = f"{self.__class__.__name__}__connection"
conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs)
client_kwargs = trim_none_values(
client_kwargs = prune_dict(
{
"region_name": conn_config.region_name,
"verify": conn_config.verify,
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/sensors/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.utils.helpers import prune_dict

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -71,7 +71,7 @@ def poke(self, context: Context) -> bool:
"FunctionName": self.function_name,
"Qualifier": self.qualifier,
}
state = self.hook.conn.get_function(**trim_none_values(get_function_args))["Configuration"]["State"]
state = self.hook.conn.get_function(**prune_dict(get_function_args))["Configuration"]["State"]

if state in self.FAILURE_STATES:
raise AirflowException(
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
from datetime import datetime
from enum import Enum

from deprecated import deprecated

from airflow.utils.helpers import prune_dict
from airflow.version import version

log = logging.getLogger(__name__)


@deprecated(reason="use prune_dict() instead")
vandonr-amz marked this conversation as resolved.
Show resolved Hide resolved
def trim_none_values(obj: dict):
return {key: val for key, val in obj.items() if val is not None}
return prune_dict(obj)
vandonr-amz marked this conversation as resolved.
Show resolved Hide resolved


def datetime_to_epoch(date_time: datetime) -> int:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/utils/connection_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from botocore.config import Config

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.utils.helpers import prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.secrets_masker import mask_secret
from airflow.utils.types import NOTSET, ArgNotSet
Expand Down Expand Up @@ -293,7 +293,7 @@ def extra_dejson(self):
@property
def session_kwargs(self) -> dict[str, Any]:
"""Additional kwargs passed to boto3.session.Session."""
return trim_none_values(
return prune_dict(
{
"aws_access_key_id": self.aws_access_key_id,
"aws_secret_access_key": self.aws_secret_access_key,
Expand Down
12 changes: 0 additions & 12 deletions tests/providers/amazon/aws/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
datetime_to_epoch_ms,
datetime_to_epoch_us,
get_airflow_version,
trim_none_values,
)

DT = datetime(2000, 1, 1, tzinfo=pytz.UTC)
Expand All @@ -37,17 +36,6 @@ class EnumTest(_StringCompareEnum):
FOO = "FOO"


def test_trim_none_values():
input_object = {
"test": "test",
"empty": None,
}
expected_output_object = {
"test": "test",
}
assert trim_none_values(input_object) == expected_output_object


def test_datetime_to_epoch():
assert datetime_to_epoch(DT) == EPOCH

Expand Down