diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py index 006d9d9318b5..23d67e6162fc 100644 --- a/airflow/serialization/serde.py +++ b/airflow/serialization/serde.py @@ -319,7 +319,7 @@ def _is_pydantic(cls: Any) -> bool: Checking is done by attributes as it is significantly faster than using isinstance. """ - return hasattr(cls, "__validators__") and hasattr(cls, "__fields__") and hasattr(cls, "dict") + return hasattr(cls, "model_config") and hasattr(cls, "model_fields") and hasattr(cls, "model_fields_set") def _register(): diff --git a/setup.cfg b/setup.cfg index 9ad223927a27..f0851e9225fe 100644 --- a/setup.cfg +++ b/setup.cfg @@ -122,7 +122,7 @@ install_requires = pendulum>=2.0 pluggy>=1.0 psutil>=4.2.0 - pydantic>=1.10.0 + pydantic>=2.3.0 pygments>=2.0.1 pyjwt>=2.0.0 python-daemon>=3.0.0 diff --git a/setup.py b/setup.py index 725147103d32..9e827b4c059f 100644 --- a/setup.py +++ b/setup.py @@ -409,7 +409,11 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve _devel_only_amazon = [ "aws_xray_sdk", - "moto[cloudformation, glue]>=4.0", + "moto[glue]>=4.0", + # TODO: Remove the two below after https://github.com/aws/serverless-application-model/pull/3282 + # gets released and add back "cloudformation" extra to moto above + "openapi-spec-validator >=0.2.8", + "jsonschema>=3.0", f"mypy-boto3-rds>={_MIN_BOTO3_VERSION}", f"mypy-boto3-redshift-data>={_MIN_BOTO3_VERSION}", f"mypy-boto3-s3>={_MIN_BOTO3_VERSION}", diff --git a/tests/serialization/test_pydantic_models.py b/tests/serialization/test_pydantic_models.py index b5e471b044c4..a9a4a6953a02 100644 --- a/tests/serialization/test_pydantic_models.py +++ b/tests/serialization/test_pydantic_models.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -from pydantic import parse_raw_as - from airflow.jobs.job import Job from airflow.jobs.local_task_job_runner import LocalTaskJobRunner from airflow.models.dataset import ( @@ -49,7 +47,7 @@ def test_serializing_pydantic_task_instance(session, create_task_instance): json_string = pydantic_task_instance.json() print(json_string) - deserialized_model = parse_raw_as(TaskInstancePydantic, json_string) + deserialized_model = TaskInstancePydantic.model_validate_json(json_string) assert deserialized_model.dag_id == dag_id assert deserialized_model.state == State.RUNNING assert deserialized_model.try_number == ti.try_number @@ -68,7 +66,7 @@ def test_serializing_pydantic_dagrun(session, create_task_instance): json_string = pydantic_dag_run.json() print(json_string) - deserialized_model = parse_raw_as(DagRunPydantic, json_string) + deserialized_model = DagRunPydantic.model_validate_json(json_string) assert deserialized_model.dag_id == dag_id assert deserialized_model.state == State.RUNNING @@ -85,7 +83,7 @@ def test_serializing_pydantic_local_task_job(session, create_task_instance): json_string = pydantic_job.json() print(json_string) - deserialized_model = parse_raw_as(JobPydantic, json_string) + deserialized_model = JobPydantic.model_validate_json(json_string) assert deserialized_model.dag_id == dag_id assert deserialized_model.job_type == "LocalTaskJob" assert deserialized_model.state == State.RUNNING @@ -139,17 +137,17 @@ def test_serializing_pydantic_dataset_event(session, create_task_instance, creat json_string_dr = pydantic_dag_run.json() print(json_string_dr) - deserialized_model1 = parse_raw_as(DatasetEventPydantic, json_string1) + deserialized_model1 = DatasetEventPydantic.model_validate_json(json_string1) assert deserialized_model1.dataset.id == 1 assert deserialized_model1.dataset.uri == "one" assert len(deserialized_model1.dataset.consuming_dags) == 1 assert len(deserialized_model1.dataset.producing_tasks) == 1 - deserialized_model2 = parse_raw_as(DatasetEventPydantic, json_string2) + deserialized_model2 = DatasetEventPydantic.model_validate_json(json_string2) assert deserialized_model2.dataset.id == 2 assert deserialized_model2.dataset.uri == "two" assert len(deserialized_model2.dataset.consuming_dags) == 0 assert len(deserialized_model2.dataset.producing_tasks) == 0 - deserialized_dr = parse_raw_as(DagRunPydantic, json_string_dr) + deserialized_dr = DagRunPydantic.model_validate_json(json_string_dr) assert len(deserialized_dr.consumed_dataset_events) == 3