diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index f995297b059ed..c838b4c21fecb 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -244,8 +244,14 @@ def get_dag_runs_batch(session): @provide_session def post_dag_run(dag_id, session): """Trigger a DAG.""" - if not session.query(DagModel).filter(DagModel.dag_id == dag_id).first(): + dm = session.query(DagModel).filter(DagModel.dag_id == dag_id).first() + if not dm: raise NotFound(title="DAG not found", detail=f"DAG with dag_id: '{dag_id}' not found") + if dm.has_import_errors: + raise BadRequest( + title="DAG cannot be triggered", + detail=f"DAG with dag_id: '{dag_id}' has import errors", + ) try: post_body = dagrun_schema.load(request.json, session=session) except ValidationError as err: diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 63856e407e54a..e59c818f20534 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -544,6 +544,12 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None: # Add the errors of the processed files for filename, stacktrace in dagbag.import_errors.items(): + ( + session.query(DagModel) + .filter(DagModel.fileloc == filename) + .update({'has_import_errors': True}, synchronize_session='fetch') + ) + session.add( errors.ImportError(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace) ) diff --git a/airflow/migrations/versions/be2bfac3da23_add_has_import_errors_column_to_dagmodel.py b/airflow/migrations/versions/be2bfac3da23_add_has_import_errors_column_to_dagmodel.py new file mode 100644 index 0000000000000..9edef0b7db828 --- /dev/null +++ b/airflow/migrations/versions/be2bfac3da23_add_has_import_errors_column_to_dagmodel.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add has_import_errors column to DagModel + +Revision ID: be2bfac3da23 +Revises: 7b2661a43ba3 +Create Date: 2021-11-04 20:33:11.009547 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'be2bfac3da23' +down_revision = '7b2661a43ba3' +branch_labels = None +depends_on = None + + +def upgrade(): + """Apply Add has_import_errors column to DagModel""" + op.add_column("dag", sa.Column("has_import_errors", sa.Boolean(), server_default='0')) + + +def downgrade(): + """Unapply Add has_import_errors column to DagModel""" + op.drop_column("dag", "has_import_errors") diff --git a/airflow/models/dag.py b/airflow/models/dag.py index d0877d04b0de4..52d73b843ce7f 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2450,6 +2450,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None): orm_dag.fileloc = dag.fileloc orm_dag.owners = dag.owner orm_dag.is_active = True + orm_dag.has_import_errors = False orm_dag.last_parsed_time = timezone.utcnow() orm_dag.default_view = dag.default_view orm_dag.description = dag.description @@ -2710,6 +2711,7 @@ class DagModel(Base): max_active_runs = Column(Integer, nullable=True) has_task_concurrency_limits = Column(Boolean, nullable=False) + has_import_errors = Column(Boolean(), default=False) # The logical date of the next dag run. next_dagrun = Column(UtcDateTime) @@ -2744,8 +2746,10 @@ def __init__(self, concurrency=None, **kwargs): self.max_active_tasks = concurrency else: self.max_active_tasks = conf.getint('core', 'max_active_tasks_per_dag') + if self.max_active_runs is None: self.max_active_runs = conf.getint('core', 'max_active_runs_per_dag') + if self.has_task_concurrency_limits is None: # Be safe -- this will be updated later once the DAG is parsed self.has_task_concurrency_limits = True @@ -2882,11 +2886,13 @@ def dags_needing_dagruns(cls, session: Session): # TODO[HA]: Bake this query, it is run _A lot_ # We limit so that _one_ scheduler doesn't try to do all the creation # of dag runs + query = ( session.query(cls) .filter( cls.is_paused == expression.false(), cls.is_active == expression.true(), + cls.has_import_errors == expression.false(), cls.next_dagrun_create_after <= func.now(), ) .order_by(cls.next_dagrun_create_after) diff --git a/airflow/www/views.py b/airflow/www/views.py index 85338c59caad8..fee1dcbb6685b 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1704,6 +1704,10 @@ def trigger(self, session=None): flash(f"Cannot find dag {dag_id}") return redirect(origin) + if dag_orm.has_import_errors: + flash(f"Cannot create dagruns because the dag {dag_id} has import errors", "error") + return redirect(origin) + if request.method == 'GET': # Populate conf textarea with conf requests parameter, or dag.params default_conf = '' diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index 552076be589b8..016c6243df6a2 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -23,7 +23,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ -| ``7b2661a43ba3`` (head) | ``142555e44c17`` | ``2.2.0`` | Change ``TaskInstance`` and ``TaskReschedule`` tables from execution_date to run_id. | +| ``be2bfac3da23`` (head) | ``7b2661a43ba3`` | ``2.2.3`` | Add has_import_errors column to DagModel | ++--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ +| ``7b2661a43ba3`` | ``142555e44c17`` | ``2.2.0`` | Change ``TaskInstance`` and ``TaskReschedule`` tables from execution_date to run_id. | +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ | ``142555e44c17`` | ``54bebd308c5f`` | ``2.2.0`` | Add ``data_interval_start`` and ``data_interval_end`` to ``DagRun`` | +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+ diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 54410b9d5a217..c342da214fc42 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -118,6 +118,7 @@ def _create_dag(self, dag_id): session.add(dag_instance) dag = DAG(dag_id=dag_id, schedule_interval=None) self.app.dag_bag.bag_dag(dag, root_dag=dag) + return dag_instance def _create_test_dag_run(self, state='running', extra_dag=False, commit=True): dag_runs = [] @@ -986,6 +987,24 @@ def test_should_respond_200(self, logical_date_field_name, dag_run_id, logical_d "state": "queued", } == response.json + def test_should_respond_400_if_a_dag_has_import_errors(self, session): + """Test that if a dagmodel has import errors, dags won't be triggered""" + dm = self._create_dag("TEST_DAG_ID") + dm.has_import_errors = True + session.add(dm) + session.flush() + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns", + json={}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert { + "detail": "DAG with dag_id: 'TEST_DAG_ID' has import errors", + "status": 400, + "title": 'DAG cannot be triggered', + "type": EXCEPTIONS_LINK_MAP[400], + } == response.json + def test_should_response_200_for_matching_execution_date_logical_date(self): self._create_dag("TEST_DAG_ID") response = self.client.post( diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index 9b2cb7851f17d..6bfec8f4c21fa 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -26,7 +26,7 @@ import pytest from airflow import settings -from airflow.configuration import conf +from airflow.configuration import TEST_DAGS_FOLDER, conf from airflow.dag_processing.processor import DagFileProcessor from airflow.models import DagBag, DagModel, SlaMiss, TaskInstance, errors from airflow.models.taskinstance import SimpleTaskInstance @@ -481,6 +481,32 @@ def test_add_unparseable_zip_file_creates_import_error(self, tmpdir): assert import_error.stacktrace == f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)" session.rollback() + @conf_vars({("core", "dagbag_import_error_tracebacks"): "False"}) + def test_dag_model_has_import_error_is_true_when_import_error_exists(self, tmpdir, session): + dag_file = os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py") + temp_dagfile = os.path.join(tmpdir, TEMP_DAG_FILENAME) + with open(dag_file) as main_dag, open(temp_dagfile, 'w') as next_dag: + for line in main_dag: + next_dag.write(line) + # first we parse the dag + self._process_file(temp_dagfile, session) + # assert DagModel.has_import_errors is false + dm = session.query(DagModel).filter(DagModel.fileloc == temp_dagfile).first() + assert not dm.has_import_errors + # corrupt the file + with open(temp_dagfile, 'a') as file: + file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) + + self._process_file(temp_dagfile, session) + import_errors = session.query(errors.ImportError).all() + + assert len(import_errors) == 1 + import_error = import_errors[0] + assert import_error.filename == temp_dagfile + assert import_error.stacktrace + dm = session.query(DagModel).filter(DagModel.fileloc == temp_dagfile).first() + assert dm.has_import_errors + def test_no_import_errors_with_parseable_dag(self, tmpdir): parseable_filename = os.path.join(tmpdir, TEMP_DAG_FILENAME) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index efdff89e64445..55de5182916ab 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -822,6 +822,37 @@ def test_bulk_write_to_db_max_active_runs(self, state): model = session.query(DagModel).get((dag.dag_id,)) assert model.next_dagrun_create_after is None + def test_bulk_write_to_db_has_import_error(self): + """ + Test that DagModel.has_import_error is set to false if no import errors. + """ + dag = DAG(dag_id='test_has_import_error', start_date=DEFAULT_DATE) + + DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + session = settings.Session() + dag.clear() + DAG.bulk_write_to_db([dag], session) + + model = session.query(DagModel).get((dag.dag_id,)) + + assert not model.has_import_errors + + # Simulate Dagfileprocessor setting the import error to true + model.has_import_errors = True + session.merge(model) + session.flush() + model = session.query(DagModel).get((dag.dag_id,)) + # assert + assert model.has_import_errors + # parse + DAG.bulk_write_to_db([dag]) + + model = session.query(DagModel).get((dag.dag_id,)) + # assert that has_import_error is now false + assert not model.has_import_errors + session.close() + def test_sync_to_db(self): dag = DAG( 'dag', @@ -1863,9 +1894,10 @@ def test_max_active_runs_not_none(self): next_dagrun_create_after=None, is_active=True, ) + # assert max_active_runs updated + assert orm_dag.max_active_runs == 16 session.add(orm_dag) session.flush() - assert orm_dag.max_active_runs is not None session.rollback() @@ -1901,6 +1933,33 @@ def test_dags_needing_dagruns_only_unpaused(self): session.rollback() session.close() + def test_dags_needing_dagruns_doesnot_send_dagmodel_with_import_errors(self, session): + """ + We check that has_import_error is false for dags + being set to scheduler to create dagruns + """ + dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE) + DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + orm_dag = DagModel( + dag_id=dag.dag_id, + has_task_concurrency_limits=False, + next_dagrun=DEFAULT_DATE, + next_dagrun_create_after=DEFAULT_DATE + timedelta(days=1), + is_active=True, + ) + assert not orm_dag.has_import_errors + session.add(orm_dag) + session.flush() + + needed = DagModel.dags_needing_dagruns(session).all() + assert needed == [orm_dag] + orm_dag.has_import_errors = True + session.merge(orm_dag) + session.flush() + needed = DagModel.dags_needing_dagruns(session).all() + assert needed == [] + @pytest.mark.parametrize( ('fileloc', 'expected_relative'), [