diff --git a/airflow/www/utils.py b/airflow/www/utils.py index a0833ee0690c6..bc4a679e78dc9 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -13,7 +13,6 @@ # limitations under the License. # from future import standard_library -standard_library.install_aliases() from builtins import str from builtins import object @@ -30,11 +29,14 @@ from flask_login import current_user import wtforms from wtforms.compat import text_type +from flask_admin._compat import as_unicode +from flask_admin.tools import CHAR_ESCAPE, CHAR_SEPARATOR -from airflow import configuration, models, settings +from airflow import configuration, models from airflow.utils.db import create_session from airflow.utils import timezone from airflow.utils.json import AirflowJsonEncoder +standard_library.install_aliases() AUTHENTICATE = configuration.getboolean('webserver', 'AUTHENTICATE') @@ -177,9 +179,9 @@ def is_current(current, page): vals = { 'is_active': 'active' if is_current(current_page, page) else '', 'href_link': void_link if is_current(current_page, page) - else '?{}'.format(get_params(page=page, - search=search, - showPaused=showPaused)), + else '?{}'.format(get_params(page=page, + search=search, + showPaused=showPaused)), 'page_num': page + 1 } output.append(page_node.format(**vals)) @@ -338,9 +340,11 @@ def zipper(response): response.direct_passthrough = False - if (response.status_code < 200 or + if ( + response.status_code < 200 or response.status_code >= 300 or - 'Content-Encoding' in response.headers): + 'Content-Encoding' in response.headers + ): return response gzip_buffer = IO() gzip_file = gzip.GzipFile(mode='wb', @@ -389,6 +393,17 @@ def __call__(self, field, **kwargs): return wtforms.widgets.core.HTMLString(html) +def flask_admin_unescape(value): + """ + Function to back out the CHAR_ESCAPE, CHAR_SEPARATOR values created by flask_admin + :param value: flask_admin id or ids to unescape + :return: unicode + """ + return (as_unicode(value) + .replace(CHAR_ESCAPE + CHAR_ESCAPE, CHAR_ESCAPE) + .replace(CHAR_ESCAPE + CHAR_SEPARATOR, CHAR_SEPARATOR)) + + class UtcFilterConverter(FilterConverter): @filters.convert('utcdatetime') def conv_utcdatetime(self, column, name, **kwargs): diff --git a/airflow/www/views.py b/airflow/www/views.py index 252241ad21bf2..eb3d7fdab08f3 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1863,9 +1863,7 @@ def get_int_arg(value, default=0): if hide_paused: sql_query = sql_query.filter(~DM.is_paused) - orm_dags = {dag.dag_id: dag for dag - in sql_query - .all()} + orm_dags = {dag.dag_id: dag for dag in sql_query.all()} import_errors = session.query(models.ImportError).all() for ie in import_errors: @@ -2081,11 +2079,11 @@ class SlaMissModelView(wwwutils.SuperUserMixin, ModelViewOnly): @provide_session def _connection_ids(session=None): return [ - (c.conn_id, c.conn_id) - for c in ( - session.query(models.Connection.conn_id) - .group_by(models.Connection.conn_id) - ) + (c.conn_id, c.conn_id) + for c in ( + session.query(models.Connection.conn_id).group_by( + models.Connection.conn_id + )) ] @@ -2564,6 +2562,8 @@ def action_clear(self, ids, session=None): # Collect dags upfront as dagbag.get_dag() will reset the session for id_str in ids: task_id, dag_id, execution_date = id_str.split(',') + task_id = wwwutils.flask_admin_unescape(task_id) + dag_id = wwwutils.flask_admin_unescape(dag_id) dag = dagbag.get_dag(dag_id) task_details = dag_to_task_details.setdefault(dag, []) task_details.append((task_id, execution_date)) @@ -2599,6 +2599,8 @@ def set_task_instance_state(self, ids, target_state, session=None): for id in ids: task_id, dag_id, execution_date = id.split(',') execution_date = parse_execution_date(execution_date) + task_id = wwwutils.flask_admin_unescape(task_id) + dag_id = wwwutils.flask_admin_unescape(dag_id) ti = session.query(TI).filter(TI.task_id == task_id, TI.dag_id == dag_id, diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py index d13a49db6280c..40872fe91a1da 100644 --- a/tests/www/test_utils.py +++ b/tests/www/test_utils.py @@ -104,6 +104,30 @@ def test_params_all(self): self.assertEqual('page=3&search=bash_&showPaused=False', utils.get_params(showPaused=False, page=3, search='bash_')) + def test_flask_admin_unescape(self): + flask_admin_dag_id_escape_char = 'dag{escape_char}{escape_char}sub_dag'.format( + escape_char=utils.CHAR_ESCAPE + ) + correct_dag_id_escape_char = 'dag{escape_char}sub_dag'.format( + escape_char=utils.CHAR_ESCAPE + ) + self.assertEqual(utils.flask_admin_unescape( + flask_admin_dag_id_escape_char), + correct_dag_id_escape_char + ) + flask_admin_dag_id_separator_char = 'dag{escape_char}{separator_char}sub_dag'\ + .format( + escape_char=utils.CHAR_ESCAPE, + separator_char=utils.CHAR_SEPARATOR + ) + correct_dag_id_separator_char = 'dag{separator_char}sub_dag'.format( + separator_char=utils.CHAR_SEPARATOR + ) + self.assertEqual( + utils.flask_admin_unescape(flask_admin_dag_id_separator_char), + correct_dag_id_separator_char + ) + if __name__ == '__main__': unittest.main() diff --git a/tests/www/test_views.py b/tests/www/test_views.py index ff20333303197..dcd23cea5205f 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -20,6 +20,7 @@ import tempfile import unittest import sys +import flask_admin from werkzeug.test import Client @@ -31,6 +32,7 @@ from airflow.utils.timezone import datetime from airflow.www import app as application from airflow import configuration as conf +from airflow.utils.state import State class TestChartModelView(unittest.TestCase): @@ -136,9 +138,9 @@ def test_can_handle_error_on_decrypt(self): self.assertEqual(response.status_code, 200) # update the variable with a wrong value, given that is encrypted - Var = models.Variable - (self.session.query(Var) - .filter(Var.key == self.variable['key']) + var = models.Variable + (self.session.query(var) + .filter(var.key == self.variable['key']) .update({ 'val': 'failed_value_not_encrypted' }, synchronize_session=False)) @@ -446,5 +448,88 @@ def test_mount(self): self.assertIn(b"DAGs", resp_html) +class TestTaskInstanceModelView(unittest.TestCase): + DAG_ID = 'dag_for_testing_setting_subdag_task_instance_state_view' + SUBDAG_ID = 'task_for_testing_setting_subdag_task_instance_state_view' + SUBDAG_NAME = '{dag_id}.{subdag_id}'.format(dag_id=DAG_ID, subdag_id=SUBDAG_ID) + DEFAULT_DATE = datetime(2018, 1, 1) + + @classmethod + def setUpClass(cls): + super(TestTaskInstanceModelView, cls).setUpClass() + session = Session() + session.query(TaskInstance).filter( + TaskInstance.dag_id == cls.DAG_ID and + TaskInstance.execution_date == cls.DEFAULT_DATE).delete() + session.commit() + session.close() + + def setUp(self): + super(TestTaskInstanceModelView, self).setUp() + self.app = application.create_app(testing=True) + self.client = self.app.test_client() + self.session = Session() + subdag_dag = DAG(self.SUBDAG_NAME, start_date=self.DEFAULT_DATE) + subdag_subtask = DummyOperator(task_id='test', dag=subdag_dag) + self.subdag_dag_id = subdag_dag.dag_id + self.task_id = subdag_subtask.task_id + self.row_ids = ['{task_id},{dag_id},{execution_date}'.format( + dag_id=flask_admin.tools.escape(self.subdag_dag_id), + task_id=flask_admin.tools.escape(self.task_id), + execution_date=self.DEFAULT_DATE + )] + from airflow.www.views import TaskInstanceModelView + self.task_instance_model_view = TaskInstanceModelView( + TaskInstance, self.session, name="Task Instances", category="Browse" + ) + ti = TaskInstance(task=subdag_subtask, execution_date=self.DEFAULT_DATE) + self.session.merge(ti) + self.session.commit() + + def tearDown(self): + self.session.query(TaskInstance).filter( + TaskInstance.dag_id == self.subdag_dag_id and + TaskInstance.execution_date == self.DEFAULT_DATE).delete() + self.session.commit() + self.session.close() + super(TestTaskInstanceModelView, self).tearDown() + + def test_set_subdag_state_failed(self): + with self.app.test_request_context(): + self.task_instance_model_view.action_set_failed(self.row_ids) + self.assertEqual(self.session.query(TaskInstance).filter( + TaskInstance.dag_id == self.subdag_dag_id and + TaskInstance.task_id == self.task_id and + TaskInstance.execution_date == self.DEFAULT_DATE + ).one().state, State.FAILED) + + def test_set_subdag_state_up_for_retry(self): + with self.app.test_request_context(): + self.task_instance_model_view.action_set_retry(self.row_ids) + self.assertEqual(self.session.query(TaskInstance).filter( + TaskInstance.dag_id == self.subdag_dag_id and + TaskInstance.task_id == self.task_id and + TaskInstance.execution_date == self.DEFAULT_DATE + ).one().state, State.UP_FOR_RETRY) + + def test_set_subdag_state_running(self): + with self.app.test_request_context(): + self.task_instance_model_view.action_set_running(self.row_ids) + self.assertEqual(self.session.query(TaskInstance).filter( + TaskInstance.dag_id == self.subdag_dag_id and + TaskInstance.task_id == self.task_id and + TaskInstance.execution_date == self.DEFAULT_DATE + ).one().state, State.RUNNING) + + def test_set_subdag_state_success(self): + with self.app.test_request_context(): + self.task_instance_model_view.action_set_success(self.row_ids) + self.assertEqual(self.session.query(TaskInstance).filter( + TaskInstance.dag_id == self.subdag_dag_id and + TaskInstance.task_id == self.task_id and + TaskInstance.execution_date == self.DEFAULT_DATE + ).one().state, State.SUCCESS) + + if __name__ == '__main__': unittest.main()