Skip to content

Commit

Permalink
[AIRFLOW-1246] Setting a Subdag Task State Throws Exception
Browse files Browse the repository at this point in the history
    # This is a combination of 16 commits.
    # This is the 1st commit message:
    Add flask_admin_unescape function to airfow.www.utils

    # This is the commit message apache#2:

    Add flask_admin_unescape to TaskInstanceModelView for task_id and dag_id

    # This is the commit message apache#3:

    Add test_flask_admin_unescape to tests.www.test_utils

    # This is the commit message apache#4:

    Add TestTaskInstanceModelView to tests.www.test_views

        This test for the proper integration between Subdag Tasks and the
        flask_admin package. Because the flask_admin package uses '.'
        characters as it's escape character and airflow enforces the
        dag_id.subdag_id format for the dag_id's for task_id's in subdags,
        the escaping that flask_admin preforms causes a duplication of the
        '.' character. When perform the actions that are available for those
        task_instances via the ui, this causes an error. These tests are
        here to protect against that integration issue.

    # This is the commit message apache#5:

    Update airflow.www.utils to conform to flake8 standards

    # This is the commit message apache#6:

    Update tests.www.test_utils to conform to flake8 standards

    # This is the commit message apache#7:

    Update TestTaskInstanceModelView to Leverage assertEqual

    # This is the commit message apache#8:

    Remove Superfluous SubDAG Task Var TestTaskInstanceModelView

    # This is the commit message apache#9:

    Update airflow.www.utils.gizipped.view_func to match with flake8

    # This is the commit message apache#11:

    Update airflow.www._connection_ids to align with flake8 spacing

    # This is the commit message apache#12:

    Update tests.www.test_utils.test_flask_admin_unescape with proper spaces

    # This is the commit message apache#13:

    Align tests.www.test_views with flake8 standards

    # This is the commit message apache#14:

    Remove unnecessary SubDagOperator in tests.www.test_views

    # This is the commit message apache#15:

    Change Position on install_aliases call in www.utils for flake8

    # This is the commit message apache#16:

    Separate setting subdag state via flask_admin model tests
  • Loading branch information
Zack Lawson committed Jan 25, 2018
1 parent cbc02da commit 432fa19
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 18 deletions.
29 changes: 22 additions & 7 deletions airflow/www/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
#
from future import standard_library
standard_library.install_aliases()
from builtins import str
from builtins import object

Expand All @@ -30,11 +29,14 @@
from flask_login import current_user
import wtforms
from wtforms.compat import text_type
from flask_admin._compat import as_unicode
from flask_admin.tools import CHAR_ESCAPE, CHAR_SEPARATOR

from airflow import configuration, models, settings
from airflow import configuration, models
from airflow.utils.db import create_session
from airflow.utils import timezone
from airflow.utils.json import AirflowJsonEncoder
standard_library.install_aliases()

AUTHENTICATE = configuration.getboolean('webserver', 'AUTHENTICATE')

Expand Down Expand Up @@ -177,9 +179,9 @@ def is_current(current, page):
vals = {
'is_active': 'active' if is_current(current_page, page) else '',
'href_link': void_link if is_current(current_page, page)
else '?{}'.format(get_params(page=page,
search=search,
showPaused=showPaused)),
else '?{}'.format(get_params(page=page,
search=search,
showPaused=showPaused)),
'page_num': page + 1
}
output.append(page_node.format(**vals))
Expand Down Expand Up @@ -338,9 +340,11 @@ def zipper(response):

response.direct_passthrough = False

if (response.status_code < 200 or
if (
response.status_code < 200 or
response.status_code >= 300 or
'Content-Encoding' in response.headers):
'Content-Encoding' in response.headers
):
return response
gzip_buffer = IO()
gzip_file = gzip.GzipFile(mode='wb',
Expand Down Expand Up @@ -389,6 +393,17 @@ def __call__(self, field, **kwargs):
return wtforms.widgets.core.HTMLString(html)


def flask_admin_unescape(value):
"""
Function to back out the CHAR_ESCAPE, CHAR_SEPARATOR values created by flask_admin
:param value: flask_admin id or ids to unescape
:return: unicode
"""
return (as_unicode(value)
.replace(CHAR_ESCAPE + CHAR_ESCAPE, CHAR_ESCAPE)
.replace(CHAR_ESCAPE + CHAR_SEPARATOR, CHAR_SEPARATOR))


class UtcFilterConverter(FilterConverter):
@filters.convert('utcdatetime')
def conv_utcdatetime(self, column, name, **kwargs):
Expand Down
18 changes: 10 additions & 8 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,9 +1863,7 @@ def get_int_arg(value, default=0):
if hide_paused:
sql_query = sql_query.filter(~DM.is_paused)

orm_dags = {dag.dag_id: dag for dag
in sql_query
.all()}
orm_dags = {dag.dag_id: dag for dag in sql_query.all()}

import_errors = session.query(models.ImportError).all()
for ie in import_errors:
Expand Down Expand Up @@ -2081,11 +2079,11 @@ class SlaMissModelView(wwwutils.SuperUserMixin, ModelViewOnly):
@provide_session
def _connection_ids(session=None):
return [
(c.conn_id, c.conn_id)
for c in (
session.query(models.Connection.conn_id)
.group_by(models.Connection.conn_id)
)
(c.conn_id, c.conn_id)
for c in (
session.query(models.Connection.conn_id).group_by(
models.Connection.conn_id
))
]


Expand Down Expand Up @@ -2564,6 +2562,8 @@ def action_clear(self, ids, session=None):
# Collect dags upfront as dagbag.get_dag() will reset the session
for id_str in ids:
task_id, dag_id, execution_date = id_str.split(',')
task_id = wwwutils.flask_admin_unescape(task_id)
dag_id = wwwutils.flask_admin_unescape(dag_id)
dag = dagbag.get_dag(dag_id)
task_details = dag_to_task_details.setdefault(dag, [])
task_details.append((task_id, execution_date))
Expand Down Expand Up @@ -2599,6 +2599,8 @@ def set_task_instance_state(self, ids, target_state, session=None):
for id in ids:
task_id, dag_id, execution_date = id.split(',')
execution_date = parse_execution_date(execution_date)
task_id = wwwutils.flask_admin_unescape(task_id)
dag_id = wwwutils.flask_admin_unescape(dag_id)

ti = session.query(TI).filter(TI.task_id == task_id,
TI.dag_id == dag_id,
Expand Down
24 changes: 24 additions & 0 deletions tests/www/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,30 @@ def test_params_all(self):
self.assertEqual('page=3&search=bash_&showPaused=False',
utils.get_params(showPaused=False, page=3, search='bash_'))

def test_flask_admin_unescape(self):
flask_admin_dag_id_escape_char = 'dag{escape_char}{escape_char}sub_dag'.format(
escape_char=utils.CHAR_ESCAPE
)
correct_dag_id_escape_char = 'dag{escape_char}sub_dag'.format(
escape_char=utils.CHAR_ESCAPE
)
self.assertEqual(utils.flask_admin_unescape(
flask_admin_dag_id_escape_char),
correct_dag_id_escape_char
)
flask_admin_dag_id_separator_char = 'dag{escape_char}{separator_char}sub_dag'\
.format(
escape_char=utils.CHAR_ESCAPE,
separator_char=utils.CHAR_SEPARATOR
)
correct_dag_id_separator_char = 'dag{separator_char}sub_dag'.format(
separator_char=utils.CHAR_SEPARATOR
)
self.assertEqual(
utils.flask_admin_unescape(flask_admin_dag_id_separator_char),
correct_dag_id_separator_char
)


if __name__ == '__main__':
unittest.main()
91 changes: 88 additions & 3 deletions tests/www/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tempfile
import unittest
import sys
import flask_admin

from werkzeug.test import Client

Expand All @@ -31,6 +32,7 @@
from airflow.utils.timezone import datetime
from airflow.www import app as application
from airflow import configuration as conf
from airflow.utils.state import State


class TestChartModelView(unittest.TestCase):
Expand Down Expand Up @@ -136,9 +138,9 @@ def test_can_handle_error_on_decrypt(self):
self.assertEqual(response.status_code, 200)

# update the variable with a wrong value, given that is encrypted
Var = models.Variable
(self.session.query(Var)
.filter(Var.key == self.variable['key'])
var = models.Variable
(self.session.query(var)
.filter(var.key == self.variable['key'])
.update({
'val': 'failed_value_not_encrypted'
}, synchronize_session=False))
Expand Down Expand Up @@ -446,5 +448,88 @@ def test_mount(self):
self.assertIn(b"DAGs", resp_html)


class TestTaskInstanceModelView(unittest.TestCase):
DAG_ID = 'dag_for_testing_setting_subdag_task_instance_state_view'
SUBDAG_ID = 'task_for_testing_setting_subdag_task_instance_state_view'
SUBDAG_NAME = '{dag_id}.{subdag_id}'.format(dag_id=DAG_ID, subdag_id=SUBDAG_ID)
DEFAULT_DATE = datetime(2018, 1, 1)

@classmethod
def setUpClass(cls):
super(TestTaskInstanceModelView, cls).setUpClass()
session = Session()
session.query(TaskInstance).filter(
TaskInstance.dag_id == cls.DAG_ID and
TaskInstance.execution_date == cls.DEFAULT_DATE).delete()
session.commit()
session.close()

def setUp(self):
super(TestTaskInstanceModelView, self).setUp()
self.app = application.create_app(testing=True)
self.client = self.app.test_client()
self.session = Session()
subdag_dag = DAG(self.SUBDAG_NAME, start_date=self.DEFAULT_DATE)
subdag_subtask = DummyOperator(task_id='test', dag=subdag_dag)
self.subdag_dag_id = subdag_dag.dag_id
self.task_id = subdag_subtask.task_id
self.row_ids = ['{task_id},{dag_id},{execution_date}'.format(
dag_id=flask_admin.tools.escape(self.subdag_dag_id),
task_id=flask_admin.tools.escape(self.task_id),
execution_date=self.DEFAULT_DATE
)]
from airflow.www.views import TaskInstanceModelView
self.task_instance_model_view = TaskInstanceModelView(
TaskInstance, self.session, name="Task Instances", category="Browse"
)
ti = TaskInstance(task=subdag_subtask, execution_date=self.DEFAULT_DATE)
self.session.merge(ti)
self.session.commit()

def tearDown(self):
self.session.query(TaskInstance).filter(
TaskInstance.dag_id == self.subdag_dag_id and
TaskInstance.execution_date == self.DEFAULT_DATE).delete()
self.session.commit()
self.session.close()
super(TestTaskInstanceModelView, self).tearDown()

def test_set_subdag_state_failed(self):
with self.app.test_request_context():
self.task_instance_model_view.action_set_failed(self.row_ids)
self.assertEqual(self.session.query(TaskInstance).filter(
TaskInstance.dag_id == self.subdag_dag_id and
TaskInstance.task_id == self.task_id and
TaskInstance.execution_date == self.DEFAULT_DATE
).one().state, State.FAILED)

def test_set_subdag_state_up_for_retry(self):
with self.app.test_request_context():
self.task_instance_model_view.action_set_retry(self.row_ids)
self.assertEqual(self.session.query(TaskInstance).filter(
TaskInstance.dag_id == self.subdag_dag_id and
TaskInstance.task_id == self.task_id and
TaskInstance.execution_date == self.DEFAULT_DATE
).one().state, State.UP_FOR_RETRY)

def test_set_subdag_state_running(self):
with self.app.test_request_context():
self.task_instance_model_view.action_set_running(self.row_ids)
self.assertEqual(self.session.query(TaskInstance).filter(
TaskInstance.dag_id == self.subdag_dag_id and
TaskInstance.task_id == self.task_id and
TaskInstance.execution_date == self.DEFAULT_DATE
).one().state, State.RUNNING)

def test_set_subdag_state_success(self):
with self.app.test_request_context():
self.task_instance_model_view.action_set_success(self.row_ids)
self.assertEqual(self.session.query(TaskInstance).filter(
TaskInstance.dag_id == self.subdag_dag_id and
TaskInstance.task_id == self.task_id and
TaskInstance.execution_date == self.DEFAULT_DATE
).one().state, State.SUCCESS)


if __name__ == '__main__':
unittest.main()

0 comments on commit 432fa19

Please sign in to comment.