diff --git a/airflow/operators/subdag_operator.py b/airflow/operators/subdag_operator.py index c56e7afc54066..a44e210de480b 100644 --- a/airflow/operators/subdag_operator.py +++ b/airflow/operators/subdag_operator.py @@ -1,6 +1,21 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from airflow.exceptions import AirflowException -from airflow.models import BaseOperator +from airflow.models import BaseOperator, Pool from airflow.utils.decorators import apply_defaults +from airflow.utils.db import provide_session from airflow.executors import DEFAULT_EXECUTOR @@ -10,6 +25,7 @@ class SubDagOperator(BaseOperator): ui_color = '#555' ui_fgcolor = '#fff' + @provide_session @apply_defaults def __init__( self, @@ -28,13 +44,40 @@ def __init__( if 'dag' not in kwargs: raise AirflowException("Please pass in the `dag` param") dag = kwargs['dag'] + session = kwargs.pop('session') super(SubDagOperator, self).__init__(*args, **kwargs) + + # validate subdag name if dag.dag_id + '.' + kwargs['task_id'] != subdag.dag_id: raise AirflowException( "The subdag's dag_id should have the form " "'{{parent_dag_id}}.{{this_task_id}}'. Expected " "'{d}.{t}'; received '{rcvd}'.".format( d=dag.dag_id, t=kwargs['task_id'], rcvd=subdag.dag_id)) + + # validate that subdag operator and subdag tasks don't have a + # pool conflict + if self.pool: + pool = ( + session + .query(Pool) + .filter(Pool.slots == 1) + .filter(Pool.pool == self.pool) + .first() + ) + conflicts = [t for t in subdag.tasks if t.pool == self.pool] + if pool and any(t.pool == self.pool for t in subdag.tasks): + raise AirflowException( + 'SubDagOperator {sd} and subdag task{plural} {t} both use ' + 'pool {p}, but the pool only has 1 slot. The subdag tasks' + 'will never run.'.format( + sd=self.task_id, + plural=len(conflicts) > 1, + t=', '.join(t.task_id for t in conflicts), + p=self.pool + ) + ) + self.subdag = subdag self.executor = executor diff --git a/airflow/utils/db.py b/airflow/utils/db.py index d3a69db57e312..430026ea2b2bc 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -199,6 +199,7 @@ def initdb(): "GROUP BY state"), ) session.add(chart) + session.commit() def upgradedb(): diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py index b4e61123f7514..98a17a788a8a9 100644 --- a/tests/operators/__init__.py +++ b/tests/operators/__init__.py @@ -1 +1,2 @@ from .docker_operator import * +from .subdag_operator import * diff --git a/tests/operators/subdag_operator.py b/tests/operators/subdag_operator.py new file mode 100644 index 0000000000000..2de1ee1b25b56 --- /dev/null +++ b/tests/operators/subdag_operator.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from datetime import datetime + +import airflow +from airflow import DAG +from airflow.operators import DummyOperator +from airflow.operators.subdag_operator import SubDagOperator +from airflow.exceptions import AirflowException + +default_args = dict( + owner='airflow', + start_date=datetime(2016, 1, 1), +) + +class SubDagOperatorTests(unittest.TestCase): + + def test_subdag_name(self): + """ + Subdag names must be {parent_dag}.{subdag task} + """ + dag = DAG('parent', default_args=default_args) + subdag_good = DAG('parent.test', default_args=default_args) + subdag_bad1 = DAG('parent.bad', default_args=default_args) + subdag_bad2 = DAG('bad.test', default_args=default_args) + subdag_bad3 = DAG('bad.bad', default_args=default_args) + + SubDagOperator(task_id='test', dag=dag, subdag=subdag_good) + self.assertRaises( + AirflowException, + SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad1) + self.assertRaises( + AirflowException, + SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad2) + self.assertRaises( + AirflowException, + SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad3) + + def test_subdag_pools(self): + """ + Subdags and subdag tasks can't both have a pool with 1 slot + """ + dag = DAG('parent', default_args=default_args) + subdag = DAG('parent.test', default_args=default_args) + + session = airflow.settings.Session() + pool_1 = airflow.models.Pool(pool='test_pool_1', slots=1) + pool_10 = airflow.models.Pool(pool='test_pool_10', slots=10) + session.add(pool_1) + session.add(pool_10) + session.commit() + + dummy_1 = DummyOperator(task_id='dummy', dag=subdag, pool='test_pool_1') + + self.assertRaises( + AirflowException, + SubDagOperator, + task_id='test', dag=dag, subdag=subdag, pool='test_pool_1') + + # recreate dag because failed subdagoperator was already added + dag = DAG('parent', default_args=default_args) + SubDagOperator( + task_id='test', dag=dag, subdag=subdag, pool='test_pool_10') + + session.delete(pool_1) + session.delete(pool_10) + session.commit() + + +if __name__ == "__main__": + unittest.main()