From 663bb546e782748bdd315483ca2070a77046997a Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Thu, 16 Dec 2021 12:30:42 +0100 Subject: [PATCH] Deprecate some functions in the experimental API (#19931) This PR seeks to deprecate some functions in the experimental API. Some of the deprecated functions are only used in the experimental REST API, others that are valid are being moved out of the experimental package. (cherry picked from commit 6239ae91a4c8bfb05f053a61cb8386f2d63b8b3a) --- airflow/api/client/local_client.py | 29 ++-- airflow/api/common/delete_dag.py | 83 ++++++++++++ airflow/api/common/experimental/delete_dag.py | 70 +--------- airflow/api/common/experimental/get_code.py | 3 + .../common/experimental/get_dag_run_state.py | 3 + airflow/api/common/experimental/get_task.py | 3 + .../common/experimental/get_task_instance.py | 3 + airflow/api/common/experimental/pool.py | 6 + .../api/common/experimental/trigger_dag.py | 115 +--------------- airflow/api/common/trigger_dag.py | 127 ++++++++++++++++++ .../api_connexion/endpoints/dag_endpoint.py | 7 +- airflow/models/pool.py | 52 ++++++- airflow/operators/trigger_dagrun.py | 2 +- airflow/utils/db.py | 15 +++ airflow/www/views.py | 2 +- setup.cfg | 1 + tests/api/client/test_local_client.py | 31 ++++- .../{experimental => }/test_delete_dag.py | 2 +- .../{experimental => }/test_trigger_dag.py | 8 +- tests/models/test_pool.py | 71 ++++++++++ 20 files changed, 435 insertions(+), 198 deletions(-) create mode 100644 airflow/api/common/delete_dag.py create mode 100644 airflow/api/common/trigger_dag.py rename tests/api/common/{experimental => }/test_delete_dag.py (99%) rename tests/api/common/{experimental => }/test_trigger_dag.py (93%) diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py index 7ce0d1655da6e..c0050672a8e47 100644 --- a/airflow/api/client/local_client.py +++ b/airflow/api/client/local_client.py @@ -18,8 +18,10 @@ """Local client API""" from airflow.api.client import api_client -from airflow.api.common.experimental import delete_dag, pool, trigger_dag +from airflow.api.common import delete_dag, trigger_dag from airflow.api.common.experimental.get_lineage import get_lineage as get_lineage_api +from airflow.exceptions import AirflowBadRequest, PoolNotFound +from airflow.models.pool import Pool class Client(api_client.Client): @@ -36,19 +38,30 @@ def delete_dag(self, dag_id): return f"Removed {count} record(s)" def get_pool(self, name): - the_pool = pool.get_pool(name=name) - return the_pool.pool, the_pool.slots, the_pool.description + pool = Pool.get_pool(pool_name=name) + if not pool: + raise PoolNotFound(f"Pool {name} not found") + return pool.pool, pool.slots, pool.description def get_pools(self): - return [(p.pool, p.slots, p.description) for p in pool.get_pools()] + return [(p.pool, p.slots, p.description) for p in Pool.get_pools()] def create_pool(self, name, slots, description): - the_pool = pool.create_pool(name=name, slots=slots, description=description) - return the_pool.pool, the_pool.slots, the_pool.description + if not (name and name.strip()): + raise AirflowBadRequest("Pool name shouldn't be empty") + pool_name_length = Pool.pool.property.columns[0].type.length + if len(name) > pool_name_length: + raise AirflowBadRequest(f"pool name cannot be more than {pool_name_length} characters") + try: + slots = int(slots) + except ValueError: + raise AirflowBadRequest(f"Bad value for `slots`: {slots}") + pool = Pool.create_or_update_pool(name=name, slots=slots, description=description) + return pool.pool, pool.slots, pool.description def delete_pool(self, name): - the_pool = pool.delete_pool(name=name) - return the_pool.pool, the_pool.slots, the_pool.description + pool = Pool.delete_pool(name=name) + return pool.pool, pool.slots, pool.description def get_lineage(self, dag_id, execution_date): lineage = get_lineage_api(dag_id=dag_id, execution_date=execution_date) diff --git a/airflow/api/common/delete_dag.py b/airflow/api/common/delete_dag.py new file mode 100644 index 0000000000000..c448127f2c484 --- /dev/null +++ b/airflow/api/common/delete_dag.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Delete DAGs APIs.""" +import logging + +from sqlalchemy import or_ + +from airflow import models +from airflow.exceptions import AirflowException, DagNotFound +from airflow.models import DagModel, TaskFail +from airflow.models.serialized_dag import SerializedDagModel +from airflow.utils.db import get_sqla_model_classes +from airflow.utils.session import provide_session +from airflow.utils.state import State + +log = logging.getLogger(__name__) + + +@provide_session +def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int: + """ + :param dag_id: the dag_id of the DAG to delete + :param keep_records_in_log: whether keep records of the given dag_id + in the Log table in the backend database (for reasons like auditing). + The default value is True. + :param session: session used + :return count of deleted dags + """ + log.info("Deleting DAG: %s", dag_id) + running_tis = ( + session.query(models.TaskInstance.state) + .filter(models.TaskInstance.dag_id == dag_id) + .filter(models.TaskInstance.state == State.RUNNING) + .first() + ) + if running_tis: + raise AirflowException("TaskInstances still running") + dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first() + if dag is None: + raise DagNotFound(f"Dag id {dag_id} not found") + + # Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval. + # There may be a lag, so explicitly removes serialized DAG here. + if SerializedDagModel.has_dag(dag_id=dag_id, session=session): + SerializedDagModel.remove_dag(dag_id=dag_id, session=session) + + count = 0 + + for model in get_sqla_model_classes(): + if hasattr(model, "dag_id"): + if keep_records_in_log and model.__name__ == 'Log': + continue + cond = or_(model.dag_id == dag_id, model.dag_id.like(dag_id + ".%")) + count += session.query(model).filter(cond).delete(synchronize_session='fetch') + if dag.is_subdag: + parent_dag_id, task_id = dag_id.rsplit(".", 1) + for model in TaskFail, models.TaskInstance: + count += ( + session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete() + ) + + # Delete entries in Import Errors table for a deleted DAG + # This handles the case when the dag_id is changed in the file + session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete( + synchronize_session='fetch' + ) + + return count diff --git a/airflow/api/common/experimental/delete_dag.py b/airflow/api/common/experimental/delete_dag.py index 44e54e3738349..36bf7dd8c46a7 100644 --- a/airflow/api/common/experimental/delete_dag.py +++ b/airflow/api/common/experimental/delete_dag.py @@ -15,68 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Delete DAGs APIs.""" -import logging +import warnings -from sqlalchemy import or_ +from airflow.api.common.delete_dag import * # noqa -from airflow import models -from airflow.exceptions import AirflowException, DagNotFound -from airflow.models import DagModel, TaskFail -from airflow.models.serialized_dag import SerializedDagModel -from airflow.utils.session import provide_session -from airflow.utils.state import State - -log = logging.getLogger(__name__) - - -@provide_session -def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int: - """ - :param dag_id: the dag_id of the DAG to delete - :param keep_records_in_log: whether keep records of the given dag_id - in the Log table in the backend database (for reasons like auditing). - The default value is True. - :param session: session used - :return count of deleted dags - """ - log.info("Deleting DAG: %s", dag_id) - running_tis = ( - session.query(models.TaskInstance.state) - .filter(models.TaskInstance.dag_id == dag_id) - .filter(models.TaskInstance.state == State.RUNNING) - .first() - ) - if running_tis: - raise AirflowException("TaskInstances still running") - dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first() - if dag is None: - raise DagNotFound(f"Dag id {dag_id} not found") - - # Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval. - # There may be a lag, so explicitly removes serialized DAG here. - if SerializedDagModel.has_dag(dag_id=dag_id, session=session): - SerializedDagModel.remove_dag(dag_id=dag_id, session=session) - - count = 0 - - for model in models.base.Base._decl_class_registry.values(): - if hasattr(model, "dag_id"): - if keep_records_in_log and model.__name__ == 'Log': - continue - cond = or_(model.dag_id == dag_id, model.dag_id.like(dag_id + ".%")) - count += session.query(model).filter(cond).delete(synchronize_session='fetch') - if dag.is_subdag: - parent_dag_id, task_id = dag_id.rsplit(".", 1) - for model in TaskFail, models.TaskInstance: - count += ( - session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete() - ) - - # Delete entries in Import Errors table for a deleted DAG - # This handles the case when the dag_id is changed in the file - session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete( - synchronize_session='fetch' - ) - - return count +warnings.warn( + "This module is deprecated. Please use `airflow.api.common.delete_dag` instead.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/api/common/experimental/get_code.py b/airflow/api/common/experimental/get_code.py index 79b0b9f492654..1a1fb621dbe48 100644 --- a/airflow/api/common/experimental/get_code.py +++ b/airflow/api/common/experimental/get_code.py @@ -16,11 +16,14 @@ # specific language governing permissions and limitations # under the License. """Get code APIs.""" +from deprecated import deprecated + from airflow.api.common.experimental import check_and_get_dag from airflow.exceptions import AirflowException, DagCodeNotFound from airflow.models.dagcode import DagCode +@deprecated(reason="Use DagCode().get_code_by_fileloc() instead", version="2.2.3") def get_code(dag_id: str) -> str: """Return python code of a given dag_id. diff --git a/airflow/api/common/experimental/get_dag_run_state.py b/airflow/api/common/experimental/get_dag_run_state.py index ca71a9afb3853..b2dedd5113ae9 100644 --- a/airflow/api/common/experimental/get_dag_run_state.py +++ b/airflow/api/common/experimental/get_dag_run_state.py @@ -19,9 +19,12 @@ from datetime import datetime from typing import Dict +from deprecated import deprecated + from airflow.api.common.experimental import check_and_get_dag, check_and_get_dagrun +@deprecated(reason="Use DagRun().get_state() instead", version="2.2.3") def get_dag_run_state(dag_id: str, execution_date: datetime) -> Dict[str, str]: """Return the Dag Run state identified by the given dag_id and execution_date. diff --git a/airflow/api/common/experimental/get_task.py b/airflow/api/common/experimental/get_task.py index 302ad6430efe9..fae5fd7ef1851 100644 --- a/airflow/api/common/experimental/get_task.py +++ b/airflow/api/common/experimental/get_task.py @@ -16,10 +16,13 @@ # specific language governing permissions and limitations # under the License. """Task APIs..""" +from deprecated import deprecated + from airflow.api.common.experimental import check_and_get_dag from airflow.models import TaskInstance +@deprecated(reason="Use DAG().get_task", version="2.2.3") def get_task(dag_id: str, task_id: str) -> TaskInstance: """Return the task object identified by the given dag_id and task_id.""" dag = check_and_get_dag(dag_id, task_id) diff --git a/airflow/api/common/experimental/get_task_instance.py b/airflow/api/common/experimental/get_task_instance.py index f3ca1cf2f6380..137f8a3aef9e7 100644 --- a/airflow/api/common/experimental/get_task_instance.py +++ b/airflow/api/common/experimental/get_task_instance.py @@ -18,11 +18,14 @@ """Task Instance APIs.""" from datetime import datetime +from deprecated import deprecated + from airflow.api.common.experimental import check_and_get_dag, check_and_get_dagrun from airflow.exceptions import TaskInstanceNotFound from airflow.models import TaskInstance +@deprecated(version="2.2.3", reason="Use DagRun.get_task_instance instead") def get_task_instance(dag_id: str, task_id: str, execution_date: datetime) -> TaskInstance: """Return the task instance identified by the given dag_id, task_id and execution_date.""" dag = check_and_get_dag(dag_id, task_id) diff --git a/airflow/api/common/experimental/pool.py b/airflow/api/common/experimental/pool.py index 30950ea0026ee..0b9c3a5d4903b 100644 --- a/airflow/api/common/experimental/pool.py +++ b/airflow/api/common/experimental/pool.py @@ -16,11 +16,14 @@ # specific language governing permissions and limitations # under the License. """Pool APIs.""" +from deprecated import deprecated + from airflow.exceptions import AirflowBadRequest, PoolNotFound from airflow.models import Pool from airflow.utils.session import provide_session +@deprecated(reason="Use Pool.get_pool() instead", version="2.2.3") @provide_session def get_pool(name, session=None): """Get pool by a given name.""" @@ -34,12 +37,14 @@ def get_pool(name, session=None): return pool +@deprecated(reason="Use Pool.get_pools() instead", version="2.2.3") @provide_session def get_pools(session=None): """Get all pools.""" return session.query(Pool).all() +@deprecated(reason="Use Pool.create_pool() instead", version="2.2.3") @provide_session def create_pool(name, slots, description, session=None): """Create a pool with a given parameters.""" @@ -70,6 +75,7 @@ def create_pool(name, slots, description, session=None): return pool +@deprecated(reason="Use Pool.delete_pool() instead", version="2.2.3") @provide_session def delete_pool(name, session=None): """Delete pool by a given name.""" diff --git a/airflow/api/common/experimental/trigger_dag.py b/airflow/api/common/experimental/trigger_dag.py index 38a873ce2e013..d52631281f534 100644 --- a/airflow/api/common/experimental/trigger_dag.py +++ b/airflow/api/common/experimental/trigger_dag.py @@ -15,114 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Triggering DAG runs APIs.""" -import json -from datetime import datetime -from typing import List, Optional, Union -from airflow.exceptions import DagNotFound, DagRunAlreadyExists -from airflow.models import DagBag, DagModel, DagRun -from airflow.utils import timezone -from airflow.utils.state import State -from airflow.utils.types import DagRunType +import warnings +from airflow.api.common.trigger_dag import * # noqa -def _trigger_dag( - dag_id: str, - dag_bag: DagBag, - run_id: Optional[str] = None, - conf: Optional[Union[dict, str]] = None, - execution_date: Optional[datetime] = None, - replace_microseconds: bool = True, -) -> List[DagRun]: - """Triggers DAG run. - - :param dag_id: DAG ID - :param dag_bag: DAG Bag model - :param run_id: ID of the dag_run - :param conf: configuration - :param execution_date: date of execution - :param replace_microseconds: whether microseconds should be zeroed - :return: list of triggered dags - """ - dag = dag_bag.get_dag(dag_id) # prefetch dag if it is stored serialized - - if dag_id not in dag_bag.dags: - raise DagNotFound(f"Dag id {dag_id} not found") - - execution_date = execution_date if execution_date else timezone.utcnow() - - if not timezone.is_localized(execution_date): - raise ValueError("The execution_date should be localized") - - if replace_microseconds: - execution_date = execution_date.replace(microsecond=0) - - if dag.default_args and 'start_date' in dag.default_args: - min_dag_start_date = dag.default_args["start_date"] - if min_dag_start_date and execution_date < min_dag_start_date: - raise ValueError( - "The execution_date [{}] should be >= start_date [{}] from DAG's default_args".format( - execution_date.isoformat(), min_dag_start_date.isoformat() - ) - ) - - run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date) - dag_run = DagRun.find_duplicate(dag_id=dag_id, execution_date=execution_date, run_id=run_id) - - if dag_run: - raise DagRunAlreadyExists( - f"A Dag Run already exists for dag id {dag_id} at {execution_date} with run id {run_id}" - ) - - run_conf = None - if conf: - run_conf = conf if isinstance(conf, dict) else json.loads(conf) - - dag_runs = [] - dags_to_run = [dag] + dag.subdags - for _dag in dags_to_run: - dag_run = _dag.create_dagrun( - run_id=run_id, - execution_date=execution_date, - state=State.QUEUED, - conf=run_conf, - external_trigger=True, - dag_hash=dag_bag.dags_hash.get(dag_id), - ) - dag_runs.append(dag_run) - - return dag_runs - - -def trigger_dag( - dag_id: str, - run_id: Optional[str] = None, - conf: Optional[Union[dict, str]] = None, - execution_date: Optional[datetime] = None, - replace_microseconds: bool = True, -) -> Optional[DagRun]: - """Triggers execution of DAG specified by dag_id - - :param dag_id: DAG ID - :param run_id: ID of the dag_run - :param conf: configuration - :param execution_date: date of execution - :param replace_microseconds: whether microseconds should be zeroed - :return: first dag run triggered - even if more than one Dag Runs were triggered or None - """ - dag_model = DagModel.get_current(dag_id) - if dag_model is None: - raise DagNotFound(f"Dag id {dag_id} not found in DagModel") - - dagbag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) - triggers = _trigger_dag( - dag_id=dag_id, - dag_bag=dagbag, - run_id=run_id, - conf=conf, - execution_date=execution_date, - replace_microseconds=replace_microseconds, - ) - - return triggers[0] if triggers else None +warnings.warn( + "This module is deprecated. Please use `airflow.api.common.trigger_dag` instead.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/api/common/trigger_dag.py b/airflow/api/common/trigger_dag.py new file mode 100644 index 0000000000000..70bbb78312209 --- /dev/null +++ b/airflow/api/common/trigger_dag.py @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Triggering DAG runs APIs.""" +import json +from datetime import datetime +from typing import List, Optional, Union + +from airflow.exceptions import DagNotFound, DagRunAlreadyExists +from airflow.models import DagBag, DagModel, DagRun +from airflow.utils import timezone +from airflow.utils.state import State +from airflow.utils.types import DagRunType + + +def _trigger_dag( + dag_id: str, + dag_bag: DagBag, + run_id: Optional[str] = None, + conf: Optional[Union[dict, str]] = None, + execution_date: Optional[datetime] = None, + replace_microseconds: bool = True, +) -> List[DagRun]: + """Triggers DAG run. + + :param dag_id: DAG ID + :param dag_bag: DAG Bag model + :param run_id: ID of the dag_run + :param conf: configuration + :param execution_date: date of execution + :param replace_microseconds: whether microseconds should be zeroed + :return: list of triggered dags + """ + dag = dag_bag.get_dag(dag_id) # prefetch dag if it is stored serialized + + if dag_id not in dag_bag.dags: + raise DagNotFound(f"Dag id {dag_id} not found") + + execution_date = execution_date if execution_date else timezone.utcnow() + + if not timezone.is_localized(execution_date): + raise ValueError("The execution_date should be localized") + + if replace_microseconds: + execution_date = execution_date.replace(microsecond=0) + + if dag.default_args and 'start_date' in dag.default_args: + min_dag_start_date = dag.default_args["start_date"] + if min_dag_start_date and execution_date < min_dag_start_date: + raise ValueError( + f"The execution_date [{execution_date.isoformat()}] should be >= start_date " + f"[{min_dag_start_date.isoformat()}] from DAG's default_args" + ) + + run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date) + dag_run = DagRun.find_duplicate(dag_id=dag_id, execution_date=execution_date, run_id=run_id) + + if dag_run: + raise DagRunAlreadyExists( + f"A Dag Run already exists for dag id {dag_id} at {execution_date} with run id {run_id}" + ) + + run_conf = None + if conf: + run_conf = conf if isinstance(conf, dict) else json.loads(conf) + + dag_runs = [] + dags_to_run = [dag] + dag.subdags + for _dag in dags_to_run: + dag_run = _dag.create_dagrun( + run_id=run_id, + execution_date=execution_date, + state=State.QUEUED, + conf=run_conf, + external_trigger=True, + dag_hash=dag_bag.dags_hash.get(dag_id), + ) + dag_runs.append(dag_run) + + return dag_runs + + +def trigger_dag( + dag_id: str, + run_id: Optional[str] = None, + conf: Optional[Union[dict, str]] = None, + execution_date: Optional[datetime] = None, + replace_microseconds: bool = True, +) -> Optional[DagRun]: + """Triggers execution of DAG specified by dag_id + + :param dag_id: DAG ID + :param run_id: ID of the dag_run + :param conf: configuration + :param execution_date: date of execution + :param replace_microseconds: whether microseconds should be zeroed + :return: first dag run triggered - even if more than one Dag Runs were triggered or None + """ + dag_model = DagModel.get_current(dag_id) + if dag_model is None: + raise DagNotFound(f"Dag id {dag_id} not found in DagModel") + + dagbag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) + triggers = _trigger_dag( + dag_id=dag_id, + dag_bag=dagbag, + run_id=run_id, + conf=conf, + execution_date=execution_date, + replace_microseconds=replace_microseconds, + ) + + return triggers[0] if triggers else None diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index c164fccc37dbf..286b191601caf 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -110,13 +110,10 @@ def patch_dag(session, dag_id, update_mask=None): @provide_session def delete_dag(dag_id: str, session: Session): """Delete the specific DAG.""" - # TODO: This function is shared with the /delete endpoint used by the web - # UI, so we're reusing it to simplify maintenance. Refactor the function to - # another place when the experimental/legacy API is removed. - from airflow.api.common.experimental import delete_dag + from airflow.api.common import delete_dag as delete_dag_module try: - delete_dag.delete_dag(dag_id, session=session) + delete_dag_module.delete_dag(dag_id, session=session) except DagNotFound: raise NotFound(f"Dag with id: '{dag_id}' not found") except AirflowException: diff --git a/airflow/models/pool.py b/airflow/models/pool.py index 6f217c4b025a2..8ae88aabcd45f 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -21,11 +21,11 @@ from sqlalchemy import Column, Integer, String, Text, func from sqlalchemy.orm.session import Session -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, PoolNotFound from airflow.models.base import Base from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.typing_compat import TypedDict -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import nowait, with_row_locks from airflow.utils.state import State @@ -57,7 +57,13 @@ def __repr__(self): @staticmethod @provide_session - def get_pool(pool_name, session: Session = None): + def get_pools(session: Session = NEW_SESSION): + """Get all pools.""" + return session.query(Pool).all() + + @staticmethod + @provide_session + def get_pool(pool_name: str, session: Session = NEW_SESSION): """ Get the Pool with specific pool name from the Pools. @@ -69,7 +75,7 @@ def get_pool(pool_name, session: Session = None): @staticmethod @provide_session - def get_default_pool(session: Session = None): + def get_default_pool(session: Session = NEW_SESSION): """ Get the Pool of the default_pool from the Pools. @@ -78,12 +84,46 @@ def get_default_pool(session: Session = None): """ return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session) + @staticmethod + @provide_session + def create_or_update_pool(name: str, slots: int, description: str, session: Session = NEW_SESSION): + """Create a pool with given parameters or update it if it already exists.""" + if not name: + return + pool = session.query(Pool).filter_by(pool=name).first() + if pool is None: + pool = Pool(pool=name, slots=slots, description=description) + session.add(pool) + else: + pool.slots = slots + pool.description = description + + session.commit() + + return pool + + @staticmethod + @provide_session + def delete_pool(name: str, session: Session = NEW_SESSION): + """Delete pool by a given name.""" + if name == Pool.DEFAULT_POOL_NAME: + raise AirflowException("default_pool cannot be deleted") + + pool = session.query(Pool).filter_by(pool=name).first() + if pool is None: + raise PoolNotFound(f"Pool '{name}' doesn't exist") + + session.delete(pool) + session.commit() + + return pool + @staticmethod @provide_session def slots_stats( *, lock_rows: bool = False, - session: Session = None, + session: Session = NEW_SESSION, ) -> Dict[str, PoolStats]: """ Get Pool stats (Number of Running, Queued, Open & Total tasks) @@ -210,7 +250,7 @@ def queued_slots(self, session: Session): ) @provide_session - def open_slots(self, session: Session) -> float: + def open_slots(self, session: Session = NEW_SESSION) -> float: """ Get the number of slots open at the moment. diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 1e6cb7f6ab38f..421c7963d0d68 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -21,7 +21,7 @@ import time from typing import Dict, List, Optional, Union -from airflow.api.common.experimental.trigger_dag import trigger_dag +from airflow.api.common.trigger_dag import trigger_dag from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists from airflow.models import BaseOperator, BaseOperatorLink, DagBag, DagModel, DagRun from airflow.utils import timezone diff --git a/airflow/utils/db.py b/airflow/utils/db.py index f35d1659f8cb9..023f482790d00 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -991,3 +991,18 @@ def check(session=None): """ session.execute('select 1 as is_alive;') log.info("Connection successful.") + + +def get_sqla_model_classes(): + """ + Get all SQLAlchemy class mappers. + + SQLAlchemy < 1.4 does not support registry.mappers so we use + try/except to handle it. + """ + from airflow.models.base import Base + + try: + return [mapper.class_ for mapper in Base.registry.mappers] + except AttributeError: + return Base._decl_class_registry.values() diff --git a/airflow/www/views.py b/airflow/www/views.py index 2182a1706aeec..f2642a73f1f0e 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1607,7 +1607,7 @@ def run(self): @action_logging def delete(self): """Deletes DAG.""" - from airflow.api.common.experimental import delete_dag + from airflow.api.common import delete_dag from airflow.exceptions import DagNotFound dag_id = request.values.get('dag_id') diff --git a/setup.cfg b/setup.cfg index b83ef9be02826..c3cce1c0ac0c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -95,6 +95,7 @@ install_requires = croniter>=0.3.17 cryptography>=0.9.3 dataclasses;python_version<"3.7" + deprecated>=1.2.13 dill>=0.2.2, <0.4 # Sphinx RTD theme 0.5.2. introduced limitation to docutils to account for some docutils markup # change: diff --git a/tests/api/client/test_local_client.py b/tests/api/client/test_local_client.py index a2af8ca245e6b..9f574e4fc657a 100644 --- a/tests/api/client/test_local_client.py +++ b/tests/api/client/test_local_client.py @@ -17,6 +17,8 @@ # under the License. import json +import random +import string import unittest from unittest.mock import ANY, patch @@ -25,7 +27,7 @@ from airflow.api.client.local_client import Client from airflow.example_dags import example_bash_operator -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowBadRequest, AirflowException, PoolNotFound from airflow.models import DAG, DagBag, DagModel, DagRun, Pool from airflow.utils import timezone from airflow.utils.session import create_session @@ -133,6 +135,10 @@ def test_get_pool(self): pool = self.client.get_pool(name='foo') assert pool == ('foo', 1, '') + def test_get_pool_non_existing_raises(self): + with pytest.raises(PoolNotFound): + self.client.get_pool(name='foo') + def test_get_pools(self): self.client.create_pool(name='foo1', slots=1, description='') self.client.create_pool(name='foo2', slots=2, description='') @@ -145,6 +151,26 @@ def test_create_pool(self): with create_session() as session: assert session.query(Pool).count() == 2 + def test_create_pool_bad_slots(self): + with pytest.raises(AirflowBadRequest, match="^Bad value for `slots`: foo$"): + self.client.create_pool( + name='foo', + slots='foo', + description='', + ) + + def test_create_pool_name_too_long(self): + long_name = ''.join(random.choices(string.ascii_lowercase, k=300)) + pool_name_length = Pool.pool.property.columns[0].type.length + with pytest.raises( + AirflowBadRequest, match=f"^pool name cannot be more than {pool_name_length} characters" + ): + self.client.create_pool( + name=long_name, + slots=5, + description='', + ) + def test_delete_pool(self): self.client.create_pool(name='foo', slots=1, description='') with create_session() as session: @@ -152,3 +178,6 @@ def test_delete_pool(self): self.client.delete_pool(name='foo') with create_session() as session: assert session.query(Pool).count() == 1 + for name in ('', ' '): + with pytest.raises(PoolNotFound, match=f"^Pool {name!r} doesn't exist$"): + Pool.delete_pool(name=name) diff --git a/tests/api/common/experimental/test_delete_dag.py b/tests/api/common/test_delete_dag.py similarity index 99% rename from tests/api/common/experimental/test_delete_dag.py rename to tests/api/common/test_delete_dag.py index 5984cd2b14f0f..0eb058a18337a 100644 --- a/tests/api/common/experimental/test_delete_dag.py +++ b/tests/api/common/test_delete_dag.py @@ -20,7 +20,7 @@ import pytest from airflow import models -from airflow.api.common.experimental.delete_dag import delete_dag +from airflow.api.common.delete_dag import delete_dag from airflow.exceptions import AirflowException, DagNotFound from airflow.operators.dummy import DummyOperator from airflow.utils.dates import days_ago diff --git a/tests/api/common/experimental/test_trigger_dag.py b/tests/api/common/test_trigger_dag.py similarity index 93% rename from tests/api/common/experimental/test_trigger_dag.py rename to tests/api/common/test_trigger_dag.py index 2f164468d085f..f79d413ed5eae 100644 --- a/tests/api/common/experimental/test_trigger_dag.py +++ b/tests/api/common/test_trigger_dag.py @@ -22,7 +22,7 @@ import pytest from parameterized import parameterized -from airflow.api.common.experimental.trigger_dag import _trigger_dag +from airflow.api.common.trigger_dag import _trigger_dag from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun from airflow.utils import timezone @@ -42,7 +42,7 @@ def test_trigger_dag_dag_not_found(self, dag_bag_mock): with pytest.raises(AirflowException): _trigger_dag('dag_not_found', dag_bag_mock) - @mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun) + @mock.patch('airflow.api.common.trigger_dag.DagRun', spec=DagRun) @mock.patch('airflow.models.DagBag') def test_trigger_dag_dag_run_exist(self, dag_bag_mock, dag_run_mock): dag_id = "dag_run_exist" @@ -54,7 +54,7 @@ def test_trigger_dag_dag_run_exist(self, dag_bag_mock, dag_run_mock): _trigger_dag(dag_id, dag_bag_mock) @mock.patch('airflow.models.DAG') - @mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun) + @mock.patch('airflow.api.common.trigger_dag.DagRun', spec=DagRun) @mock.patch('airflow.models.DagBag') def test_trigger_dag_include_subdags(self, dag_bag_mock, dag_run_mock, dag_mock): dag_id = "trigger_dag" @@ -70,7 +70,7 @@ def test_trigger_dag_include_subdags(self, dag_bag_mock, dag_run_mock, dag_mock) assert 3 == len(triggers) @mock.patch('airflow.models.DAG') - @mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun) + @mock.patch('airflow.api.common.trigger_dag.DagRun', spec=DagRun) @mock.patch('airflow.models.DagBag') def test_trigger_dag_include_nested_subdags(self, dag_bag_mock, dag_run_mock, dag_mock): dag_id = "trigger_dag" diff --git a/tests/models/test_pool.py b/tests/models/test_pool.py index 00fe14039d7e3..95e585efa5974 100644 --- a/tests/models/test_pool.py +++ b/tests/models/test_pool.py @@ -16,11 +16,15 @@ # specific language governing permissions and limitations # under the License. +import pytest + from airflow import settings +from airflow.exceptions import AirflowException, PoolNotFound from airflow.models.pool import Pool from airflow.models.taskinstance import TaskInstance as TI from airflow.operators.dummy import DummyOperator from airflow.utils import timezone +from airflow.utils.session import create_session from airflow.utils.state import State from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots @@ -28,6 +32,10 @@ class TestPool: + + USER_POOL_COUNT = 2 + TOTAL_POOL_COUNT = USER_POOL_COUNT + 1 # including default_pool + @staticmethod def clean_db(): clear_db_dags() @@ -36,6 +44,20 @@ def clean_db(): def setup_method(self): self.clean_db() + self.pools = [] + + def add_pools(self): + self.pools = [Pool.get_default_pool()] + for i in range(self.USER_POOL_COUNT): + name = f'experimental_{i + 1}' + pool = Pool( + pool=name, + slots=i, + description=name, + ) + self.pools.append(pool) + with create_session() as session: + session.add_all(self.pools) def teardown_method(self): self.clean_db() @@ -149,3 +171,52 @@ def test_default_pool_open_slots(self, dag_maker): "running": 1, } } == Pool.slots_stats() + + def test_get_pool(self): + self.add_pools() + pool = Pool.get_pool(pool_name=self.pools[0].pool) + assert pool.pool == self.pools[0].pool + + def test_get_pool_non_existing(self): + self.add_pools() + assert not Pool.get_pool(pool_name='test') + + def test_get_pool_bad_name(self): + for name in ('', ' '): + assert not Pool.get_pool(pool_name=name) + + def test_get_pools(self): + self.add_pools() + pools = sorted(Pool.get_pools(), key=lambda p: p.pool) + assert pools[0].pool == self.pools[0].pool + assert pools[1].pool == self.pools[1].pool + + def test_create_pool(self, session): + self.add_pools() + pool = Pool.create_or_update_pool(name='foo', slots=5, description='') + assert pool.pool == 'foo' + assert pool.slots == 5 + assert pool.description == '' + assert session.query(Pool).count() == self.TOTAL_POOL_COUNT + 1 + + def test_create_pool_existing(self, session): + self.add_pools() + pool = Pool.create_or_update_pool(name=self.pools[0].pool, slots=5, description='') + assert pool.pool == self.pools[0].pool + assert pool.slots == 5 + assert pool.description == '' + assert session.query(Pool).count() == self.TOTAL_POOL_COUNT + + def test_delete_pool(self, session): + self.add_pools() + pool = Pool.delete_pool(name=self.pools[-1].pool) + assert pool.pool == self.pools[-1].pool + assert session.query(Pool).count() == self.TOTAL_POOL_COUNT - 1 + + def test_delete_pool_non_existing(self): + with pytest.raises(PoolNotFound, match="^Pool 'test' doesn't exist$"): + Pool.delete_pool(name='test') + + def test_delete_default_pool_not_allowed(self): + with pytest.raises(AirflowException, match="^default_pool cannot be deleted$"): + Pool.delete_pool(Pool.DEFAULT_POOL_NAME)