From 8fe306b435cb3d58851b41deb703ea86470c071c Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 21 Mar 2024 14:24:36 -0700 Subject: [PATCH] introduce StrictModel branch-name: strict-model --- python_modules/dagster/dagster/_core/utils.py | 10 +++++ .../core_tests/test_strict_model.py | 37 +++++++++++++++++++ .../general_tests/test_serdes.py | 18 ++++++++- 3 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 python_modules/dagster/dagster_tests/core_tests/test_strict_model.py diff --git a/python_modules/dagster/dagster/_core/utils.py b/python_modules/dagster/dagster/_core/utils.py index ec901add5a75b..f7a4709be1404 100644 --- a/python_modules/dagster/dagster/_core/utils.py +++ b/python_modules/dagster/dagster/_core/utils.py @@ -23,6 +23,7 @@ from weakref import WeakSet import toposort as toposort_ +from pydantic import BaseModel from typing_extensions import Final import dagster._check as check @@ -196,3 +197,12 @@ def submit(self, fn, *args, **kwargs): def is_valid_email(email: str) -> bool: regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,7}\b" return bool(re.fullmatch(regex, email)) + + +class StrictModel(BaseModel): + def __init__(self, **data: Any) -> None: + super().__init__(**data) + + class Config: + extra = "forbid" + frozen = True diff --git a/python_modules/dagster/dagster_tests/core_tests/test_strict_model.py b/python_modules/dagster/dagster_tests/core_tests/test_strict_model.py new file mode 100644 index 0000000000000..2cd12bbea17ca --- /dev/null +++ b/python_modules/dagster/dagster_tests/core_tests/test_strict_model.py @@ -0,0 +1,37 @@ +import pytest +from dagster._core.utils import StrictModel +from pydantic import ValidationError + + +def test_override_constructor_in_subclass(): + class MyClass(StrictModel): + foo: str + bar: int + + def __init__(self, foo: str, bar: int): + super().__init__(foo=foo, bar=bar) + + MyClass(foo="fdsjk", bar=4) + + +def test_override_constructor_in_subclass_different_arg_names(): + class MyClass(StrictModel): + foo: str + bar: int + + def __init__(self, fooarg: str, bararg: int): + super().__init__(foo=fooarg, bar=bararg) + + MyClass(fooarg="fdsjk", bararg=4) + + +def test_override_constructor_in_subclass_wrong_type(): + class MyClass(StrictModel): + foo: str + bar: int + + def __init__(self, foo: str, bar: str): + super().__init__(foo=foo, bar=bar) + + with pytest.raises(ValidationError): + MyClass(foo="fdsjk", bar="fdslk") diff --git a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py index 99c026605c0b4..ee95e8ceb18e3 100644 --- a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py +++ b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py @@ -8,6 +8,7 @@ import pydantic import pytest from dagster._check import ParameterCheckError, inst_param, set_param +from dagster._core.utils import StrictModel from dagster._serdes.errors import DeserializationError, SerdesUsageError, SerializationError from dagster._serdes.serdes import ( EnumSerializer, @@ -822,6 +823,11 @@ class SomeModel(pydantic.BaseModel): id: int name: str + @_whitelist_for_serdes(test_env) + class SomeStrictModel(StrictModel): + id: int + name: str + @_whitelist_for_serdes(test_env) @pydantic.dataclasses.dataclass class DataclassObj: @@ -830,8 +836,16 @@ class DataclassObj: d: InnerDataclass nt: SomeNT m: SomeModel - - o = DataclassObj("woo", 4, InnerDataclass(1.2), SomeNT([1, 2, 3]), SomeModel(id=4, name="zuck")) + st: SomeStrictModel + + o = DataclassObj( + "woo", + 4, + InnerDataclass(1.2), + SomeNT([1, 2, 3]), + SomeModel(id=4, name="zuck"), + SomeStrictModel(id=4, name="zuck"), + ) ser_o = serialize_value(o, whitelist_map=test_env) assert deserialize_value(ser_o, whitelist_map=test_env) == o