Skip to content

Commit

Permalink
Sanitize DagRun.run_id and allow flexibility
Browse files Browse the repository at this point in the history
This commit sanitizes the DagRun.run_id parameter by introducing a configurable option.
Users now have the ability to select a specific run_id pattern for their runs,
ensuring stricter control over the values used. This update does not impact the default run_id
generation performed by the scheduler for scheduled DAG runs or for Dag runs triggered without
modifying the run_id parameter in the run configuration page.
The configuration flexibility empowers users to align the run_id pattern with their specific requirements.
  • Loading branch information
ephraimbuddy committed Jul 2, 2023
1 parent 2811ba7 commit 155cfe3
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 36 deletions.
9 changes: 9 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2555,6 +2555,15 @@ scheduler:
type: float
example: ~
default: "120.0"
allowed_run_id_pattern:
description: |
The run_id pattern used to verify the validity of user input to the run_id parameter when
triggering a DAG. This pattern cannot change the pattern used by scheduler to generate run_id
for scheduled DAG runs or DAG runs triggered without changing the run_id parameter.
version_added: 2.6.3
type: string
example: ~
default: "^[A-Za-z0-9_.~:+-]+$"
triggerer:
description: ~
options:
Expand Down
5 changes: 5 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,11 @@ task_queued_timeout = 600.0
# longer than `[scheduler] task_queued_timeout`.
task_queued_timeout_check_interval = 120.0

# The run_id pattern used to verify the validity of user input to the run_id parameter when
# triggering a DAG. This pattern cannot change the pattern used by scheduler to generate run_id
# for scheduled DAG runs or DAG runs triggered without changing the run_id parameter.
allowed_run_id_pattern = ^[A-Za-z0-9_.~:+-]+$

[triggerer]
# How many triggers a single Triggerer will run at once, by default.
default_capacity = 1000
Expand Down
55 changes: 29 additions & 26 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
import airflow.templates
from airflow import settings, utils
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf, secrets_backend_list
from airflow.configuration import conf as airflow_conf, secrets_backend_list
from airflow.exceptions import (
AirflowDagInconsistent,
AirflowException,
Expand All @@ -96,7 +96,7 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import DagRun
from airflow.models.dagrun import RUN_ID_REGEX, DagRun
from airflow.models.operator import Operator
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import Context, TaskInstance, TaskInstanceKey, clear_task_instances
Expand Down Expand Up @@ -422,13 +422,13 @@ def __init__(
user_defined_filters: dict | None = None,
default_args: dict | None = None,
concurrency: int | None = None,
max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"),
max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"),
max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"),
max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"),
dagrun_timeout: timedelta | None = None,
sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None,
default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(),
orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"),
catchup: bool = conf.getboolean("scheduler", "catchup_by_default"),
default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower(),
orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation"),
catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"),
on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
doc_md: str | None = None,
Expand Down Expand Up @@ -2588,7 +2588,7 @@ def run(
mark_success=False,
local=False,
executor=None,
donot_pickle=conf.getboolean("core", "donot_pickle"),
donot_pickle=airflow_conf.getboolean("core", "donot_pickle"),
ignore_task_deps=False,
ignore_first_depends_on_past=True,
pool=None,
Expand Down Expand Up @@ -2826,13 +2826,14 @@ def create_dagrun(
"Creating DagRun needs either `run_id` or both `run_type` and `execution_date`"
)

if run_id and "/" in run_id:
warnings.warn(
"Using forward slash ('/') in a DAG run ID is deprecated. Note that this character "
"also makes the run impossible to retrieve via Airflow's REST API.",
RemovedInAirflow3Warning,
stacklevel=3,
)
regex = airflow_conf.get("scheduler", "allowed_run_id_pattern")

if run_id and not re.match(RUN_ID_REGEX, run_id):
if not regex.strip() or not re.match(regex.strip(), run_id):
raise AirflowException(
f"The provided run ID '{run_id}' is invalid. It does not match either "
f"the configured pattern: '{regex}' or the built-in pattern: '{RUN_ID_REGEX}'"
)

# create a copy of params before validating
copied_params = copy.deepcopy(self.params)
Expand Down Expand Up @@ -3125,7 +3126,7 @@ def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION):
def get_default_view(self):
"""This is only there for backward compatible jinja2 templates."""
if self.default_view is None:
return conf.get("webserver", "dag_default_view").lower()
return airflow_conf.get("webserver", "dag_default_view").lower()
else:
return self.default_view

Expand Down Expand Up @@ -3342,7 +3343,7 @@ class DagModel(Base):
root_dag_id = Column(StringID())
# A DAG can be paused from the UI / DB
# Set this default value of is_paused based on a configuration value!
is_paused_at_creation = conf.getboolean("core", "dags_are_paused_at_creation")
is_paused_at_creation = airflow_conf.getboolean("core", "dags_are_paused_at_creation")
is_paused = Column(Boolean, default=is_paused_at_creation)
# Whether the DAG is a subdag
is_subdag = Column(Boolean, default=False)
Expand Down Expand Up @@ -3416,7 +3417,9 @@ class DagModel(Base):
"TaskOutletDatasetReference",
cascade="all, delete, delete-orphan",
)
NUM_DAGS_PER_DAGRUN_QUERY = conf.getint("scheduler", "max_dagruns_to_create_per_loop", fallback=10)
NUM_DAGS_PER_DAGRUN_QUERY = airflow_conf.getint(
"scheduler", "max_dagruns_to_create_per_loop", fallback=10
)

def __init__(self, concurrency=None, **kwargs):
super().__init__(**kwargs)
Expand All @@ -3429,10 +3432,10 @@ def __init__(self, concurrency=None, **kwargs):
)
self.max_active_tasks = concurrency
else:
self.max_active_tasks = conf.getint("core", "max_active_tasks_per_dag")
self.max_active_tasks = airflow_conf.getint("core", "max_active_tasks_per_dag")

if self.max_active_runs is None:
self.max_active_runs = conf.getint("core", "max_active_runs_per_dag")
self.max_active_runs = airflow_conf.getint("core", "max_active_runs_per_dag")

if self.has_task_concurrency_limits is None:
# Be safe -- this will be updated later once the DAG is parsed
Expand Down Expand Up @@ -3510,7 +3513,7 @@ def get_default_view(self) -> str:
have a value.
"""
# This is for backwards-compatibility with old dags that don't have None as default_view
return self.default_view or conf.get_mandatory_value("webserver", "dag_default_view").lower()
return self.default_view or airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower()

@property
def safe_dag_id(self):
Expand Down Expand Up @@ -3699,13 +3702,13 @@ def dag(
user_defined_filters: dict | None = None,
default_args: dict | None = None,
concurrency: int | None = None,
max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"),
max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"),
max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"),
max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"),
dagrun_timeout: timedelta | None = None,
sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None,
default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(),
orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"),
catchup: bool = conf.getboolean("scheduler", "catchup_by_default"),
default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower(),
orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation"),
catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"),
on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
doc_md: str | None = None,
Expand Down
16 changes: 15 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload

import re2 as re
from sqlalchemy import (
Boolean,
Column,
Expand All @@ -44,7 +45,7 @@
)
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, Session, declared_attr, joinedload, relationship, synonym
from sqlalchemy.orm import Query, Session, declared_attr, joinedload, relationship, synonym, validates
from sqlalchemy.sql.expression import false, select, true

from airflow import settings
Expand Down Expand Up @@ -76,6 +77,8 @@
CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI])
TaskCreator = Callable[[Operator, Iterable[int]], CreatedTasks]

RUN_ID_REGEX = r"^(?:manual|scheduled|dataset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$"


class TISchedulingDecision(NamedTuple):
"""Type of return for DagRun.task_instance_scheduling_decisions."""
Expand Down Expand Up @@ -240,6 +243,17 @@ def __repr__(self):
external_trigger=self.external_trigger,
)

@validates("run_id")
def validate_run_id(self, key: str, run_id: str) -> str | None:
if not run_id:
return None
regex = airflow_conf.get("scheduler", "allowed_run_id_pattern")
if not re.match(regex, run_id) and not re.match(RUN_ID_REGEX, run_id):
raise ValueError(
f"The run_id provided '{run_id}' does not match the pattern '{regex}' or '{RUN_ID_REGEX}'"
)
return run_id

@property
def stats_tags(self) -> dict[str, str]:
return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
Expand Down
32 changes: 23 additions & 9 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.dag import DAG, get_dataset_triggered_next_run_info
from airflow.models.dagcode import DagCode
from airflow.models.dagrun import DagRun, DagRunType
from airflow.models.dagrun import RUN_ID_REGEX, DagRun, DagRunType
from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
Expand Down Expand Up @@ -1975,7 +1975,7 @@ def delete(self):
@provide_session
def trigger(self, dag_id: str, session: Session = NEW_SESSION):
"""Triggers DAG Run."""
run_id = request.values.get("run_id", "")
run_id = request.values.get("run_id", "").replace(" ", "+")
origin = get_safe_url(request.values.get("origin"))
unpause = request.values.get("unpause")
request_conf = request.values.get("conf")
Expand Down Expand Up @@ -2096,13 +2096,27 @@ def trigger(self, dag_id: str, session: Session = NEW_SESSION):
flash(message, "error")
return redirect(origin)

# Flash a warning when slash is used, but still allow it to continue on.
if run_id and "/" in run_id:
flash(
"Using forward slash ('/') in a DAG run ID is deprecated. Note that this character "
"also makes the run impossible to retrieve via Airflow's REST API.",
"warning",
)
regex = conf.get("scheduler", "allowed_run_id_pattern")
if run_id and not re.match(RUN_ID_REGEX, run_id):
if not regex.strip() or not re.match(regex.strip(), run_id):
flash(
f"The provided run ID '{run_id}' is invalid. It does not match either "
f"the configured pattern: '{regex}' or the built-in pattern: '{RUN_ID_REGEX}'",
"error",
)

form = DateTimeForm(data={"execution_date": execution_date})
return self.render_template(
"airflow/trigger.html",
form_fields=form_fields,
dag=dag,
dag_id=dag_id,
origin=origin,
conf=request_conf,
form=form,
is_dag_run_conf_overrides_params=is_dag_run_conf_overrides_params,
recent_confs=recent_confs,
)

run_conf = {}
if request_conf:
Expand Down
32 changes: 32 additions & 0 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow import settings
from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.decorators import setup, task, task_group, teardown
from airflow.exceptions import AirflowException
from airflow.models import (
DAG,
DagBag,
Expand All @@ -54,6 +55,7 @@
from airflow.utils.types import DagRunType
from tests.models import DEFAULT_DATE as _DEFAULT_DATE
from tests.test_utils import db
from tests.test_utils.config import conf_vars
from tests.test_utils.mock_operators import MockOperator

DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE)
Expand Down Expand Up @@ -2541,3 +2543,33 @@ def make_task(task_id, dag):
tis = dr.task_instance_scheduling_decisions(session).tis
tis_for_state = {x.task_id for x in dr._tis_for_dagrun_state(dag=dag, tis=tis)}
assert tis_for_state == expected


@pytest.mark.parametrize(
"pattern, run_id, result",
[
["^[A-Z]", "ABC", True],
["^[A-Z]", "abc", False],
["^[0-9]", "123", True],
# The below params tests that user configuration does not affect internally generated
# run_ids
["", "scheduled__2023-01-01T00:00:00+00:00", True],
["", "manual__2023-01-01T00:00:00+00:00", True],
["", "dataset_triggered__2023-01-01T00:00:00+00:00", True],
["", "scheduled_2023-01-01T00", False],
["", "manual_2023-01-01T00", False],
["", "dataset_triggered_2023-01-01T00", False],
["^[0-9]", "scheduled__2023-01-01T00:00:00+00:00", True],
["^[0-9]", "manual__2023-01-01T00:00:00+00:00", True],
["^[a-z]", "dataset_triggered__2023-01-01T00:00:00+00:00", True],
],
)
def test_dag_run_id_config(session, dag_maker, pattern, run_id, result):
with conf_vars({("scheduler", "allowed_run_id_pattern"): pattern}):
with dag_maker():
...
if result:
dag_maker.create_dagrun(run_id=run_id)
else:
with pytest.raises(AirflowException):
dag_maker.create_dagrun(run_id=run_id)
30 changes: 30 additions & 0 deletions tests/www/views/test_views_trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.utils.session import create_session
from airflow.utils.types import DagRunType
from tests.test_utils.api_connexion_utils import create_test_client
from tests.test_utils.config import conf_vars
from tests.test_utils.www import check_content_in_response


Expand Down Expand Up @@ -287,3 +288,32 @@ def test_trigger_dag_params_array_value_none_render(admin_client, dag_maker, ses
f'<textarea style="display: none;" id="json_start" name="json_start">{expected_dag_conf}</textarea>',
resp,
)


@pytest.mark.parametrize(
"pattern, run_id, result",
[
["^[A-Z]", "ABC", True],
["^[A-Z]", "abc", False],
["^[0-9]", "123", True],
# The below params tests that user configuration does not affect internally generated
# run_ids. We use manual__ as a prefix for manually triggered DAGs due to a restriction
# in manually triggered DAGs that the run_id must not start with scheduled__.
["", "manual__2023-01-01T00:00:00+00:00", True],
["", "scheduled_2023-01-01T00", False],
["", "manual_2023-01-01T00", False],
["", "dataset_triggered_2023-01-01T00", False],
["^[0-9]", "manual__2023-01-01T00:00:00+00:00", True],
["^[a-z]", "manual__2023-01-01T00:00:00+00:00", True],
],
)
def test_dag_run_id_pattern(session, admin_client, pattern, run_id, result):
with conf_vars({("scheduler", "allowed_run_id_pattern"): pattern}):
test_dag_id = "example_bash_operator"
admin_client.post(f"dags/{test_dag_id}/trigger?&run_id={run_id}")
run = session.query(DagRun).filter(DagRun.dag_id == test_dag_id).first()
if result:
assert run is not None
assert run.run_type == DagRunType.MANUAL
else:
assert run is None

0 comments on commit 155cfe3

Please sign in to comment.