Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-867] Enable and fix lots of untested unit tests #2078

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/contrib/operators/dataflow_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def google_cloud_to_local(self, file_name):
# Extracts bucket_id and object_id by first removing 'gs://' prefix and
# then split the remaining by path delimiter '/'.
path_components = file_name[self.GCS_PREFIX_LENGTH:].split('/')
if path_components < 2:
if len(path_components) < 2:
raise Exception(
'Invalid Google Cloud Storage (GCS) object path: {}.'
.format(file_name))
Expand Down
2 changes: 1 addition & 1 deletion airflow/contrib/operators/ecs_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def execute(self, context):

def _wait_for_task_ended(self):
waiter = self.client.get_waiter('tasks_stopped')
waiter.config.max_attempts = sys.maxint # timeout is managed by airflow
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
waiter.wait(
cluster=self.cluster,
tasks=[self.arn]
Expand Down
2 changes: 0 additions & 2 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,6 @@ def are_dependencies_met(
:param verbose: whether or not to print details on failed dependencies
:type verbose: boolean
"""
dep_context = dep_context or DepContext()
failed = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this line?

for dep_status in self.get_failed_dep_statuses(
dep_context=dep_context,
Expand All @@ -1131,7 +1130,6 @@ def get_failed_dep_statuses(
self,
dep_context=None,
session=None):
dep_context = dep_context or DepContext()
for dep in dep_context.deps | self.task.deps:
for dep_status in dep.get_dep_statuses(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above (why remove this line?)

self,
Expand Down
2 changes: 1 addition & 1 deletion airflow/operators/latest_only_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LatestOnlyOperator(BaseOperator):
def execute(self, context):
# If the DAG Run is externally triggered, then return without
# skipping downstream tasks
if context['dag_run'].external_trigger:
if context['dag_run'] and context['dag_run'].external_trigger:
logging.info("""Externally triggered DAG_Run:
allowing execution to proceed.""")
return
Expand Down
10 changes: 7 additions & 3 deletions airflow/ti_deps/deps/base_ti_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _get_dep_statuses(self, ti, session, dep_context):
raise NotImplementedError

@provide_session
def get_dep_statuses(self, ti, session, dep_context):
def get_dep_statuses(self, ti, session, dep_context=None):
"""
Wrapper around the private _get_dep_statuses method that contains some global
checks for all dependencies.
Expand All @@ -81,6 +81,10 @@ def get_dep_statuses(self, ti, session, dep_context):
:param dep_context: the context for which this dependency should be evaluated for
:type dep_context: DepContext
"""
from airflow.ti_deps.dep_context import DepContext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid imports at local scope. I think the previous implementation was OK but it seems you have some concerns?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation is probably clearer in the original commit before the squash. While it's possible to make dep_context optional without a local import, it requires further refactoring and this PR is already big enough so I opted for the less invasive option. But feel free to refactor it or revert the previous signature (and update the dozens previously broken tests that were passing None) if you feel strongly about it.

if dep_context is None:
dep_context = DepContext()

if self.IGNOREABLE and dep_context.ignore_all_deps:
yield self._passing_status(
reason="Context specified all dependencies should be ignored.")
Expand All @@ -95,7 +99,7 @@ def get_dep_statuses(self, ti, session, dep_context):
yield dep_status

@provide_session
def is_met(self, ti, session, dep_context):
def is_met(self, ti, session, dep_context=None):
"""
Returns whether or not this dependency is met for a given task instance. A
dependency is considered met if all of the dependency statuses it reports are
Expand All @@ -113,7 +117,7 @@ def is_met(self, ti, session, dep_context):
self.get_dep_statuses(ti, session, dep_context))

@provide_session
def get_failure_reasons(self, ti, session, dep_context):
def get_failure_reasons(self, ti, session, dep_context=None):
"""
Returns an iterable of strings that explain why this dependency wasn't met.

Expand Down
5 changes: 5 additions & 0 deletions scripts/ci/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
alembic
bcrypt
boto
boto3
celery
cgroupspy
chartkick
Expand All @@ -9,6 +10,7 @@ coverage
coveralls
croniter
cryptography
datadog
dill
distributed
docker-py
Expand All @@ -21,7 +23,9 @@ flask-cache
flask-login==0.2.11
Flask-WTF
flower
freezegun
future
google-api-python-client
gunicorn
hdfs
hive-thrift-py
Expand All @@ -34,6 +38,7 @@ ldap3
lxml
markdown
mock
moto
mysqlclient
nose
nose-exclude
Expand Down
13 changes: 0 additions & 13 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

from .api import *
from .configuration import *
from .contrib import *
from .core import *
from .executors import *
from .jobs import *
from .impersonation import *
from .models import *
from .operators import *
from .utils import *
6 changes: 0 additions & 6 deletions tests/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

from .client import *
from .common import *

File renamed without changes.
File renamed without changes.
4 changes: 0 additions & 4 deletions tests/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from .operators import *
from .sensors import *
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,20 @@
#

import unittest

import boto3
from moto import mock_emr

from airflow import configuration
from airflow.contrib.hooks.aws_hook import AwsHook


try:
from moto import mock_emr
except ImportError:
mock_emr = None


class TestAwsHook(unittest.TestCase):

@mock_emr
def setUp(self):
configuration.load_test_config()

@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@mock_emr
def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
client = boto3.client('emr', region_name='us-east-1')
Expand All @@ -42,6 +38,3 @@ def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
client_from_hook = hook.get_client_type('emr')

self.assertEqual(client_from_hook.list_clusters()['Clusters'], [])

if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,15 @@ def test_invalid_syntax_tiple_dot_var(self):
self.assertIn('Format exception for var_x:',
str(context.exception), "")


class TestBigQueryHookSourceFormat(unittest.TestCase):
def test_invalid_source_format(self):
with self.assertRaises(Exception) as context:
hook.BigQueryBaseCursor("test", "test").run_load("test.test", "test_schema.json", ["test_data.json"], source_format="json")

# since we passed 'json' in, and it's not valid, make sure it's present in the error string.
self.assertIn("json", str(context.exception))
hook.BigQueryBaseCursor("test", "test").run_load("test.test",
["test_schema.json"],
["test_data.json"],
source_format="json")
self.assertIn("JSON", str(context.exception))


class TestBigQueryBaseCursor(unittest.TestCase):
Expand All @@ -134,6 +136,3 @@ def test_invalid_schema_update_and_write_disposition(self):
write_disposition='WRITE_EMPTY'
)
self.assertIn("schema_update_options is only", str(context.exception))

if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,25 @@
#

import unittest

import boto3
from moto import mock_emr

from airflow import configuration
from airflow.contrib.hooks.emr_hook import EmrHook


try:
from moto import mock_emr
except ImportError:
mock_emr = None


class TestEmrHook(unittest.TestCase):

@mock_emr
def setUp(self):
configuration.load_test_config()

@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@mock_emr
def test_get_conn_returns_a_boto3_connection(self):
hook = EmrHook(aws_conn_id='aws_default')
self.assertIsNotNone(hook.get_conn().list_clusters())

@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@mock_emr
def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self):
client = boto3.client('emr', region_name='us-east-1')
Expand All @@ -47,7 +42,5 @@ def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self):
hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default')
cluster = hook.create_job_flow({'Name': 'test_cluster'})

self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], cluster['JobFlowId'])

if __name__ == '__main__':
unittest.main()
self.assertEqual(client.list_clusters()['Clusters'][0]['Id'],
cluster['JobFlowId'])
4 changes: 0 additions & 4 deletions tests/contrib/hooks/test_ftp_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,3 @@ def test_rename(self):

self.conn_mock.rename.assert_called_once_with(from_path, to_path)
self.conn_mock.quit.assert_called_once_with()


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@
#

import unittest
from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook

try:
from unittest import mock
except ImportError:
try:
import mock
except ImportError:
mock = None
import mock

from airflow.models import Connection
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook

TASK_ID = 'test-python-dataflow'
PY_FILE = 'apache_beam.examples.wordcount'
Expand All @@ -32,24 +27,15 @@
'project': 'test',
'staging_location': 'gs://test/staging'
}
BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}'
DATAFLOW_STRING = 'airflow.contrib.hooks.gcp_dataflow_hook.{}'


def mock_init(self, gcp_conn_id, delegate_to=None):
pass


class DataFlowHookTest(unittest.TestCase):

def setUp(self):
with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
new=mock_init):
self.dataflow_hook = DataFlowHook(gcp_conn_id='test')

@mock.patch(DATAFLOW_STRING.format('DataFlowHook._start_dataflow'))
@mock.patch('airflow.contrib.hooks.gcp_dataflow_hook.DataFlowHook._start_dataflow')
@mock.patch.object(GoogleCloudBaseHook, 'get_connection', Connection)
def test_start_python_dataflow(self, internal_dataflow_mock):
self.dataflow_hook.start_python_dataflow(
dataflow_hook = DataFlowHook(gcp_conn_id='test')
dataflow_hook.start_python_dataflow(
task_id=TASK_ID, variables=OPTIONS,
dataflow=PY_FILE, py_options=PY_OPTIONS)
internal_dataflow_mock.assert_called_once_with(
Expand Down
10 changes: 3 additions & 7 deletions tests/contrib/hooks/test_jira_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from airflow import models
from airflow.utils import db

jira_client_mock = Mock(
name="jira_client"
)

jira_client_mock = Mock(name="jira_client")


class TestJiraHook(unittest.TestCase):

def setUp(self):
configuration.load_test_config()
db.merge_conn(
Expand All @@ -45,7 +45,3 @@ def test_jira_client_connection(self, jira_mock):
self.assertTrue(jira_mock.called)
self.assertIsInstance(jira_hook.client, Mock)
self.assertEqual(jira_hook.client.name, jira_mock.return_value.name)


if __name__ == '__main__':
unittest.main()
5 changes: 0 additions & 5 deletions tests/contrib/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import absolute_import
from .ssh_execute_operator import *
from .fs_operator import *
Loading