diff --git a/airflow/api_fastapi/execution_api/app.py b/airflow/api_fastapi/execution_api/app.py index 1751b61bcd54b..e019e8f14f3d2 100644 --- a/airflow/api_fastapi/execution_api/app.py +++ b/airflow/api_fastapi/execution_api/app.py @@ -20,6 +20,7 @@ from contextlib import asynccontextmanager from fastapi import FastAPI +from fastapi.openapi.utils import get_openapi @asynccontextmanager @@ -34,11 +35,49 @@ def create_task_execution_api_app(app: FastAPI) -> FastAPI: from airflow.api_fastapi.execution_api.routes import execution_api_router # TODO: Add versioning to the API - task_exec_api_app = FastAPI( + app = FastAPI( title="Airflow Task Execution API", description="The private Airflow Task Execution API.", lifespan=lifespan, ) - task_exec_api_app.include_router(execution_api_router) - return task_exec_api_app + def custom_openapi() -> dict: + """ + Customize the OpenAPI schema to include additional schemas not tied to specific endpoints. + + This is particularly useful for client SDKs that require models for types + not directly exposed in any endpoint's request or response schema. + + References: + - https://fastapi.tiangolo.com/how-to/extending-openapi/#modify-the-openapi-schema + """ + if app.openapi_schema: + return app.openapi_schema + openapi_schema = get_openapi( + title=app.title, + description=app.description, + version=app.version, + routes=app.routes, + ) + + extra_schemas = get_extra_schemas() + for schema_name, schema in extra_schemas.items(): + if schema_name not in openapi_schema["components"]["schemas"]: + openapi_schema["components"]["schemas"][schema_name] = schema + + app.openapi_schema = openapi_schema + return app.openapi_schema + + app.openapi = custom_openapi # type: ignore[method-assign] + + app.include_router(execution_api_router) + return app + + +def get_extra_schemas() -> dict[str, dict]: + """Get all the extra schemas that are not part of the main FastAPI app.""" + from airflow.api_fastapi.execution_api.datamodels import taskinstance + + return { + "TaskInstance": taskinstance.TaskInstance.model_json_schema(), + } diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index db63dc3a8dbb1..07066eb5a5cc3 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -17,6 +17,7 @@ from __future__ import annotations +import uuid from typing import Annotated, Literal, Union from pydantic import BaseModel, ConfigDict, Discriminator, Tag, WithJsonSchema @@ -97,3 +98,17 @@ class TIHeartbeatInfo(BaseModel): hostname: str pid: int + + +# This model is not used in the API, but it is included in generated OpenAPI schema +# for use in the client SDKs. +class TaskInstance(BaseModel): + """Schema for TaskInstance model with minimal required fields needed for Runtime.""" + + id: uuid.UUID + + task_id: str + dag_id: str + run_id: str + try_number: int + map_index: int | None = None diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index e921bee4bc2bd..c1d10f74d4a84 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -24,6 +24,7 @@ from datetime import datetime from enum import Enum from typing import Annotated, Any, Literal +from uuid import UUID from pydantic import BaseModel, Field @@ -45,10 +46,9 @@ class ConnectionResponse(BaseModel): class IntermediateTIState(str, Enum): """ - States that a Task Instance can be in that indicate it is not yet in a terminal or running state + States that a Task Instance can be in that indicate it is not yet in a terminal or running state. """ - REMOVED = "removed" SCHEDULED = "scheduled" QUEUED = "queued" RESTARTING = "restarting" @@ -89,12 +89,13 @@ class TITargetStatePayload(BaseModel): class TerminalTIState(str, Enum): """ - States that a Task Instance can be in that indicate it has reached a terminal state + States that a Task Instance can be in that indicate it has reached a terminal state. """ SUCCESS = "success" FAILED = "failed" SKIPPED = "skipped" + REMOVED = "removed" class ValidationError(BaseModel): @@ -121,6 +122,15 @@ class XComResponse(BaseModel): value: Annotated[Any, Field(title="Value")] +class TaskInstance(BaseModel): + id: Annotated[UUID, Field(title="Id")] + task_id: Annotated[str, Field(title="Task Id")] + dag_id: Annotated[str, Field(title="Dag Id")] + run_id: Annotated[str, Field(title="Run Id")] + try_number: Annotated[int, Field(title="Try Number")] + map_index: Annotated[int | None, Field(title="Map Index")] = None + + class HTTPValidationError(BaseModel): detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = None diff --git a/task_sdk/src/airflow/sdk/api/datamodels/activities.py b/task_sdk/src/airflow/sdk/api/datamodels/activities.py index 04f2b389d5dbe..30bf41f6a28af 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/activities.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/activities.py @@ -21,7 +21,7 @@ from pydantic import BaseModel -from airflow.sdk.api.datamodels.ti import TaskInstance +from airflow.sdk.api.datamodels._generated import TaskInstance class ExecuteTaskActivity(BaseModel): diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index a78fbb3e33b9e..07b260a417d50 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -47,8 +47,7 @@ from pydantic import BaseModel, ConfigDict, Field -from airflow.sdk.api.datamodels._generated import TerminalTIState # noqa: TCH001 -from airflow.sdk.api.datamodels.ti import TaskInstance # noqa: TCH001 +from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState # noqa: TCH001 class StartupDetails(BaseModel): diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index c05c6138f9642..6ecd8ff569800 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -43,7 +43,7 @@ from pydantic import TypeAdapter from airflow.sdk.api.client import Client -from airflow.sdk.api.datamodels._generated import TerminalTIState +from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( ConnectionResponse, GetConnection, @@ -55,8 +55,6 @@ from structlog.typing import FilteringBoundLogger from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity - from airflow.sdk.api.datamodels.ti import TaskInstance - __all__ = ["WatchedSubprocess", "supervise"] diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index c952207bca533..a6d7569382b31 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -29,7 +29,8 @@ from pydantic import ConfigDict, TypeAdapter from airflow.sdk import BaseOperator -from airflow.sdk.execution_time.comms import StartupDetails, TaskInstance, ToSupervisor, ToTask +from airflow.sdk.api.datamodels._generated import TaskInstance +from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor, ToTask if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 7a712d1cc0afe..428ade1c35aea 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -29,7 +29,7 @@ import structlog.testing from airflow.sdk.api import client as sdk_client -from airflow.sdk.api.datamodels.ti import TaskInstance +from airflow.sdk.api.datamodels._generated import TaskInstance from airflow.sdk.execution_time.supervisor import WatchedSubprocess from airflow.utils import timezone as tz diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index c634ba1255fe4..40c112170c6cd 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -24,7 +24,7 @@ import pytest from uuid6 import uuid7 -from airflow.sdk.api.datamodels.ti import TaskInstance +from airflow.sdk.api.datamodels._generated import TaskInstance from airflow.sdk.execution_time.comms import StartupDetails from airflow.sdk.execution_time.task_runner import CommsDecoder, parse diff --git a/task_sdk/src/airflow/sdk/api/datamodels/ti.py b/tests/api_fastapi/execution_api/test_app.py similarity index 59% rename from task_sdk/src/airflow/sdk/api/datamodels/ti.py rename to tests/api_fastapi/execution_api/test_app.py index ce9e1e870ae29..ccd8b4c8db9ff 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/ti.py +++ b/tests/api_fastapi/execution_api/test_app.py @@ -14,19 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - from __future__ import annotations -import uuid +from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance + + +def test_custom_openapi_includes_extra_schemas(client): + """Test to ensure that extra schemas are correctly included in the OpenAPI schema.""" -from pydantic import BaseModel + response = client.get("/execution/openapi.json") + assert response.status_code == 200 + openapi_schema = response.json() -class TaskInstance(BaseModel): - id: uuid.UUID + assert "TaskInstance" in openapi_schema["components"]["schemas"] + schema = openapi_schema["components"]["schemas"]["TaskInstance"] - task_id: str - dag_id: str - run_id: str - try_number: int - map_index: int | None = None + assert schema == TaskInstance.model_json_schema()