Skip to content

Commit

Permalink
Extend OpenAPI schema with extra models for Task SDK (apache#44076)
Browse files Browse the repository at this point in the history
- Introduced `custom_openapi` to extend OpenAPI schema with additional models.
- Added `TaskInstance` model for inclusion in OpenAPI schema, specifically for Task SDK

Reference: https://fastapi.tiangolo.com/how-to/extending-openapi/#modify-the-openapi-schema
  • Loading branch information
kaxil authored and kandharvishnuu committed Nov 19, 2024
1 parent 88a0fae commit 2f955da
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 25 deletions.
45 changes: 42 additions & 3 deletions airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from contextlib import asynccontextmanager

from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi


@asynccontextmanager
Expand All @@ -34,11 +35,49 @@ def create_task_execution_api_app(app: FastAPI) -> FastAPI:
from airflow.api_fastapi.execution_api.routes import execution_api_router

# TODO: Add versioning to the API
task_exec_api_app = FastAPI(
app = FastAPI(
title="Airflow Task Execution API",
description="The private Airflow Task Execution API.",
lifespan=lifespan,
)

task_exec_api_app.include_router(execution_api_router)
return task_exec_api_app
def custom_openapi() -> dict:
"""
Customize the OpenAPI schema to include additional schemas not tied to specific endpoints.
This is particularly useful for client SDKs that require models for types
not directly exposed in any endpoint's request or response schema.
References:
- https://fastapi.tiangolo.com/how-to/extending-openapi/#modify-the-openapi-schema
"""
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description=app.description,
version=app.version,
routes=app.routes,
)

extra_schemas = get_extra_schemas()
for schema_name, schema in extra_schemas.items():
if schema_name not in openapi_schema["components"]["schemas"]:
openapi_schema["components"]["schemas"][schema_name] = schema

app.openapi_schema = openapi_schema
return app.openapi_schema

app.openapi = custom_openapi # type: ignore[method-assign]

app.include_router(execution_api_router)
return app


def get_extra_schemas() -> dict[str, dict]:
"""Get all the extra schemas that are not part of the main FastAPI app."""
from airflow.api_fastapi.execution_api.datamodels import taskinstance

return {
"TaskInstance": taskinstance.TaskInstance.model_json_schema(),
}
15 changes: 15 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import uuid
from typing import Annotated, Literal, Union

from pydantic import BaseModel, ConfigDict, Discriminator, Tag, WithJsonSchema
Expand Down Expand Up @@ -97,3 +98,17 @@ class TIHeartbeatInfo(BaseModel):

hostname: str
pid: int


# This model is not used in the API, but it is included in generated OpenAPI schema
# for use in the client SDKs.
class TaskInstance(BaseModel):
"""Schema for TaskInstance model with minimal required fields needed for Runtime."""

id: uuid.UUID

task_id: str
dag_id: str
run_id: str
try_number: int
map_index: int | None = None
16 changes: 13 additions & 3 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Literal
from uuid import UUID

from pydantic import BaseModel, Field

Expand All @@ -45,10 +46,9 @@ class ConnectionResponse(BaseModel):

class IntermediateTIState(str, Enum):
"""
States that a Task Instance can be in that indicate it is not yet in a terminal or running state
States that a Task Instance can be in that indicate it is not yet in a terminal or running state.
"""

REMOVED = "removed"
SCHEDULED = "scheduled"
QUEUED = "queued"
RESTARTING = "restarting"
Expand Down Expand Up @@ -89,12 +89,13 @@ class TITargetStatePayload(BaseModel):

class TerminalTIState(str, Enum):
"""
States that a Task Instance can be in that indicate it has reached a terminal state
States that a Task Instance can be in that indicate it has reached a terminal state.
"""

SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped"
REMOVED = "removed"


class ValidationError(BaseModel):
Expand All @@ -121,6 +122,15 @@ class XComResponse(BaseModel):
value: Annotated[Any, Field(title="Value")]


class TaskInstance(BaseModel):
id: Annotated[UUID, Field(title="Id")]
task_id: Annotated[str, Field(title="Task Id")]
dag_id: Annotated[str, Field(title="Dag Id")]
run_id: Annotated[str, Field(title="Run Id")]
try_number: Annotated[int, Field(title="Try Number")]
map_index: Annotated[int | None, Field(title="Map Index")] = None


class HTTPValidationError(BaseModel):
detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = None

Expand Down
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/api/datamodels/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pydantic import BaseModel

from airflow.sdk.api.datamodels.ti import TaskInstance
from airflow.sdk.api.datamodels._generated import TaskInstance


class ExecuteTaskActivity(BaseModel):
Expand Down
3 changes: 1 addition & 2 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@

from pydantic import BaseModel, ConfigDict, Field

from airflow.sdk.api.datamodels._generated import TerminalTIState # noqa: TCH001
from airflow.sdk.api.datamodels.ti import TaskInstance # noqa: TCH001
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState # noqa: TCH001


class StartupDetails(BaseModel):
Expand Down
4 changes: 1 addition & 3 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from pydantic import TypeAdapter

from airflow.sdk.api.client import Client
from airflow.sdk.api.datamodels._generated import TerminalTIState
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
ConnectionResponse,
GetConnection,
Expand All @@ -55,8 +55,6 @@
from structlog.typing import FilteringBoundLogger

from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
from airflow.sdk.api.datamodels.ti import TaskInstance


__all__ = ["WatchedSubprocess", "supervise"]

Expand Down
3 changes: 2 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from pydantic import ConfigDict, TypeAdapter

from airflow.sdk import BaseOperator
from airflow.sdk.execution_time.comms import StartupDetails, TaskInstance, ToSupervisor, ToTask
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor, ToTask

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import structlog.testing

from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.datamodels.ti import TaskInstance
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
from airflow.utils import timezone as tz

Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pytest
from uuid6 import uuid7

from airflow.sdk.api.datamodels.ti import TaskInstance
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.comms import StartupDetails
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import uuid
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance


def test_custom_openapi_includes_extra_schemas(client):
"""Test to ensure that extra schemas are correctly included in the OpenAPI schema."""

from pydantic import BaseModel
response = client.get("/execution/openapi.json")
assert response.status_code == 200

openapi_schema = response.json()

class TaskInstance(BaseModel):
id: uuid.UUID
assert "TaskInstance" in openapi_schema["components"]["schemas"]
schema = openapi_schema["components"]["schemas"]["TaskInstance"]

task_id: str
dag_id: str
run_id: str
try_number: int
map_index: int | None = None
assert schema == TaskInstance.model_json_schema()

0 comments on commit 2f955da

Please sign in to comment.