Skip to content

Commit

Permalink
Use linear time regular expressions
Browse files Browse the repository at this point in the history
The standard regexp library can consume > O(n) in certain circumstances.
The re2 library does not have this issue.
  • Loading branch information
pierrejeambrun committed Jul 4, 2023
1 parent 575bf2f commit 353df22
Show file tree
Hide file tree
Showing 29 changed files with 169 additions and 152 deletions.
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,14 @@ repos:
language: python
pass_filenames: true
files: ^tests/.*\.py$
- id: check-usage-of-re2-over-re
language: pygrep
name: Use re2 over re
description: Use re2 module instead of re
entry: "^\\s*from re\\s|^\\s*import re\\s"
pass_filenames: true
files: \.py$
exclude: ^airflow/providers|^dev/.*\.py$|^scripts/.*\.py$|^tests/|^docker_tests/|^docs/.*\.py$|^airflow/utils/helpers.py$
## ADD MOST PRE-COMMITS ABOVE THAT LINE
# The below pre-commits are those requiring CI image to be built
- id: mypy-dev
Expand Down
2 changes: 2 additions & 0 deletions STATIC_CODE_CHECKS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ require Breeze Docker image to be built locally.
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-urlparse-usage-in-code | Don't use urlparse in code | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-usage-of-re2-over-re | Use re2 over re | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-xml | Check XML files with xmllint | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| codespell | Run codespell to check for common misspellings in files | |
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/provider_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

import re
import re2

from airflow.api_connexion import security
from airflow.api_connexion.schemas.provider_schema import (
Expand All @@ -30,7 +30,7 @@


def _remove_rst_syntax(value: str) -> str:
return re.sub("[`_<>]", "", value.strip(" \n."))
return re2.sub("[`_<>]", "", value.strip(" \n."))


def _provider_mapper(provider: ProviderInfo) -> Provider:
Expand Down
4 changes: 2 additions & 2 deletions airflow/cli/commands/provider_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Providers sub-commands."""
from __future__ import annotations

import re
import re2

from airflow.cli.simple_table import AirflowConsole
from airflow.providers_manager import ProvidersManager
Expand All @@ -27,7 +27,7 @@


def _remove_rst_syntax(value: str) -> str:
return re.sub("[`_<>]", "", value.strip(" \n."))
return re2.sub("[`_<>]", "", value.strip(" \n."))


@suppress_logs_and_warning
Expand Down
4 changes: 2 additions & 2 deletions airflow/cli/commands/user_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import json
import os
import random
import re
import string
from typing import Any

import re2
from marshmallow import Schema, fields, validate
from marshmallow.exceptions import ValidationError

Expand Down Expand Up @@ -164,7 +164,7 @@ def users_export(args):
# In the User model the first and last name fields have underscores,
# but the corresponding parameters in the CLI don't
def remove_underscores(s):
return re.sub("_", "", s)
return re2.sub("_", "", s)

users = [
{
Expand Down
7 changes: 4 additions & 3 deletions airflow/config_templates/default_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from __future__ import annotations

import logging
import re
import ssl

import re2

from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, AirflowException

Expand Down Expand Up @@ -88,7 +89,7 @@ def _broker_supports_visibility_timeout(url):
"ca_certs": conf.get("celery", "SSL_CACERT"),
"cert_reqs": ssl.CERT_REQUIRED,
}
elif broker_url and re.search("rediss?://|sentinel://", broker_url):
elif broker_url and re2.search("rediss?://|sentinel://", broker_url):
broker_use_ssl = {
"ssl_keyfile": conf.get("celery", "SSL_KEY"),
"ssl_certfile": conf.get("celery", "SSL_CERT"),
Expand All @@ -114,7 +115,7 @@ def _broker_supports_visibility_timeout(url):
f"all necessary certs and key ({e})."
)

if re.search("rediss?://|amqp://|rpc://", result_backend):
if re2.search("rediss?://|amqp://|rpc://", result_backend):
log.warning(
"You have configured a result_backend of %s, it is highly recommended "
"to use an alternative result_backend (i.e. a database).",
Expand Down
23 changes: 11 additions & 12 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import multiprocessing
import os
import pathlib
import re
import shlex
import stat
import subprocess
Expand All @@ -36,10 +35,10 @@
from configparser import _UNSET, ConfigParser, NoOptionError, NoSectionError # type: ignore
from contextlib import contextmanager, suppress
from json.decoder import JSONDecodeError
from re import Pattern
from typing import IO, Any, Dict, Iterable, Set, Tuple, Union
from typing import IO, Any, Dict, Iterable, Pattern, Set, Tuple, Union
from urllib.parse import urlsplit

import re2
from typing_extensions import overload

from airflow.exceptions import AirflowConfigException
Expand All @@ -55,7 +54,7 @@
warnings.filterwarnings(action="default", category=DeprecationWarning, module="airflow")
warnings.filterwarnings(action="default", category=PendingDeprecationWarning, module="airflow")

_SQLITE3_VERSION_PATTERN = re.compile(r"(?P<version>^\d+(?:\.\d+)*)\D?.*$")
_SQLITE3_VERSION_PATTERN = re2.compile(r"(?P<version>^\d+(?:\.\d+)*)\D?.*$")

ConfigType = Union[str, int, float, bool]
ConfigOptionsDictType = Dict[str, ConfigType]
Expand Down Expand Up @@ -269,36 +268,36 @@ def inversed_deprecated_sections(self):
# about. Mapping of section -> setting -> { old, replace, by_version }
deprecated_values: dict[str, dict[str, tuple[Pattern, str, str]]] = {
"core": {
"hostname_callable": (re.compile(r":"), r".", "2.1"),
"hostname_callable": (re2.compile(r":"), r".", "2.1"),
},
"webserver": {
"navbar_color": (re.compile(r"\A#007A87\Z", re.IGNORECASE), "#fff", "2.1"),
"dag_default_view": (re.compile(r"^tree$"), "grid", "3.0"),
"navbar_color": (re2.compile(r"(?i)\A#007A87\z"), "#fff", "2.1"),
"dag_default_view": (re2.compile(r"^tree$"), "grid", "3.0"),
},
"email": {
"email_backend": (
re.compile(r"^airflow\.contrib\.utils\.sendgrid\.send_email$"),
re2.compile(r"^airflow\.contrib\.utils\.sendgrid\.send_email$"),
r"airflow.providers.sendgrid.utils.emailer.send_email",
"2.1",
),
},
"logging": {
"log_filename_template": (
re.compile(re.escape("{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log")),
re2.compile(re2.escape("{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log")),
"XX-set-after-default-config-loaded-XX",
"3.0",
),
},
"api": {
"auth_backends": (
re.compile(r"^airflow\.api\.auth\.backend\.deny_all$|^$"),
re2.compile(r"^airflow\.api\.auth\.backend\.deny_all$|^$"),
"airflow.api.auth.backend.session",
"3.0",
),
},
"elasticsearch": {
"log_id_template": (
re.compile("^" + re.escape("{dag_id}-{task_id}-{execution_date}-{try_number}") + "$"),
re2.compile("^" + re2.escape("{dag_id}-{task_id}-{execution_date}-{try_number}") + "$"),
"{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}",
"3.0",
)
Expand Down Expand Up @@ -425,7 +424,7 @@ def _upgrade_postgres_metastore_conn(self):
FutureWarning,
)
self.upgraded_values[(section, key)] = old_value
new_value = re.sub("^" + re.escape(f"{parsed.scheme}://"), f"{good_scheme}://", old_value)
new_value = re2.sub("^" + re2.escape(f"{parsed.scheme}://"), f"{good_scheme}://", old_value)
self._update_env_var(section=section, name=key, new_value=new_value)

# if the old value is set via env var, we need to wipe it
Expand Down
8 changes: 4 additions & 4 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import inspect
import re
import warnings
from functools import cached_property
from itertools import chain
Expand All @@ -38,6 +37,7 @@
)

import attr
import re2
import typing_extensions
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -144,15 +144,15 @@ def get_unique_task_id(
return task_id

def _find_id_suffixes(dag: DAG) -> Iterator[int]:
prefix = re.split(r"__\d+$", tg_task_id)[0]
prefix = re2.split(r"__\d+$", tg_task_id)[0]
for task_id in dag.task_ids:
match = re.match(rf"^{prefix}__(\d+)$", task_id)
match = re2.match(rf"^{prefix}__(\d+)$", task_id)
if match is None:
continue
yield int(match.group(1))
yield 0 # Default if there's no matching task ID.

core = re.split(r"__\d+$", task_id)[0]
core = re2.split(r"__\d+$", task_id)[0]
return f"{core}__{max(_find_id_suffixes(dag)) + 1}"


Expand Down
4 changes: 2 additions & 2 deletions airflow/kubernetes/pod_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
import datetime
import logging
import os
import re
import warnings
from functools import reduce

import re2
from dateutil import parser
from kubernetes.client import models as k8s
from kubernetes.client.api_client import ApiClient
Expand Down Expand Up @@ -65,7 +65,7 @@ def make_safe_label_value(string: str) -> str:
way from the original value sent to this function, then we need to truncate to
53 chars, and append it with a unique hash.
"""
safe_label = re.sub(r"^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$", "", string)
safe_label = re2.sub(r"^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$", "", string)

if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
safe_hash = md5(string.encode()).hexdigest()[:9]
Expand Down
5 changes: 2 additions & 3 deletions airflow/kubernetes/pod_generator_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from __future__ import annotations

import copy
import re
import uuid

import re2
from kubernetes.client import models as k8s

from airflow.utils.hashlib_wrapper import md5
Expand Down Expand Up @@ -70,7 +70,7 @@ def make_safe_label_value(string):
way from the original value sent to this function, then we need to truncate to
53 chars, and append it with a unique hash.
"""
safe_label = re.sub(r"^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$", "", string)
safe_label = re2.sub(r"^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$", "", string)

if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
safe_hash = md5(string.encode()).hexdigest()[:9]
Expand Down Expand Up @@ -151,7 +151,6 @@ def __init__(
extract_xcom: bool = False,
priority_class_name: str | None = None,
):

self.pod = k8s.V1Pod()
self.pod.api_version = "v1"
self.pod.kind = "Pod"
Expand Down
7 changes: 4 additions & 3 deletions airflow/metrics/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@

import abc
import logging
import re
import string
import warnings
from functools import partial, wraps
from typing import Callable, Iterable, Pattern, cast

import re2

from airflow.configuration import conf
from airflow.exceptions import InvalidStatsNameException

Expand Down Expand Up @@ -78,7 +79,7 @@ class MetricNameLengthExemptionWarning(Warning):
r"^dagrun\.schedule_delay\.(?P<dag_id>.*)$",
r"^dagrun\.(?P<dag_id>.*)\.first_task_scheduling_delay$",
}
BACK_COMPAT_METRIC_NAMES: set[Pattern[str]] = {re.compile(name) for name in BACK_COMPAT_METRIC_NAME_PATTERNS}
BACK_COMPAT_METRIC_NAMES: set[Pattern[str]] = {re2.compile(name) for name in BACK_COMPAT_METRIC_NAME_PATTERNS}

OTEL_NAME_MAX_LENGTH = 63

Expand Down Expand Up @@ -132,7 +133,7 @@ def stat_name_otel_handler(
# If the name is in the exceptions list, do not fail it for being too long.
# It may still be deemed invalid for other reasons below.
for exemption in BACK_COMPAT_METRIC_NAMES:
if re.match(exemption, stat_name):
if re2.match(exemption, stat_name):
# There is a back-compat exception for this name; proceed
name_length_exemption = True
matched_exemption = exemption.pattern
Expand Down
8 changes: 4 additions & 4 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

import jinja2
import pendulum
import re2 as re
import re2
from dateutil.relativedelta import relativedelta
from pendulum.tz.timezone import Timezone
from sqlalchemy import (
Expand Down Expand Up @@ -2361,7 +2361,7 @@ def partial_subset(
dag = copy.deepcopy(self, memo) # type: ignore

if isinstance(task_ids_or_regex, (str, Pattern)):
matched_tasks = [t for t in self.tasks if re.findall(task_ids_or_regex, t.task_id)]
matched_tasks = [t for t in self.tasks if re2.findall(task_ids_or_regex, t.task_id)]
else:
matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex]

Expand Down Expand Up @@ -2828,8 +2828,8 @@ def create_dagrun(

regex = airflow_conf.get("scheduler", "allowed_run_id_pattern")

if run_id and not re.match(RUN_ID_REGEX, run_id):
if not regex.strip() or not re.match(regex.strip(), run_id):
if run_id and not re2.match(RUN_ID_REGEX, run_id):
if not regex.strip() or not re2.match(regex.strip(), run_id):
raise AirflowException(
f"The provided run ID '{run_id}' is invalid. It does not match either "
f"the configured pattern: '{regex}' or the built-in pattern: '{RUN_ID_REGEX}'"
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload

import re2 as re
import re2
from sqlalchemy import (
Boolean,
Column,
Expand Down Expand Up @@ -248,7 +248,7 @@ def validate_run_id(self, key: str, run_id: str) -> str | None:
if not run_id:
return None
regex = airflow_conf.get("scheduler", "allowed_run_id_pattern")
if not re.match(regex, run_id) and not re.match(RUN_ID_REGEX, run_id):
if not re2.match(regex, run_id) and not re2.match(RUN_ID_REGEX, run_id):
raise ValueError(
f"The run_id provided '{run_id}' does not match the pattern '{regex}' or '{RUN_ID_REGEX}'"
)
Expand Down
5 changes: 3 additions & 2 deletions airflow/security/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
# limitations under the License.
#
"""Various security-related utils."""
import re
import socket

import re2

from airflow.utils.net import get_hostname


Expand All @@ -49,7 +50,7 @@ def get_components(principal) -> list[str] | None:
"""
if not principal:
return None
return re.split(r"[/@]", str(principal))
return re2.split(r"[/@]", str(principal))


def replace_hostname_pattern(components, host=None):
Expand Down
Loading

0 comments on commit 353df22

Please sign in to comment.