Skip to content

Commit

Permalink
[misc] Replace XOR ^ conditions by exactly_one helper in providers (
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Dec 3, 2022
1 parent 51c70a5 commit 527b948
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 45 deletions.
3 changes: 2 additions & 1 deletion airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from airflow.utils.helpers import exactly_one
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.secrets_masker import mask_secret

Expand Down Expand Up @@ -493,7 +494,7 @@ def conn(self) -> BaseAwsConnection:
:return: boto3.client or boto3.resource
"""
if not ((not self.client_type) ^ (not self.resource_type)):
if not exactly_one(self.client_type, self.resource_type):
raise ValueError(
f"Either client_type={self.client_type!r} or "
f"resource_type={self.resource_type!r} must be provided, not both."
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink
from airflow.utils.helpers import exactly_one

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(
wait_for_completion: bool = False,
**kwargs,
):
if not (job_flow_id is None) ^ (job_flow_name is None):
if not exactly_one(job_flow_id is None, job_flow_name is None):
raise AirflowException("Exactly one of job_flow_id or job_flow_name must be specified.")
super().__init__(**kwargs)
cluster_states = cluster_states or []
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/operators/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.helpers import exactly_one

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -463,11 +464,11 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.verify = verify

if not bool(keys is None) ^ bool(prefix is None):
if not exactly_one(prefix is None, keys is None):
raise AirflowException("Either keys or prefix should be set.")

def execute(self, context: Context):
if not bool(self.keys is None) ^ bool(self.prefix is None):
if not exactly_one(self.keys is None, self.prefix is None):
raise AirflowException("Either keys or prefix should be set.")

if isinstance(self.keys, (list, str)) and not bool(self.keys):
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/operators/cloud_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from airflow.providers.google.cloud.triggers.cloud_build import CloudBuildCreateBuildTrigger
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.utils import yaml
from airflow.utils.helpers import exactly_one

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -971,7 +972,7 @@ def __init__(self, build: dict | Build) -> None:
self.build = deepcopy(build)

def _verify_source(self) -> None:
if not (("storage_source" in self.build["source"]) ^ ("repo_source" in self.build["source"])):
if not exactly_one("storage_source" in self.build["source"], "repo_source" in self.build["source"]):
raise AirflowException(
"The source could not be determined. Please choose one data source from: "
"storage_source and repo_source."
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/slack/hooks/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.providers.slack.utils import ConnectionExtraConfig
from airflow.utils.helpers import exactly_one
from airflow.utils.log.secrets_masker import mask_secret

if TYPE_CHECKING:
Expand Down Expand Up @@ -268,7 +269,7 @@ def send_file(
- `Slack API files.upload method <https://api.slack.com/methods/files.upload>`_
- `File types <https://api.slack.com/types/file#file_types>`_
"""
if not ((not file) ^ (not content)):
if not exactly_one(file, content):
raise ValueError("Either `file` or `content` must be provided, not both.")
elif file:
file = Path(file)
Expand Down
21 changes: 18 additions & 3 deletions tests/providers/amazon/aws/operators/test_emr_add_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import json
import os
import unittest
from datetime import timedelta
from unittest.mock import MagicMock, call, patch

Expand All @@ -41,7 +40,7 @@
)


class TestEmrAddStepsOperator(unittest.TestCase):
class TestEmrAddStepsOperator:
# When
_config = [
{
Expand All @@ -54,7 +53,7 @@ class TestEmrAddStepsOperator(unittest.TestCase):
}
]

def setUp(self):
def setup_method(self):
self.args = {"owner": "airflow", "start_date": DEFAULT_DATE}

# Mock out the emr_client (moto has incorrect response)
Expand All @@ -79,6 +78,22 @@ def test_init(self):
assert self.operator.job_flow_id == "j-8989898989"
assert self.operator.aws_conn_id == "aws_default"

@pytest.mark.parametrize(
"job_flow_id, job_flow_name",
[
pytest.param("j-8989898989", "test_cluster", id="both-specified"),
pytest.param(None, None, id="both-none"),
],
)
def test_validate_mutually_exclusive_args(self, job_flow_id, job_flow_name):
error_message = r"Exactly one of job_flow_id or job_flow_name must be specified\."
with pytest.raises(AirflowException, match=error_message):
EmrAddStepsOperator(
task_id="test_validate_mutually_exclusive_args",
job_flow_id=job_flow_id,
job_flow_name=job_flow_name,
)

def test_render_template(self):
dag_run = DagRun(dag_id=self.operator.dag.dag_id, execution_date=DEFAULT_DATE, run_id="test")
ti = TaskInstance(task=self.operator)
Expand Down
74 changes: 39 additions & 35 deletions tests/providers/amazon/aws/operators/test_s3_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest import mock

import boto3
import pytest
from moto import mock_s3

from airflow import AirflowException
Expand Down Expand Up @@ -95,8 +96,8 @@ def test_s3_copy_object_arg_combination_2(self):
assert objects_in_dest_bucket["Contents"][0]["Key"] == self.dest_key


class TestS3DeleteObjectsOperator(unittest.TestCase):
@mock_s3
@mock_s3
class TestS3DeleteObjectsOperator:
def test_s3_delete_single_object(self):
bucket = "testbucket"
key = "path/data.txt"
Expand All @@ -116,7 +117,6 @@ def test_s3_delete_single_object(self):
# There should be no object found in the bucket created earlier
assert "Contents" not in conn.list_objects(Bucket=bucket, Prefix=key)

@mock_s3
def test_s3_delete_multiple_objects(self):
bucket = "testbucket"
key_pattern = "path/data"
Expand All @@ -139,7 +139,6 @@ def test_s3_delete_multiple_objects(self):
# There should be no object found in the bucket created earlier
assert "Contents" not in conn.list_objects(Bucket=bucket, Prefix=key_pattern)

@mock_s3
def test_s3_delete_prefix(self):
bucket = "testbucket"
key_pattern = "path/data"
Expand All @@ -162,7 +161,6 @@ def test_s3_delete_prefix(self):
# There should be no object found in the bucket created earlier
assert "Contents" not in conn.list_objects(Bucket=bucket, Prefix=key_pattern)

@mock_s3
def test_s3_delete_empty_list(self):
bucket = "testbucket"
key_of_test = "path/data.txt"
Expand All @@ -185,7 +183,6 @@ def test_s3_delete_empty_list(self):
# the object found should be consistent with dest_key specified earlier
assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test

@mock_s3
def test_s3_delete_empty_string(self):
bucket = "testbucket"
key_of_test = "path/data.txt"
Expand All @@ -208,50 +205,57 @@ def test_s3_delete_empty_string(self):
# the object found should be consistent with dest_key specified earlier
assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test

@mock_s3
def test_assert_s3_both_keys_and_prifix_given(self):
bucket = "testbucket"
keys = "path/data.txt"
key_pattern = "path/data"

conn = boto3.client("s3")
conn.create_bucket(Bucket=bucket)
conn.upload_fileobj(Bucket=bucket, Key=keys, Fileobj=io.BytesIO(b"input"))

# The object should be detected before the DELETE action is tested
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=keys)
assert len(objects_in_dest_bucket["Contents"]) == 1
assert objects_in_dest_bucket["Contents"][0]["Key"] == keys
with self.assertRaises(AirflowException):
op = S3DeleteObjectsOperator(
task_id="test_assert_s3_both_keys_and_prifix_given",
bucket=bucket,
@pytest.mark.parametrize(
"keys, prefix",
[
pytest.param("path/data.txt", "path/data", id="single-key-and-prefix"),
pytest.param(["path/data.txt"], "path/data", id="multiple-keys-and-prefix"),
pytest.param(None, None, id="both-none"),
],
)
def test_validate_keys_and_prefix_in_constructor(self, keys, prefix):
with pytest.raises(AirflowException, match=r"Either keys or prefix should be set\."):
S3DeleteObjectsOperator(
task_id="test_validate_keys_and_prefix_in_constructor",
bucket="foo-bar-bucket",
keys=keys,
prefix=key_pattern,
prefix=prefix,
)
op.execute(None)

# The object found in the bucket created earlier should still be there
assert len(objects_in_dest_bucket["Contents"]) == 1
# the object found should be consistent with dest_key specified earlier
assert objects_in_dest_bucket["Contents"][0]["Key"] == keys

@mock_s3
def test_assert_s3_no_keys_or_prifix_given(self):
@pytest.mark.parametrize(
"keys, prefix",
[
pytest.param("path/data.txt", "path/data", id="single-key-and-prefix"),
pytest.param(["path/data.txt"], "path/data", id="multiple-keys-and-prefix"),
pytest.param(None, None, id="both-none"),
],
)
def test_validate_keys_and_prefix_in_execute(self, keys, prefix):
bucket = "testbucket"
key_of_test = "path/data.txt"

conn = boto3.client("s3")
conn.create_bucket(Bucket=bucket)
conn.upload_fileobj(Bucket=bucket, Key=key_of_test, Fileobj=io.BytesIO(b"input"))

# Set valid values for constructor, and change them later for emulate rendering template
op = S3DeleteObjectsOperator(
task_id="test_validate_keys_and_prefix_in_execute",
bucket=bucket,
keys="keys-exists",
prefix=None,
)
op.keys = keys
op.prefix = prefix

# The object should be detected before the DELETE action is tested
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key_of_test)
assert len(objects_in_dest_bucket["Contents"]) == 1
assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test
with self.assertRaises(AirflowException):
op = S3DeleteObjectsOperator(task_id="test_assert_s3_no_keys_or_prifix_given", bucket=bucket)

with pytest.raises(AirflowException, match=r"Either keys or prefix should be set\."):
op.execute(None)

# The object found in the bucket created earlier should still be there
assert len(objects_in_dest_bucket["Contents"]) == 1
# the object found should be consistent with dest_key specified earlier
Expand Down
6 changes: 5 additions & 1 deletion tests/providers/google/cloud/operators/test_cloud_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,13 @@ def test_update_build_trigger(self, mock_hook):

class TestBuildProcessor(TestCase):
def test_verify_source(self):
with pytest.raises(AirflowException, match="The source could not be determined."):
error_message = r"The source could not be determined."
with pytest.raises(AirflowException, match=error_message):
BuildProcessor(build={"source": {"storage_source": {}, "repo_source": {}}}).process_body()

with pytest.raises(AirflowException, match=error_message):
BuildProcessor(build={"source": {}}).process_body()

@parameterized.expand(
[
(
Expand Down

0 comments on commit 527b948

Please sign in to comment.