diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py index 7ce0d1655da6e..c0050672a8e47 100644 --- a/airflow/api/client/local_client.py +++ b/airflow/api/client/local_client.py @@ -18,8 +18,10 @@ """Local client API""" from airflow.api.client import api_client -from airflow.api.common.experimental import delete_dag, pool, trigger_dag +from airflow.api.common import delete_dag, trigger_dag from airflow.api.common.experimental.get_lineage import get_lineage as get_lineage_api +from airflow.exceptions import AirflowBadRequest, PoolNotFound +from airflow.models.pool import Pool class Client(api_client.Client): @@ -36,19 +38,30 @@ def delete_dag(self, dag_id): return f"Removed {count} record(s)" def get_pool(self, name): - the_pool = pool.get_pool(name=name) - return the_pool.pool, the_pool.slots, the_pool.description + pool = Pool.get_pool(pool_name=name) + if not pool: + raise PoolNotFound(f"Pool {name} not found") + return pool.pool, pool.slots, pool.description def get_pools(self): - return [(p.pool, p.slots, p.description) for p in pool.get_pools()] + return [(p.pool, p.slots, p.description) for p in Pool.get_pools()] def create_pool(self, name, slots, description): - the_pool = pool.create_pool(name=name, slots=slots, description=description) - return the_pool.pool, the_pool.slots, the_pool.description + if not (name and name.strip()): + raise AirflowBadRequest("Pool name shouldn't be empty") + pool_name_length = Pool.pool.property.columns[0].type.length + if len(name) > pool_name_length: + raise AirflowBadRequest(f"pool name cannot be more than {pool_name_length} characters") + try: + slots = int(slots) + except ValueError: + raise AirflowBadRequest(f"Bad value for `slots`: {slots}") + pool = Pool.create_or_update_pool(name=name, slots=slots, description=description) + return pool.pool, pool.slots, pool.description def delete_pool(self, name): - the_pool = pool.delete_pool(name=name) - return the_pool.pool, the_pool.slots, the_pool.description + pool = Pool.delete_pool(name=name) + return pool.pool, pool.slots, pool.description def get_lineage(self, dag_id, execution_date): lineage = get_lineage_api(dag_id=dag_id, execution_date=execution_date) diff --git a/airflow/api/common/delete_dag.py b/airflow/api/common/delete_dag.py new file mode 100644 index 0000000000000..c448127f2c484 --- /dev/null +++ b/airflow/api/common/delete_dag.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Delete DAGs APIs.""" +import logging + +from sqlalchemy import or_ + +from airflow import models +from airflow.exceptions import AirflowException, DagNotFound +from airflow.models import DagModel, TaskFail +from airflow.models.serialized_dag import SerializedDagModel +from airflow.utils.db import get_sqla_model_classes +from airflow.utils.session import provide_session +from airflow.utils.state import State + +log = logging.getLogger(__name__) + + +@provide_session +def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int: + """ + :param dag_id: the dag_id of the DAG to delete + :param keep_records_in_log: whether keep records of the given dag_id + in the Log table in the backend database (for reasons like auditing). + The default value is True. + :param session: session used + :return count of deleted dags + """ + log.info("Deleting DAG: %s", dag_id) + running_tis = ( + session.query(models.TaskInstance.state) + .filter(models.TaskInstance.dag_id == dag_id) + .filter(models.TaskInstance.state == State.RUNNING) + .first() + ) + if running_tis: + raise AirflowException("TaskInstances still running") + dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first() + if dag is None: + raise DagNotFound(f"Dag id {dag_id} not found") + + # Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval. + # There may be a lag, so explicitly removes serialized DAG here. + if SerializedDagModel.has_dag(dag_id=dag_id, session=session): + SerializedDagModel.remove_dag(dag_id=dag_id, session=session) + + count = 0 + + for model in get_sqla_model_classes(): + if hasattr(model, "dag_id"): + if keep_records_in_log and model.__name__ == 'Log': + continue + cond = or_(model.dag_id == dag_id, model.dag_id.like(dag_id + ".%")) + count += session.query(model).filter(cond).delete(synchronize_session='fetch') + if dag.is_subdag: + parent_dag_id, task_id = dag_id.rsplit(".", 1) + for model in TaskFail, models.TaskInstance: + count += ( + session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete() + ) + + # Delete entries in Import Errors table for a deleted DAG + # This handles the case when the dag_id is changed in the file + session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete( + synchronize_session='fetch' + ) + + return count diff --git a/airflow/api/common/experimental/delete_dag.py b/airflow/api/common/experimental/delete_dag.py index 44e54e3738349..36bf7dd8c46a7 100644 --- a/airflow/api/common/experimental/delete_dag.py +++ b/airflow/api/common/experimental/delete_dag.py @@ -15,68 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Delete DAGs APIs.""" -import logging +import warnings -from sqlalchemy import or_ +from airflow.api.common.delete_dag import * # noqa -from airflow import models -from airflow.exceptions import AirflowException, DagNotFound -from airflow.models import DagModel, TaskFail -from airflow.models.serialized_dag import SerializedDagModel -from airflow.utils.session import provide_session -from airflow.utils.state import State - -log = logging.getLogger(__name__) - - -@provide_session -def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int: - """ - :param dag_id: the dag_id of the DAG to delete - :param keep_records_in_log: whether keep records of the given dag_id - in the Log table in the backend database (for reasons like auditing). - The default value is True. - :param session: session used - :return count of deleted dags - """ - log.info("Deleting DAG: %s", dag_id) - running_tis = ( - session.query(models.TaskInstance.state) - .filter(models.TaskInstance.dag_id == dag_id) - .filter(models.TaskInstance.state == State.RUNNING) - .first() - ) - if running_tis: - raise AirflowException("TaskInstances still running") - dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first() - if dag is None: - raise DagNotFound(f"Dag id {dag_id} not found") - - # Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval. - # There may be a lag, so explicitly removes serialized DAG here. - if SerializedDagModel.has_dag(dag_id=dag_id, session=session): - SerializedDagModel.remove_dag(dag_id=dag_id, session=session) - - count = 0 - - for model in models.base.Base._decl_class_registry.values(): - if hasattr(model, "dag_id"): - if keep_records_in_log and model.__name__ == 'Log': - continue - cond = or_(model.dag_id == dag_id, model.dag_id.like(dag_id + ".%")) - count += session.query(model).filter(cond).delete(synchronize_session='fetch') - if dag.is_subdag: - parent_dag_id, task_id = dag_id.rsplit(".", 1) - for model in TaskFail, models.TaskInstance: - count += ( - session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete() - ) - - # Delete entries in Import Errors table for a deleted DAG - # This handles the case when the dag_id is changed in the file - session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete( - synchronize_session='fetch' - ) - - return count +warnings.warn( + "This module is deprecated. Please use `airflow.api.common.delete_dag` instead.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/api/common/experimental/get_code.py b/airflow/api/common/experimental/get_code.py index 79b0b9f492654..1a1fb621dbe48 100644 --- a/airflow/api/common/experimental/get_code.py +++ b/airflow/api/common/experimental/get_code.py @@ -16,11 +16,14 @@ # specific language governing permissions and limitations # under the License. """Get code APIs.""" +from deprecated import deprecated + from airflow.api.common.experimental import check_and_get_dag from airflow.exceptions import AirflowException, DagCodeNotFound from airflow.models.dagcode import DagCode +@deprecated(reason="Use DagCode().get_code_by_fileloc() instead", version="2.2.3") def get_code(dag_id: str) -> str: """Return python code of a given dag_id. diff --git a/airflow/api/common/experimental/get_dag_run_state.py b/airflow/api/common/experimental/get_dag_run_state.py index ca71a9afb3853..b2dedd5113ae9 100644 --- a/airflow/api/common/experimental/get_dag_run_state.py +++ b/airflow/api/common/experimental/get_dag_run_state.py @@ -19,9 +19,12 @@ from datetime import datetime from typing import Dict +from deprecated import deprecated + from airflow.api.common.experimental import check_and_get_dag, check_and_get_dagrun +@deprecated(reason="Use DagRun().get_state() instead", version="2.2.3") def get_dag_run_state(dag_id: str, execution_date: datetime) -> Dict[str, str]: """Return the Dag Run state identified by the given dag_id and execution_date. diff --git a/airflow/api/common/experimental/get_task.py b/airflow/api/common/experimental/get_task.py index 302ad6430efe9..fae5fd7ef1851 100644 --- a/airflow/api/common/experimental/get_task.py +++ b/airflow/api/common/experimental/get_task.py @@ -16,10 +16,13 @@ # specific language governing permissions and limitations # under the License. """Task APIs..""" +from deprecated import deprecated + from airflow.api.common.experimental import check_and_get_dag from airflow.models import TaskInstance +@deprecated(reason="Use DAG().get_task", version="2.2.3") def get_task(dag_id: str, task_id: str) -> TaskInstance: """Return the task object identified by the given dag_id and task_id.""" dag = check_and_get_dag(dag_id, task_id) diff --git a/airflow/api/common/experimental/get_task_instance.py b/airflow/api/common/experimental/get_task_instance.py index f3ca1cf2f6380..137f8a3aef9e7 100644 --- a/airflow/api/common/experimental/get_task_instance.py +++ b/airflow/api/common/experimental/get_task_instance.py @@ -18,11 +18,14 @@ """Task Instance APIs.""" from datetime import datetime +from deprecated import deprecated + from airflow.api.common.experimental import check_and_get_dag, check_and_get_dagrun from airflow.exceptions import TaskInstanceNotFound from airflow.models import TaskInstance +@deprecated(version="2.2.3", reason="Use DagRun.get_task_instance instead") def get_task_instance(dag_id: str, task_id: str, execution_date: datetime) -> TaskInstance: """Return the task instance identified by the given dag_id, task_id and execution_date.""" dag = check_and_get_dag(dag_id, task_id) diff --git a/airflow/api/common/experimental/pool.py b/airflow/api/common/experimental/pool.py index 30950ea0026ee..0b9c3a5d4903b 100644 --- a/airflow/api/common/experimental/pool.py +++ b/airflow/api/common/experimental/pool.py @@ -16,11 +16,14 @@ # specific language governing permissions and limitations # under the License. """Pool APIs.""" +from deprecated import deprecated + from airflow.exceptions import AirflowBadRequest, PoolNotFound from airflow.models import Pool from airflow.utils.session import provide_session +@deprecated(reason="Use Pool.get_pool() instead", version="2.2.3") @provide_session def get_pool(name, session=None): """Get pool by a given name.""" @@ -34,12 +37,14 @@ def get_pool(name, session=None): return pool +@deprecated(reason="Use Pool.get_pools() instead", version="2.2.3") @provide_session def get_pools(session=None): """Get all pools.""" return session.query(Pool).all() +@deprecated(reason="Use Pool.create_pool() instead", version="2.2.3") @provide_session def create_pool(name, slots, description, session=None): """Create a pool with a given parameters.""" @@ -70,6 +75,7 @@ def create_pool(name, slots, description, session=None): return pool +@deprecated(reason="Use Pool.delete_pool() instead", version="2.2.3") @provide_session def delete_pool(name, session=None): """Delete pool by a given name.""" diff --git a/airflow/api/common/experimental/trigger_dag.py b/airflow/api/common/experimental/trigger_dag.py index 38a873ce2e013..d52631281f534 100644 --- a/airflow/api/common/experimental/trigger_dag.py +++ b/airflow/api/common/experimental/trigger_dag.py @@ -15,114 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Triggering DAG runs APIs.""" -import json -from datetime import datetime -from typing import List, Optional, Union -from airflow.exceptions import DagNotFound, DagRunAlreadyExists -from airflow.models import DagBag, DagModel, DagRun -from airflow.utils import timezone -from airflow.utils.state import State -from airflow.utils.types import DagRunType +import warnings +from airflow.api.common.trigger_dag import * # noqa -def _trigger_dag( - dag_id: str, - dag_bag: DagBag, - run_id: Optional[str] = None, - conf: Optional[Union[dict, str]] = None, - execution_date: Optional[datetime] = None, - replace_microseconds: bool = True, -) -> List[DagRun]: - """Triggers DAG run. - - :param dag_id: DAG ID - :param dag_bag: DAG Bag model - :param run_id: ID of the dag_run - :param conf: configuration - :param execution_date: date of execution - :param replace_microseconds: whether microseconds should be zeroed - :return: list of triggered dags - """ - dag = dag_bag.get_dag(dag_id) # prefetch dag if it is stored serialized - - if dag_id not in dag_bag.dags: - raise DagNotFound(f"Dag id {dag_id} not found") - - execution_date = execution_date if execution_date else timezone.utcnow() - - if not timezone.is_localized(execution_date): - raise ValueError("The execution_date should be localized") - - if replace_microseconds: - execution_date = execution_date.replace(microsecond=0) - - if dag.default_args and 'start_date' in dag.default_args: - min_dag_start_date = dag.default_args["start_date"] - if min_dag_start_date and execution_date < min_dag_start_date: - raise ValueError( - "The execution_date [{}] should be >= start_date [{}] from DAG's default_args".format( - execution_date.isoformat(), min_dag_start_date.isoformat() - ) - ) - - run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date) - dag_run = DagRun.find_duplicate(dag_id=dag_id, execution_date=execution_date, run_id=run_id) - - if dag_run: - raise DagRunAlreadyExists( - f"A Dag Run already exists for dag id {dag_id} at {execution_date} with run id {run_id}" - ) - - run_conf = None - if conf: - run_conf = conf if isinstance(conf, dict) else json.loads(conf) - - dag_runs = [] - dags_to_run = [dag] + dag.subdags - for _dag in dags_to_run: - dag_run = _dag.create_dagrun( - run_id=run_id, - execution_date=execution_date, - state=State.QUEUED, - conf=run_conf, - external_trigger=True, - dag_hash=dag_bag.dags_hash.get(dag_id), - ) - dag_runs.append(dag_run) - - return dag_runs - - -def trigger_dag( - dag_id: str, - run_id: Optional[str] = None, - conf: Optional[Union[dict, str]] = None, - execution_date: Optional[datetime] = None, - replace_microseconds: bool = True, -) -> Optional[DagRun]: - """Triggers execution of DAG specified by dag_id - - :param dag_id: DAG ID - :param run_id: ID of the dag_run - :param conf: configuration - :param execution_date: date of execution - :param replace_microseconds: whether microseconds should be zeroed - :return: first dag run triggered - even if more than one Dag Runs were triggered or None - """ - dag_model = DagModel.get_current(dag_id) - if dag_model is None: - raise DagNotFound(f"Dag id {dag_id} not found in DagModel") - - dagbag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) - triggers = _trigger_dag( - dag_id=dag_id, - dag_bag=dagbag, - run_id=run_id, - conf=conf, - execution_date=execution_date, - replace_microseconds=replace_microseconds, - ) - - return triggers[0] if triggers else None +warnings.warn( + "This module is deprecated. Please use `airflow.api.common.trigger_dag` instead.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/api/common/trigger_dag.py b/airflow/api/common/trigger_dag.py new file mode 100644 index 0000000000000..70bbb78312209 --- /dev/null +++ b/airflow/api/common/trigger_dag.py @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Triggering DAG runs APIs.""" +import json +from datetime import datetime +from typing import List, Optional, Union + +from airflow.exceptions import DagNotFound, DagRunAlreadyExists +from airflow.models import DagBag, DagModel, DagRun +from airflow.utils import timezone +from airflow.utils.state import State +from airflow.utils.types import DagRunType + + +def _trigger_dag( + dag_id: str, + dag_bag: DagBag, + run_id: Optional[str] = None, + conf: Optional[Union[dict, str]] = None, + execution_date: Optional[datetime] = None, + replace_microseconds: bool = True, +) -> List[DagRun]: + """Triggers DAG run. + + :param dag_id: DAG ID + :param dag_bag: DAG Bag model + :param run_id: ID of the dag_run + :param conf: configuration + :param execution_date: date of execution + :param replace_microseconds: whether microseconds should be zeroed + :return: list of triggered dags + """ + dag = dag_bag.get_dag(dag_id) # prefetch dag if it is stored serialized + + if dag_id not in dag_bag.dags: + raise DagNotFound(f"Dag id {dag_id} not found") + + execution_date = execution_date if execution_date else timezone.utcnow() + + if not timezone.is_localized(execution_date): + raise ValueError("The execution_date should be localized") + + if replace_microseconds: + execution_date = execution_date.replace(microsecond=0) + + if dag.default_args and 'start_date' in dag.default_args: + min_dag_start_date = dag.default_args["start_date"] + if min_dag_start_date and execution_date < min_dag_start_date: + raise ValueError( + f"The execution_date [{execution_date.isoformat()}] should be >= start_date " + f"[{min_dag_start_date.isoformat()}] from DAG's default_args" + ) + + run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date) + dag_run = DagRun.find_duplicate(dag_id=dag_id, execution_date=execution_date, run_id=run_id) + + if dag_run: + raise DagRunAlreadyExists( + f"A Dag Run already exists for dag id {dag_id} at {execution_date} with run id {run_id}" + ) + + run_conf = None + if conf: + run_conf = conf if isinstance(conf, dict) else json.loads(conf) + + dag_runs = [] + dags_to_run = [dag] + dag.subdags + for _dag in dags_to_run: + dag_run = _dag.create_dagrun( + run_id=run_id, + execution_date=execution_date, + state=State.QUEUED, + conf=run_conf, + external_trigger=True, + dag_hash=dag_bag.dags_hash.get(dag_id), + ) + dag_runs.append(dag_run) + + return dag_runs + + +def trigger_dag( + dag_id: str, + run_id: Optional[str] = None, + conf: Optional[Union[dict, str]] = None, + execution_date: Optional[datetime] = None, + replace_microseconds: bool = True, +) -> Optional[DagRun]: + """Triggers execution of DAG specified by dag_id + + :param dag_id: DAG ID + :param run_id: ID of the dag_run + :param conf: configuration + :param execution_date: date of execution + :param replace_microseconds: whether microseconds should be zeroed + :return: first dag run triggered - even if more than one Dag Runs were triggered or None + """ + dag_model = DagModel.get_current(dag_id) + if dag_model is None: + raise DagNotFound(f"Dag id {dag_id} not found in DagModel") + + dagbag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) + triggers = _trigger_dag( + dag_id=dag_id, + dag_bag=dagbag, + run_id=run_id, + conf=conf, + execution_date=execution_date, + replace_microseconds=replace_microseconds, + ) + + return triggers[0] if triggers else None diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index c164fccc37dbf..286b191601caf 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -110,13 +110,10 @@ def patch_dag(session, dag_id, update_mask=None): @provide_session def delete_dag(dag_id: str, session: Session): """Delete the specific DAG.""" - # TODO: This function is shared with the /delete endpoint used by the web - # UI, so we're reusing it to simplify maintenance. Refactor the function to - # another place when the experimental/legacy API is removed. - from airflow.api.common.experimental import delete_dag + from airflow.api.common import delete_dag as delete_dag_module try: - delete_dag.delete_dag(dag_id, session=session) + delete_dag_module.delete_dag(dag_id, session=session) except DagNotFound: raise NotFound(f"Dag with id: '{dag_id}' not found") except AirflowException: diff --git a/airflow/models/pool.py b/airflow/models/pool.py index 6f217c4b025a2..8ae88aabcd45f 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -21,11 +21,11 @@ from sqlalchemy import Column, Integer, String, Text, func from sqlalchemy.orm.session import Session -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, PoolNotFound from airflow.models.base import Base from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.typing_compat import TypedDict -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import nowait, with_row_locks from airflow.utils.state import State @@ -57,7 +57,13 @@ def __repr__(self): @staticmethod @provide_session - def get_pool(pool_name, session: Session = None): + def get_pools(session: Session = NEW_SESSION): + """Get all pools.""" + return session.query(Pool).all() + + @staticmethod + @provide_session + def get_pool(pool_name: str, session: Session = NEW_SESSION): """ Get the Pool with specific pool name from the Pools. @@ -69,7 +75,7 @@ def get_pool(pool_name, session: Session = None): @staticmethod @provide_session - def get_default_pool(session: Session = None): + def get_default_pool(session: Session = NEW_SESSION): """ Get the Pool of the default_pool from the Pools. @@ -78,12 +84,46 @@ def get_default_pool(session: Session = None): """ return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session) + @staticmethod + @provide_session + def create_or_update_pool(name: str, slots: int, description: str, session: Session = NEW_SESSION): + """Create a pool with given parameters or update it if it already exists.""" + if not name: + return + pool = session.query(Pool).filter_by(pool=name).first() + if pool is None: + pool = Pool(pool=name, slots=slots, description=description) + session.add(pool) + else: + pool.slots = slots + pool.description = description + + session.commit() + + return pool + + @staticmethod + @provide_session + def delete_pool(name: str, session: Session = NEW_SESSION): + """Delete pool by a given name.""" + if name == Pool.DEFAULT_POOL_NAME: + raise AirflowException("default_pool cannot be deleted") + + pool = session.query(Pool).filter_by(pool=name).first() + if pool is None: + raise PoolNotFound(f"Pool '{name}' doesn't exist") + + session.delete(pool) + session.commit() + + return pool + @staticmethod @provide_session def slots_stats( *, lock_rows: bool = False, - session: Session = None, + session: Session = NEW_SESSION, ) -> Dict[str, PoolStats]: """ Get Pool stats (Number of Running, Queued, Open & Total tasks) @@ -210,7 +250,7 @@ def queued_slots(self, session: Session): ) @provide_session - def open_slots(self, session: Session) -> float: + def open_slots(self, session: Session = NEW_SESSION) -> float: """ Get the number of slots open at the moment. diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 1e6cb7f6ab38f..421c7963d0d68 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -21,7 +21,7 @@ import time from typing import Dict, List, Optional, Union -from airflow.api.common.experimental.trigger_dag import trigger_dag +from airflow.api.common.trigger_dag import trigger_dag from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists from airflow.models import BaseOperator, BaseOperatorLink, DagBag, DagModel, DagRun from airflow.utils import timezone diff --git a/airflow/utils/db.py b/airflow/utils/db.py index f35d1659f8cb9..023f482790d00 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -991,3 +991,18 @@ def check(session=None): """ session.execute('select 1 as is_alive;') log.info("Connection successful.") + + +def get_sqla_model_classes(): + """ + Get all SQLAlchemy class mappers. + + SQLAlchemy < 1.4 does not support registry.mappers so we use + try/except to handle it. + """ + from airflow.models.base import Base + + try: + return [mapper.class_ for mapper in Base.registry.mappers] + except AttributeError: + return Base._decl_class_registry.values() diff --git a/airflow/www/views.py b/airflow/www/views.py index 2182a1706aeec..f2642a73f1f0e 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1607,7 +1607,7 @@ def run(self): @action_logging def delete(self): """Deletes DAG.""" - from airflow.api.common.experimental import delete_dag + from airflow.api.common import delete_dag from airflow.exceptions import DagNotFound dag_id = request.values.get('dag_id') diff --git a/setup.cfg b/setup.cfg index b83ef9be02826..c3cce1c0ac0c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -95,6 +95,7 @@ install_requires = croniter>=0.3.17 cryptography>=0.9.3 dataclasses;python_version<"3.7" + deprecated>=1.2.13 dill>=0.2.2, <0.4 # Sphinx RTD theme 0.5.2. introduced limitation to docutils to account for some docutils markup # change: diff --git a/tests/api/client/test_local_client.py b/tests/api/client/test_local_client.py index a2af8ca245e6b..9f574e4fc657a 100644 --- a/tests/api/client/test_local_client.py +++ b/tests/api/client/test_local_client.py @@ -17,6 +17,8 @@ # under the License. import json +import random +import string import unittest from unittest.mock import ANY, patch @@ -25,7 +27,7 @@ from airflow.api.client.local_client import Client from airflow.example_dags import example_bash_operator -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowBadRequest, AirflowException, PoolNotFound from airflow.models import DAG, DagBag, DagModel, DagRun, Pool from airflow.utils import timezone from airflow.utils.session import create_session @@ -133,6 +135,10 @@ def test_get_pool(self): pool = self.client.get_pool(name='foo') assert pool == ('foo', 1, '') + def test_get_pool_non_existing_raises(self): + with pytest.raises(PoolNotFound): + self.client.get_pool(name='foo') + def test_get_pools(self): self.client.create_pool(name='foo1', slots=1, description='') self.client.create_pool(name='foo2', slots=2, description='') @@ -145,6 +151,26 @@ def test_create_pool(self): with create_session() as session: assert session.query(Pool).count() == 2 + def test_create_pool_bad_slots(self): + with pytest.raises(AirflowBadRequest, match="^Bad value for `slots`: foo$"): + self.client.create_pool( + name='foo', + slots='foo', + description='', + ) + + def test_create_pool_name_too_long(self): + long_name = ''.join(random.choices(string.ascii_lowercase, k=300)) + pool_name_length = Pool.pool.property.columns[0].type.length + with pytest.raises( + AirflowBadRequest, match=f"^pool name cannot be more than {pool_name_length} characters" + ): + self.client.create_pool( + name=long_name, + slots=5, + description='', + ) + def test_delete_pool(self): self.client.create_pool(name='foo', slots=1, description='') with create_session() as session: @@ -152,3 +178,6 @@ def test_delete_pool(self): self.client.delete_pool(name='foo') with create_session() as session: assert session.query(Pool).count() == 1 + for name in ('', ' '): + with pytest.raises(PoolNotFound, match=f"^Pool {name!r} doesn't exist$"): + Pool.delete_pool(name=name) diff --git a/tests/api/common/experimental/test_delete_dag.py b/tests/api/common/test_delete_dag.py similarity index 99% rename from tests/api/common/experimental/test_delete_dag.py rename to tests/api/common/test_delete_dag.py index 5984cd2b14f0f..0eb058a18337a 100644 --- a/tests/api/common/experimental/test_delete_dag.py +++ b/tests/api/common/test_delete_dag.py @@ -20,7 +20,7 @@ import pytest from airflow import models -from airflow.api.common.experimental.delete_dag import delete_dag +from airflow.api.common.delete_dag import delete_dag from airflow.exceptions import AirflowException, DagNotFound from airflow.operators.dummy import DummyOperator from airflow.utils.dates import days_ago diff --git a/tests/api/common/experimental/test_trigger_dag.py b/tests/api/common/test_trigger_dag.py similarity index 93% rename from tests/api/common/experimental/test_trigger_dag.py rename to tests/api/common/test_trigger_dag.py index 2f164468d085f..f79d413ed5eae 100644 --- a/tests/api/common/experimental/test_trigger_dag.py +++ b/tests/api/common/test_trigger_dag.py @@ -22,7 +22,7 @@ import pytest from parameterized import parameterized -from airflow.api.common.experimental.trigger_dag import _trigger_dag +from airflow.api.common.trigger_dag import _trigger_dag from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun from airflow.utils import timezone @@ -42,7 +42,7 @@ def test_trigger_dag_dag_not_found(self, dag_bag_mock): with pytest.raises(AirflowException): _trigger_dag('dag_not_found', dag_bag_mock) - @mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun) + @mock.patch('airflow.api.common.trigger_dag.DagRun', spec=DagRun) @mock.patch('airflow.models.DagBag') def test_trigger_dag_dag_run_exist(self, dag_bag_mock, dag_run_mock): dag_id = "dag_run_exist" @@ -54,7 +54,7 @@ def test_trigger_dag_dag_run_exist(self, dag_bag_mock, dag_run_mock): _trigger_dag(dag_id, dag_bag_mock) @mock.patch('airflow.models.DAG') - @mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun) + @mock.patch('airflow.api.common.trigger_dag.DagRun', spec=DagRun) @mock.patch('airflow.models.DagBag') def test_trigger_dag_include_subdags(self, dag_bag_mock, dag_run_mock, dag_mock): dag_id = "trigger_dag" @@ -70,7 +70,7 @@ def test_trigger_dag_include_subdags(self, dag_bag_mock, dag_run_mock, dag_mock) assert 3 == len(triggers) @mock.patch('airflow.models.DAG') - @mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun) + @mock.patch('airflow.api.common.trigger_dag.DagRun', spec=DagRun) @mock.patch('airflow.models.DagBag') def test_trigger_dag_include_nested_subdags(self, dag_bag_mock, dag_run_mock, dag_mock): dag_id = "trigger_dag" diff --git a/tests/models/test_pool.py b/tests/models/test_pool.py index 00fe14039d7e3..95e585efa5974 100644 --- a/tests/models/test_pool.py +++ b/tests/models/test_pool.py @@ -16,11 +16,15 @@ # specific language governing permissions and limitations # under the License. +import pytest + from airflow import settings +from airflow.exceptions import AirflowException, PoolNotFound from airflow.models.pool import Pool from airflow.models.taskinstance import TaskInstance as TI from airflow.operators.dummy import DummyOperator from airflow.utils import timezone +from airflow.utils.session import create_session from airflow.utils.state import State from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots @@ -28,6 +32,10 @@ class TestPool: + + USER_POOL_COUNT = 2 + TOTAL_POOL_COUNT = USER_POOL_COUNT + 1 # including default_pool + @staticmethod def clean_db(): clear_db_dags() @@ -36,6 +44,20 @@ def clean_db(): def setup_method(self): self.clean_db() + self.pools = [] + + def add_pools(self): + self.pools = [Pool.get_default_pool()] + for i in range(self.USER_POOL_COUNT): + name = f'experimental_{i + 1}' + pool = Pool( + pool=name, + slots=i, + description=name, + ) + self.pools.append(pool) + with create_session() as session: + session.add_all(self.pools) def teardown_method(self): self.clean_db() @@ -149,3 +171,52 @@ def test_default_pool_open_slots(self, dag_maker): "running": 1, } } == Pool.slots_stats() + + def test_get_pool(self): + self.add_pools() + pool = Pool.get_pool(pool_name=self.pools[0].pool) + assert pool.pool == self.pools[0].pool + + def test_get_pool_non_existing(self): + self.add_pools() + assert not Pool.get_pool(pool_name='test') + + def test_get_pool_bad_name(self): + for name in ('', ' '): + assert not Pool.get_pool(pool_name=name) + + def test_get_pools(self): + self.add_pools() + pools = sorted(Pool.get_pools(), key=lambda p: p.pool) + assert pools[0].pool == self.pools[0].pool + assert pools[1].pool == self.pools[1].pool + + def test_create_pool(self, session): + self.add_pools() + pool = Pool.create_or_update_pool(name='foo', slots=5, description='') + assert pool.pool == 'foo' + assert pool.slots == 5 + assert pool.description == '' + assert session.query(Pool).count() == self.TOTAL_POOL_COUNT + 1 + + def test_create_pool_existing(self, session): + self.add_pools() + pool = Pool.create_or_update_pool(name=self.pools[0].pool, slots=5, description='') + assert pool.pool == self.pools[0].pool + assert pool.slots == 5 + assert pool.description == '' + assert session.query(Pool).count() == self.TOTAL_POOL_COUNT + + def test_delete_pool(self, session): + self.add_pools() + pool = Pool.delete_pool(name=self.pools[-1].pool) + assert pool.pool == self.pools[-1].pool + assert session.query(Pool).count() == self.TOTAL_POOL_COUNT - 1 + + def test_delete_pool_non_existing(self): + with pytest.raises(PoolNotFound, match="^Pool 'test' doesn't exist$"): + Pool.delete_pool(name='test') + + def test_delete_default_pool_not_allowed(self): + with pytest.raises(AirflowException, match="^default_pool cannot be deleted$"): + Pool.delete_pool(Pool.DEFAULT_POOL_NAME)