Skip to content

Commit

Permalink
Upgrade to Pydantic 2 (#33956)
Browse files Browse the repository at this point in the history
The only blocking factor to migrate to Pydantic 2 was the
aws-sam-translator which was transitive dependency to
`moto[cloudformation]` via `cfn-lint` and we do not really need
everything in that extra - used only for testing.

While aws-sam-translator is already preparing to release Pydantic 2
compatible version, we do not want to wait - instead we replace the
cloudformation extra with openapi_spec_validator and jsonschema
needed by the cloudformation tests.
  • Loading branch information
potiuk authored Aug 31, 2023
1 parent c51901a commit 1cda0c3
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion airflow/serialization/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _is_pydantic(cls: Any) -> bool:
Checking is done by attributes as it is significantly faster than
using isinstance.
"""
return hasattr(cls, "__validators__") and hasattr(cls, "__fields__") and hasattr(cls, "dict")
return hasattr(cls, "model_config") and hasattr(cls, "model_fields") and hasattr(cls, "model_fields_set")


def _register():
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ install_requires =
pendulum>=2.0
pluggy>=1.0
psutil>=4.2.0
pydantic>=1.10.0
pydantic>=2.3.0
pygments>=2.0.1
pyjwt>=2.0.0
python-daemon>=3.0.0
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,11 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve

_devel_only_amazon = [
"aws_xray_sdk",
"moto[cloudformation, glue]>=4.0",
"moto[glue]>=4.0",
# TODO: Remove the two below after https://github.com/aws/serverless-application-model/pull/3282
# gets released and add back "cloudformation" extra to moto above
"openapi-spec-validator >=0.2.8",
"jsonschema>=3.0",
f"mypy-boto3-rds>={_MIN_BOTO3_VERSION}",
f"mypy-boto3-redshift-data>={_MIN_BOTO3_VERSION}",
f"mypy-boto3-s3>={_MIN_BOTO3_VERSION}",
Expand Down
14 changes: 6 additions & 8 deletions tests/serialization/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
# under the License.
from __future__ import annotations

from pydantic import parse_raw_as

from airflow.jobs.job import Job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.models.dataset import (
Expand Down Expand Up @@ -49,7 +47,7 @@ def test_serializing_pydantic_task_instance(session, create_task_instance):
json_string = pydantic_task_instance.json()
print(json_string)

deserialized_model = parse_raw_as(TaskInstancePydantic, json_string)
deserialized_model = TaskInstancePydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == dag_id
assert deserialized_model.state == State.RUNNING
assert deserialized_model.try_number == ti.try_number
Expand All @@ -68,7 +66,7 @@ def test_serializing_pydantic_dagrun(session, create_task_instance):
json_string = pydantic_dag_run.json()
print(json_string)

deserialized_model = parse_raw_as(DagRunPydantic, json_string)
deserialized_model = DagRunPydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == dag_id
assert deserialized_model.state == State.RUNNING

Expand All @@ -85,7 +83,7 @@ def test_serializing_pydantic_local_task_job(session, create_task_instance):
json_string = pydantic_job.json()
print(json_string)

deserialized_model = parse_raw_as(JobPydantic, json_string)
deserialized_model = JobPydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == dag_id
assert deserialized_model.job_type == "LocalTaskJob"
assert deserialized_model.state == State.RUNNING
Expand Down Expand Up @@ -139,17 +137,17 @@ def test_serializing_pydantic_dataset_event(session, create_task_instance, creat
json_string_dr = pydantic_dag_run.json()
print(json_string_dr)

deserialized_model1 = parse_raw_as(DatasetEventPydantic, json_string1)
deserialized_model1 = DatasetEventPydantic.model_validate_json(json_string1)
assert deserialized_model1.dataset.id == 1
assert deserialized_model1.dataset.uri == "one"
assert len(deserialized_model1.dataset.consuming_dags) == 1
assert len(deserialized_model1.dataset.producing_tasks) == 1

deserialized_model2 = parse_raw_as(DatasetEventPydantic, json_string2)
deserialized_model2 = DatasetEventPydantic.model_validate_json(json_string2)
assert deserialized_model2.dataset.id == 2
assert deserialized_model2.dataset.uri == "two"
assert len(deserialized_model2.dataset.consuming_dags) == 0
assert len(deserialized_model2.dataset.producing_tasks) == 0

deserialized_dr = parse_raw_as(DagRunPydantic, json_string_dr)
deserialized_dr = DagRunPydantic.model_validate_json(json_string_dr)
assert len(deserialized_dr.consumed_dataset_events) == 3

0 comments on commit 1cda0c3

Please sign in to comment.