Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove provide_session decorator from TaskInstancePydantic methods #37853

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions airflow/serialization/pydantic/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
PlainValidator,
is_pydantic_2_installed,
)
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.xcom import XCOM_RETURN_KEY

if TYPE_CHECKING:
Expand Down Expand Up @@ -144,13 +143,12 @@ def xcom_pull(
"""
return None

@provide_session
def xcom_push(
self,
key: str,
value: Any,
execution_date: datetime | None = None,
session: Session = NEW_SESSION,
session: Session | None = None,
) -> None:
"""
Push an XCom value for this task instance.
Expand All @@ -162,8 +160,7 @@ def xcom_push(
"""
pass

@provide_session
def get_dagrun(self, session: Session = NEW_SESSION) -> DagRunPydantic:
def get_dagrun(self, session: Session | None = None) -> DagRunPydantic:
"""
Return the DagRun for this TaskInstance.

Expand All @@ -186,8 +183,7 @@ def _execute_task(self, context, task_orig):

return _execute_task(task_instance=self, context=context, task_orig=task_orig)

@provide_session
def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None:
def refresh_from_db(self, session: Session | None = None, lock_for_update: bool = False) -> None:
"""
Refresh the task instance from the database based on the primary key.

Expand Down Expand Up @@ -244,14 +240,13 @@ def is_eligible_to_retry(self):

return _is_eligible_to_retry(task_instance=self)

@provide_session
def handle_failure(
self,
error: None | str | Exception | KeyboardInterrupt,
test_mode: bool | None = None,
context: Context | None = None,
force_fail: bool = False,
session: Session = NEW_SESSION,
session: Session | None = None,
) -> None:
"""
Handle Failure for a task instance.
Expand Down Expand Up @@ -284,7 +279,6 @@ def refresh_from_task(self, task: Operator, pool_override: str | None = None) ->

_refresh_from_task(task_instance=self, task=task, pool_override=pool_override)

@provide_session
def get_previous_dagrun(
self,
state: DagRunState | None = None,
Expand All @@ -300,11 +294,10 @@ def get_previous_dagrun(

return _get_previous_dagrun(task_instance=self, state=state, session=session)

@provide_session
def get_previous_execution_date(
self,
state: DagRunState | None = None,
session: Session = NEW_SESSION,
session: Session | None = None,
) -> pendulum.DateTime | None:
"""
Return the execution date from property previous_ti_success.
Expand Down Expand Up @@ -340,11 +333,10 @@ def get_email_subject_content(

return _get_email_subject_content(task_instance=self, exception=exception, task=task)

@provide_session
def get_previous_ti(
self,
state: DagRunState | None = None,
session: Session = NEW_SESSION,
session: Session | None = None,
) -> TaskInstance | TaskInstancePydantic | None:
"""
Return the task instance for the task that ran before this task instance.
Expand All @@ -356,7 +348,6 @@ def get_previous_ti(

return _get_previous_ti(task_instance=self, state=state, session=session)

@provide_session
def check_and_change_state_before_execution(
self,
verbose: bool = True,
Expand All @@ -370,7 +361,7 @@ def check_and_change_state_before_execution(
job_id: str | None = None,
pool: str | None = None,
external_executor_id: str | None = None,
session: Session = NEW_SESSION,
session: Session | None = None,
) -> bool:
return TaskInstance._check_and_change_state_before_execution(
task_instance=self,
Expand All @@ -389,8 +380,7 @@ def check_and_change_state_before_execution(
session=session,
)

@provide_session
def schedule_downstream_tasks(self, session: Session = NEW_SESSION, max_tis_per_query: int | None = None):
def schedule_downstream_tasks(self, session: Session | None = None, max_tis_per_query: int | None = None):
"""
Schedule downstream tasks of this task instance.

Expand Down