Skip to content

Commit

Permalink
D401 Support - Secrets to Triggers (Inclusive) (#33338)
Browse files Browse the repository at this point in the history
(cherry picked from commit 44a752a)
  • Loading branch information
ferruzzi authored and ephraimbuddy committed Aug 28, 2023
1 parent a33cbf0 commit 601d0e4
Show file tree
Hide file tree
Showing 22 changed files with 63 additions and 57 deletions.
18 changes: 11 additions & 7 deletions airflow/secrets/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def is_expired(self, ttl: datetime.timedelta) -> bool:

@classmethod
def init(cls):
"""Initializes the cache, provided the configuration allows it. Safe to call several times."""
"""
Initialize the cache, provided the configuration allows it.
Safe to call several times.
"""
if cls._cache is not None:
return
use_cache = conf.getboolean(section="secrets", key="use_cache", fallback=False)
Expand All @@ -62,13 +66,13 @@ def init(cls):

@classmethod
def reset(cls):
"""For test purposes only."""
"""Use for test purposes only."""
cls._cache = None

@classmethod
def get_variable(cls, key: str) -> str | None:
"""
Tries to get the value associated with the key from the cache.
Try to get the value associated with the key from the cache.
:return: The saved value (which can be None) if present in cache and not expired,
a NotPresent exception otherwise.
Expand All @@ -78,7 +82,7 @@ def get_variable(cls, key: str) -> str | None:
@classmethod
def get_connection_uri(cls, conn_id: str) -> str:
"""
Tries to get the uri associated with the conn_id from the cache.
Try to get the uri associated with the conn_id from the cache.
:return: The saved uri if present in cache and not expired,
a NotPresent exception otherwise.
Expand All @@ -101,12 +105,12 @@ def _get(cls, key: str, prefix: str) -> str | None:

@classmethod
def save_variable(cls, key: str, value: str | None):
"""Saves the value for that key in the cache, if initialized."""
"""Save the value for that key in the cache, if initialized."""
cls._save(key, value, cls._VARIABLE_PREFIX)

@classmethod
def save_connection_uri(cls, conn_id: str, uri: str):
"""Saves the uri representation for that connection in the cache, if initialized."""
"""Save the uri representation for that connection in the cache, if initialized."""
if uri is None:
# connections raise exceptions if not present, so we shouldn't have any None value to save.
return
Expand All @@ -119,7 +123,7 @@ def _save(cls, key: str, value: str | None, prefix: str):

@classmethod
def invalidate_variable(cls, key: str):
"""Invalidates (actually removes) the value stored in the cache for that Variable."""
"""Invalidate (actually removes) the value stored in the cache for that Variable."""
if cls._cache is not None:
# second arg ensures no exception if key is absent
cls._cache.pop(f"{cls._VARIABLE_PREFIX}{key}", None)
6 changes: 3 additions & 3 deletions airflow/secrets/local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@


def get_connection_parameter_names() -> set[str]:
"""Returns :class:`airflow.models.connection.Connection` constructor parameters."""
"""Return :class:`airflow.models.connection.Connection` constructor parameters."""
from airflow.models.connection import Connection

return {k for k in signature(Connection.__init__).parameters.keys() if k != "self"}
Expand Down Expand Up @@ -186,7 +186,7 @@ def _parse_secret_file(file_path: str) -> dict[str, Any]:


def _create_connection(conn_id: str, value: Any):
"""Creates a connection based on a URL or JSON object."""
"""Create a connection based on a URL or JSON object."""
from airflow.models.connection import Connection

if isinstance(value, str):
Expand Down Expand Up @@ -243,7 +243,7 @@ def load_variables(file_path: str) -> dict[str, str]:


def load_connections(file_path) -> dict[str, list[Any]]:
"""Deprecated: Please use `airflow.secrets.local_filesystem.load_connections_dict`."""
"""Use `airflow.secrets.local_filesystem.load_connections_dict`, this is deprecated."""
warnings.warn(
"This function is deprecated. Please use `airflow.secrets.local_filesystem.load_connections_dict`.",
RemovedInAirflow3Warning,
Expand Down
2 changes: 1 addition & 1 deletion airflow/security/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@


def resource_name_for_dag(root_dag_id: str) -> str:
"""Returns the resource name for a DAG id.
"""Return the resource name for a DAG id.
Note that since a sub-DAG should follow the permission of its
parent DAG, you should pass ``DagModel.root_dag_id`` to this function,
Expand Down
6 changes: 3 additions & 3 deletions airflow/security/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ def get_components(principal) -> list[str] | None:


def replace_hostname_pattern(components, host=None):
"""Replaces hostname with the right pattern including lowercase of the name."""
"""Replace hostname with the right pattern including lowercase of the name."""
fqdn = host
if not fqdn or fqdn == "0.0.0.0":
fqdn = get_hostname()
return f"{components[0]}/{fqdn.lower()}@{components[2]}"


def get_fqdn(hostname_or_ip=None):
"""Retrieves FQDN - hostname for the IP or hostname."""
"""Retrieve FQDN - hostname for the IP or hostname."""
try:
if hostname_or_ip:
fqdn = socket.gethostbyaddr(hostname_or_ip)[0]
Expand All @@ -77,7 +77,7 @@ def get_fqdn(hostname_or_ip=None):


def principal_from_username(username, realm):
"""Retrieves principal from the user name and realm."""
"""Retrieve principal from the username and realm."""
if ("@" not in username) and realm:
username = f"{username}@{realm}"

Expand Down
4 changes: 2 additions & 2 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _validate_input_values(self) -> None:
)

def poke(self, context: Context) -> bool | PokeReturnValue:
"""Function defined by the sensors while deriving this class should override."""
"""Override when deriving this class."""
raise AirflowException("Override me.")

def execute(self, context: Context) -> Any:
Expand Down Expand Up @@ -287,7 +287,7 @@ def _get_next_poke_interval(
run_duration: Callable[[], float],
try_number: int,
) -> float:
"""Using the similar logic which is used for exponential backoff retry delay for operators."""
"""Use similar logic which is used for exponential backoff retry delay for operators."""
if not self.exponential_backoff:
return self.poke_interval

Expand Down
4 changes: 2 additions & 2 deletions airflow/sensors/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def poke(self, context: Context) -> bool:

class DateTimeSensorAsync(DateTimeSensor):
"""
Waits until the specified datetime occurs.
Wait until the specified datetime occurs.
Deferring itself to avoid taking up a worker slot while it is waiting.
It is a drop-in replacement for DateTimeSensor.
Expand All @@ -92,5 +92,5 @@ def execute(self, context: Context):
)

def execute_complete(self, context, event=None):
"""Callback for when the trigger fires - returns immediately."""
"""Execute when the trigger fires - returns immediately."""
return None
4 changes: 2 additions & 2 deletions airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def execute(self, context: Context) -> None:
)

def execute_complete(self, context, event=None):
"""Callback for when the trigger fires - returns immediately."""
"""Execute when the trigger fires - return immediately."""
if event["status"] == "success":
self.log.info("External task %s has executed successfully.", self.external_task_id)
return None
Expand Down Expand Up @@ -528,7 +528,7 @@ def __init__(

@classmethod
def get_serialized_fields(cls):
"""Serialized ExternalTaskMarker contain exactly these fields + templated_fields ."""
"""Serialize ExternalTaskMarker to contain exactly these fields + templated_fields ."""
if not cls.__serialized_fields:
cls.__serialized_fields = frozenset(super().get_serialized_fields() | {"recursion_depth"})
return cls.__serialized_fields
Expand Down
2 changes: 1 addition & 1 deletion airflow/sensors/time_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,5 @@ def execute(self, context: Context):
self.defer(trigger=DateTimeTrigger(moment=target_dttm), method_name="execute_complete")

def execute_complete(self, context, event=None):
"""Callback for when the trigger fires - returns immediately."""
"""Execute for when the trigger fires - return immediately."""
return None
2 changes: 1 addition & 1 deletion airflow/sensors/time_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,5 @@ def execute(self, context: Context):
)

def execute_complete(self, context, event=None):
"""Callback for when the trigger fires - returns immediately."""
"""Execute when the trigger fires - returns immediately."""
return None
4 changes: 2 additions & 2 deletions airflow/serialization/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@


def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]:
"""Encodes o so it can be understood by the deserializer."""
"""Encode an object so it can be understood by the deserializer."""
return {CLASSNAME: cls, VERSION: version, DATA: data}


Expand Down Expand Up @@ -274,7 +274,7 @@ def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:


def _convert(old: dict) -> dict:
"""Converts an old style serialization to new style."""
"""Convert an old style serialization to new style."""
if OLD_TYPE in old and OLD_DATA in old:
# Return old style dicts directly as they do not need wrapping
if old[OLD_TYPE] == OLD_DICT:
Expand Down
20 changes: 11 additions & 9 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool:
def serialize_to_json(
cls, object_to_serialize: BaseOperator | MappedOperator | DAG, decorated_fields: set
) -> dict[str, Any]:
"""Serializes an object to JSON."""
"""Serialize an object to JSON."""
serialized_object: dict[str, Any] = {}
keys_to_serialize = object_to_serialize.get_serialized_fields()
for key in keys_to_serialize:
Expand Down Expand Up @@ -395,7 +395,8 @@ def serialize_to_json(
def serialize(
cls, var: Any, *, strict: bool = False, use_pydantic_models: bool = False
) -> Any: # Unfortunately there is no support for recursive types in mypy
"""Helper function of depth first search for serialization.
"""
Serialize an object; helper function of depth first search for serialization.
The serialization protocol is:
Expand Down Expand Up @@ -513,7 +514,8 @@ def default_serialization(cls, strict, var) -> str:

@classmethod
def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
"""Helper function of depth first search for deserialization.
"""
Deserialize an object; helper function of depth first search for deserialization.
:meta private:
"""
Expand Down Expand Up @@ -695,7 +697,7 @@ class DependencyDetector:

@staticmethod
def detect_task_dependencies(task: Operator) -> list[DagDependency]:
"""Detects dependencies caused by tasks."""
"""Detect dependencies caused by tasks."""
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.sensors.external_task import ExternalTaskSensor

Expand Down Expand Up @@ -732,7 +734,7 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]:

@staticmethod
def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]:
"""Detects dependencies set directly on the DAG object."""
"""Detect dependencies set directly on the DAG object."""
if not dag:
return
for x in dag.dataset_triggers:
Expand Down Expand Up @@ -831,7 +833,7 @@ def serialize_operator(cls, op: BaseOperator) -> dict[str, Any]:

@classmethod
def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) -> dict[str, Any]:
"""Serializes operator into a JSON object."""
"""Serialize operator into a JSON object."""
serialize_op = cls.serialize_to_json(op, cls._decorated_fields)
serialize_op["_task_type"] = getattr(op, "_task_type", type(op).__name__)
serialize_op["_task_module"] = getattr(op, "_task_module", type(op).__module__)
Expand Down Expand Up @@ -1079,7 +1081,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:

@classmethod
def detect_dependencies(cls, op: Operator) -> set[DagDependency]:
"""Detects between DAG dependencies for the operator."""
"""Detect between DAG dependencies for the operator."""

def get_custom_dep() -> list[DagDependency]:
"""
Expand Down Expand Up @@ -1275,7 +1277,7 @@ def __get_constructor_defaults():

@classmethod
def serialize_dag(cls, dag: DAG) -> dict:
"""Serializes a DAG into a JSON object."""
"""Serialize a DAG into a JSON object."""
try:
serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields)

Expand Down Expand Up @@ -1409,7 +1411,7 @@ class TaskGroupSerialization(BaseSerialization):

@classmethod
def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
"""Serializes TaskGroup into a JSON object."""
"""Serialize TaskGroup into a JSON object."""
if not task_group:
return None

Expand Down
2 changes: 1 addition & 1 deletion airflow/task/task_runner/base_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def terminate(self) -> None:
raise NotImplementedError()

def on_finish(self) -> None:
"""A callback that should be called when this is done running."""
"""Execute when this is done running."""
if self._cfg_path and os.path.isfile(self._cfg_path):
if self.run_as_user:
subprocess.call(["sudo", "rm", self._cfg_path], close_fds=True)
Expand Down
5 changes: 3 additions & 2 deletions airflow/template/templater.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,15 @@ def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
return SandboxedEnvironment(cache_size=0)

def prepare_template(self) -> None:
"""Hook triggered after the templated fields get replaced by their content.
"""
Execute after the templated fields get replaced by their content.
If you need your object to alter the content of the file before the
template is rendered, it should override this method to do so.
"""

def resolve_template_files(self) -> None:
"""Getting the content of files for template_field / template_ext."""
"""Get the content of files for template_field / template_ext."""
if self.template_ext:
for field in self.template_fields:
content = getattr(self, field, None)
Expand Down
2 changes: 1 addition & 1 deletion airflow/ti_deps/dep_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class DepContext:

def ensure_finished_tis(self, dag_run: DagRun, session: Session) -> list[TaskInstance]:
"""
Ensures finished_tis is populated if it's currently None, which allows running tasks without dag_run.
Ensure finished_tis is populated if it's currently None, which allows running tasks without dag_run.
:param dag_run: The DagRun for which to find finished tasks
:return: A list of all the finished tasks of this DAG and execution_date
Expand Down
6 changes: 3 additions & 3 deletions airflow/ti_deps/deps/base_ti_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_dep_statuses(
dep_context: DepContext | None = None,
) -> Iterator[TIDepStatus]:
"""
Wrapper around the private _get_dep_statuses method.
Wrap around the private _get_dep_statuses method.
Contains some global checks for all dependencies.
Expand All @@ -113,7 +113,7 @@ def get_dep_statuses(
@provide_session
def is_met(self, ti: TaskInstance, session: Session, dep_context: DepContext | None = None) -> bool:
"""
Returns whether a dependency is met for a given task instance.
Return whether a dependency is met for a given task instance.
A dependency is considered met if all the dependency statuses it reports are passing.
Expand All @@ -132,7 +132,7 @@ def get_failure_reasons(
dep_context: DepContext | None = None,
) -> Iterator[str]:
"""
Returns an iterable of strings that explain why this dependency wasn't met.
Return an iterable of strings that explain why this dependency wasn't met.
:param ti: the task instance to see if this dependency is met for
:param session: database session
Expand Down
2 changes: 1 addition & 1 deletion airflow/ti_deps/deps/dagrun_backfill_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DagRunNotBackfillDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, session, dep_context=None):
"""
Determines if the DagRun is valid for scheduling from scheduler.
Determine if the DagRun is valid for scheduling from scheduler.
:param ti: the task instance to get the dependency status for
:param session: database session
Expand Down
2 changes: 1 addition & 1 deletion airflow/ti_deps/deps/pool_slots_available_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PoolSlotsAvailableDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, session, dep_context=None):
"""
Determines if the pool task instance is in has available slots.
Determine if the pool task instance is in has available slots.
:param ti: the task instance to get the dependency status for
:param session: database session
Expand Down
2 changes: 1 addition & 1 deletion airflow/ti_deps/deps/ready_to_reschedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ReadyToRescheduleDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
"""
Determines whether a task is ready to be rescheduled.
Determine whether a task is ready to be rescheduled.
Only tasks in NONE state with at least one row in task_reschedule table are
handled by this dependency class, otherwise this dependency is considered as passed.
Expand Down
3 changes: 2 additions & 1 deletion airflow/timetables/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def deserialize(cls, data: dict[str, Any]) -> Timetable:
return cls(datetime.timedelta(seconds=delta))

def __eq__(self, other: Any) -> bool:
"""The offset should match.
"""
Return if the offsets match.
This is only for testing purposes and should not be relied on otherwise.
"""
Expand Down
Loading

0 comments on commit 601d0e4

Please sign in to comment.