Skip to content

Commit

Permalink
Un-ignore DeprecationWarning (#20322)
Browse files Browse the repository at this point in the history
(cherry picked from commit 9876e19)
  • Loading branch information
uranusjr authored and jedcunningham committed Jan 20, 2022
1 parent c836e71 commit a25d7ce
Show file tree
Hide file tree
Showing 23 changed files with 195 additions and 111 deletions.
41 changes: 16 additions & 25 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
from airflow.utils import timezone
from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor
from airflow.utils.email import send_email
from airflow.utils.helpers import is_container
from airflow.utils.helpers import is_container, render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.operator_helpers import context_to_airflow_vars
Expand Down Expand Up @@ -2016,7 +2016,7 @@ def render_k8s_pod_yaml(self) -> Optional[dict]:
sanitized_pod = ApiClient().sanitize_for_serialization(pod)
return sanitized_pod

def get_email_subject_content(self, exception):
def get_email_subject_content(self, exception: BaseException) -> Tuple[str, str, str]:
"""Get the email subject content for exceptions."""
# For a ti from DB (without ti.task), return the default value
# Reuse it for smart sensor to send default email alert
Expand All @@ -2043,18 +2043,18 @@ def get_email_subject_content(self, exception):
'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
)

# This function is called after changing the state from State.RUNNING,
# so we need to subtract 1 from self.try_number here.
current_try_number = self.try_number - 1
additional_context = {
"exception": exception,
"exception_html": exception_html,
"try_number": current_try_number,
"max_tries": self.max_tries,
}

if use_default:
jinja_context = {'ti': self}
# This function is called after changing the state
# from State.RUNNING so need to subtract 1 from self.try_number.
jinja_context.update(
dict(
exception=exception,
exception_html=exception_html,
try_number=self.try_number - 1,
max_tries=self.max_tries,
)
)
jinja_context = {"ti": self, **additional_context}
jinja_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True
)
Expand All @@ -2064,24 +2064,15 @@ def get_email_subject_content(self, exception):

else:
jinja_context = self.get_template_context()

jinja_context.update(
dict(
exception=exception,
exception_html=exception_html,
try_number=self.try_number - 1,
max_tries=self.max_tries,
)
)

jinja_context.update(additional_context)
jinja_env = self.task.get_template_env()

def render(key, content):
def render(key: str, content: str) -> str:
if conf.has_option('email', key):
path = conf.get('email', key)
with open(path) as f:
content = f.read()
return jinja_env.from_string(content).render(**jinja_context)
return render_template_to_string(jinja_env.from_string(content), jinja_context)

subject = render('subject_template', default_subject)
html_content = render('html_content_template', default_html_content)
Expand Down
2 changes: 1 addition & 1 deletion airflow/operators/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(

def choose_branch(self, context: Dict) -> Union[str, Iterable[str]]:
if self.use_task_execution_date is True:
now = timezone.make_naive(context["execution_date"], self.dag.timezone)
now = timezone.make_naive(context["logical_date"], self.dag.timezone)
else:
now = timezone.make_naive(timezone.utcnow(), self.dag.timezone)

Expand Down
26 changes: 16 additions & 10 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import warnings
from tempfile import TemporaryDirectory
from textwrap import dedent
from typing import Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Union

import dill

Expand All @@ -33,7 +33,7 @@
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import _CURRENT_CONTEXT
from airflow.utils.context import Context
from airflow.utils.operator_helpers import determine_kwargs
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script

Expand Down Expand Up @@ -142,8 +142,8 @@ def __init__(
self,
*,
python_callable: Callable,
op_args: Optional[List] = None,
op_kwargs: Optional[Dict] = None,
op_args: Optional[Collection[Any]] = None,
op_kwargs: Optional[Mapping[str, Any]] = None,
templates_dict: Optional[Dict] = None,
templates_exts: Optional[List[str]] = None,
**kwargs,
Expand All @@ -159,7 +159,7 @@ def __init__(
if not callable(python_callable):
raise AirflowException('`python_callable` param must be callable')
self.python_callable = python_callable
self.op_args = op_args or []
self.op_args = op_args or ()
self.op_kwargs = op_kwargs or {}
self.templates_dict = templates_dict
if templates_exts:
Expand All @@ -169,12 +169,15 @@ def execute(self, context: Dict):
context.update(self.op_kwargs)
context['templates_dict'] = self.templates_dict

self.op_kwargs = determine_kwargs(self.python_callable, self.op_args, context)
self.op_kwargs = self.determine_kwargs(context)

return_value = self.execute_callable()
self.log.info("Done. Returned value was: %s", return_value)
return return_value

def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
return KeywordParameters.determine(self.python_callable, self.op_args, context).unpacking()

def execute_callable(self):
"""
Calls the python callable with the given arguments.
Expand Down Expand Up @@ -241,11 +244,11 @@ def execute(self, context: Dict):

self.log.info('Skipping downstream tasks...')

downstream_tasks = context['task'].get_flat_relatives(upstream=False)
downstream_tasks = context["task"].get_flat_relatives(upstream=False)
self.log.debug("Downstream task_ids %s", downstream_tasks)

if downstream_tasks:
self.skip(context['dag_run'], context['ti'].execution_date, downstream_tasks)
self.skip(context["dag_run"], context["logical_date"], downstream_tasks)

self.log.info("Done.")

Expand Down Expand Up @@ -345,8 +348,8 @@ def __init__(
python_version: Optional[Union[str, int, float]] = None,
use_dill: bool = False,
system_site_packages: bool = True,
op_args: Optional[List] = None,
op_kwargs: Optional[Dict] = None,
op_args: Optional[Collection[Any]] = None,
op_kwargs: Optional[Mapping[str, Any]] = None,
string_args: Optional[Iterable[str]] = None,
templates_dict: Optional[Dict] = None,
templates_exts: Optional[List[str]] = None,
Expand Down Expand Up @@ -392,6 +395,9 @@ def execute(self, context: Context):
serializable_context = context.copy_only(serializable_keys)
return super().execute(context=serializable_context)

def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
return KeywordParameters.determine(self.python_callable, self.op_args, context).serializing()

def execute_callable(self):
with TemporaryDirectory(prefix='venv') as tmp_dir:
if self.templates_dict:
Expand Down
2 changes: 1 addition & 1 deletion airflow/operators/weekday.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(

def choose_branch(self, context: Dict) -> Union[str, Iterable[str]]:
if self.use_task_execution_day:
now = context["execution_date"]
now = context["logical_date"]
else:
now = timezone.make_naive(timezone.utcnow(), self.dag.timezone)

Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/http/operators/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead")

def execute(self, context: Dict[str, Any]) -> Any:
from airflow.utils.operator_helpers import make_kwargs_callable
from airflow.utils.operator_helpers import determine_kwargs

http = HttpHook(self.method, http_conn_id=self.http_conn_id, auth_type=self.auth_type)

Expand All @@ -114,10 +114,10 @@ def execute(self, context: Dict[str, Any]) -> Any:
if self.log_response:
self.log.info(response.text)
if self.response_check:
kwargs_callable = make_kwargs_callable(self.response_check)
if not kwargs_callable(response, **context):
kwargs = determine_kwargs(self.response_check, [response], context)
if not self.response_check(response, **kwargs):
raise AirflowException("Response check returned False.")
if self.response_filter:
kwargs_callable = make_kwargs_callable(self.response_filter)
return kwargs_callable(response, **context)
kwargs = determine_kwargs(self.response_filter, [response], context)
return self.response_filter(response, **kwargs)
return response.text
7 changes: 3 additions & 4 deletions airflow/providers/http/sensors/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
self.hook = HttpHook(method=method, http_conn_id=http_conn_id)

def poke(self, context: Dict[Any, Any]) -> bool:
from airflow.utils.operator_helpers import make_kwargs_callable
from airflow.utils.operator_helpers import determine_kwargs

self.log.info('Poking: %s', self.endpoint)
try:
Expand All @@ -107,9 +107,8 @@ def poke(self, context: Dict[Any, Any]) -> bool:
extra_options=self.extra_options,
)
if self.response_check:
kwargs_callable = make_kwargs_callable(self.response_check)
return kwargs_callable(response, **context)

kwargs = determine_kwargs(self.response_check, [response], context)
return self.response_check(response, **kwargs)
except AirflowException as exc:
if str(exc).startswith("404"):
return False
Expand Down
24 changes: 12 additions & 12 deletions airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_link(self, operator, dttm):
class ExternalTaskSensor(BaseSensorOperator):
"""
Waits for a different DAG or a task in a different DAG to complete for a
specific execution_date
specific logical date.
:param external_dag_id: The dag_id that contains the task you want to
wait for
Expand All @@ -65,14 +65,14 @@ class ExternalTaskSensor(BaseSensorOperator):
:param failed_states: Iterable of failed or dis-allowed states, default is ``None``
:type failed_states: Iterable
:param execution_delta: time difference with the previous execution to
look at, the default is the same execution_date as the current task or DAG.
look at, the default is the same logical date as the current task or DAG.
For yesterday, use [positive!] datetime.timedelta(days=1). Either
execution_delta or execution_date_fn can be passed to
ExternalTaskSensor, but not both.
:type execution_delta: Optional[datetime.timedelta]
:param execution_date_fn: function that receives the current execution date as the first
:param execution_date_fn: function that receives the current execution's logical date as the first
positional argument and optionally any number of keyword arguments available in the
context dictionary, and returns the desired execution dates to query.
context dictionary, and returns the desired logical dates to query.
Either execution_delta or execution_date_fn can be passed to ExternalTaskSensor,
but not both.
:type execution_date_fn: Optional[Callable]
Expand Down Expand Up @@ -157,11 +157,11 @@ def __init__(
@provide_session
def poke(self, context, session=None):
if self.execution_delta:
dttm = context['execution_date'] - self.execution_delta
dttm = context['logical_date'] - self.execution_delta
elif self.execution_date_fn:
dttm = self._handle_execution_date_fn(context=context)
else:
dttm = context['execution_date']
dttm = context['logical_date']

dttm_filter = dttm if isinstance(dttm, list) else [dttm]
serialized_dttm_filter = ','.join(dt.isoformat() for dt in dttm_filter)
Expand Down Expand Up @@ -260,14 +260,14 @@ def _handle_execution_date_fn(self, context) -> Any:
"""
from airflow.utils.operator_helpers import make_kwargs_callable

# Remove "execution_date" because it is already a mandatory positional argument
execution_date = context["execution_date"]
kwargs = {k: v for k, v in context.items() if k != "execution_date"}
# Remove "logical_date" because it is already a mandatory positional argument
logical_date = context["logical_date"]
kwargs = {k: v for k, v in context.items() if k not in {"execution_date", "logical_date"}}
# Add "context" in the kwargs for backward compatibility (because context used to be
# an acceptable argument of execution_date_fn)
kwargs["context"] = context
kwargs_callable = make_kwargs_callable(self.execution_date_fn)
return kwargs_callable(execution_date, **kwargs)
return kwargs_callable(logical_date, **kwargs)


class ExternalTaskMarker(DummyOperator):
Expand All @@ -281,7 +281,7 @@ class ExternalTaskMarker(DummyOperator):
:type external_dag_id: str
:param external_task_id: The task_id of the dependent task that needs to be cleared.
:type external_task_id: str
:param execution_date: The execution_date of the dependent task that needs to be cleared.
:param execution_date: The logical date of the dependent task execution that needs to be cleared.
:type execution_date: str or datetime.datetime
:param recursion_depth: The maximum level of transitive dependencies allowed. Default is 10.
This is mostly used for preventing cyclic dependencies. It is fine to increase
Expand All @@ -300,7 +300,7 @@ def __init__(
*,
external_dag_id: str,
external_task_id: str,
execution_date: Optional[Union[str, datetime.datetime]] = "{{ execution_date.isoformat() }}",
execution_date: Optional[Union[str, datetime.datetime]] = "{{ logical_date.isoformat() }}",
recursion_depth: int = 10,
**kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion airflow/sensors/weekday.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,6 @@ def poke(self, context):
WeekDay(timezone.utcnow().isoweekday()).name,
)
if self.use_task_execution_day:
return context['execution_date'].isoweekday() in self._week_day_num
return context['logical_date'].isoweekday() in self._week_day_num
else:
return timezone.utcnow().isoweekday() in self._week_day_num
33 changes: 33 additions & 0 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import contextlib
import copy
import functools
import warnings
from typing import (
AbstractSet,
Expand All @@ -28,12 +29,15 @@
Dict,
Iterator,
List,
Mapping,
MutableMapping,
Optional,
Tuple,
ValuesView,
)

import lazy_object_proxy

_NOT_SET: Any = object()


Expand Down Expand Up @@ -194,3 +198,32 @@ def copy_only(self, keys: Container[str]) -> "Context":
new = type(self)({k: v for k, v in self._context.items() if k in keys})
new._deprecation_replacements = self._deprecation_replacements.copy()
return new


def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]:
"""Create a mapping that wraps deprecated entries in a lazy object proxy.
This further delays deprecation warning to until when the entry is actually
used, instead of when it's accessed in the context. The result is useful for
passing into a callable with ``**kwargs``, which would unpack the mapping
too eagerly otherwise.
This is implemented as a free function because the ``Context`` type is
"faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
functions.
:meta private:
"""

def _deprecated_proxy_factory(k: str, v: Any) -> Any:
replacements = source._deprecation_replacements[k]
warnings.warn(_create_deprecation_warning(k, replacements))
return v

def _create_value(k: str, v: Any) -> Any:
if k not in source._deprecation_replacements:
return v
factory = functools.partial(_deprecated_proxy_factory, k, v)
return lazy_object_proxy.Proxy(factory)

return {k: _create_value(k, v) for k, v in source._context.items()}
6 changes: 5 additions & 1 deletion airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# undefined attribute errors from Mypy. Hopefully there will be a mechanism to
# declare "these are defined, but don't error if others are accessed" someday.

from typing import Any, Optional
from typing import Any, Mapping, Optional

from pendulum import DateTime

Expand Down Expand Up @@ -80,3 +80,7 @@ class Context(TypedDict, total=False):
var: _VariableAccessors
yesterday_ds: str
yesterday_ds_nodash: str

class AirflowContextDeprecationWarning(DeprecationWarning): ...

def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: ...
2 changes: 1 addition & 1 deletion airflow/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def render_log_filename(ti: "TaskInstance", try_number, filename_template) -> st
if filename_jinja_template:
jinja_context = ti.get_template_context()
jinja_context['try_number'] = try_number
return filename_jinja_template.render(**jinja_context)
return render_template_to_string(filename_jinja_template, jinja_context)

return filename_template.format(
dag_id=ti.dag_id,
Expand Down
Loading

0 comments on commit a25d7ce

Please sign in to comment.