diff --git a/python_modules/dagster/dagster/_core/utils.py b/python_modules/dagster/dagster/_core/utils.py index ec901add5a75b..fa54b8ac59cc8 100644 --- a/python_modules/dagster/dagster/_core/utils.py +++ b/python_modules/dagster/dagster/_core/utils.py @@ -7,6 +7,7 @@ from collections import OrderedDict from concurrent.futures import Future, ThreadPoolExecutor from contextvars import copy_context +from functools import wraps from typing import ( AbstractSet, Any, @@ -23,7 +24,8 @@ from weakref import WeakSet import toposort as toposort_ -from typing_extensions import Final +from pydantic.dataclasses import dataclass as pydantic_dataclass +from typing_extensions import Final, dataclass_transform import dagster._check as check from dagster._utils import library_version_from_core_version, parse_package_version @@ -196,3 +198,43 @@ def submit(self, fn, *args, **kwargs): def is_valid_email(email: str) -> bool: regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,7}\b" return bool(re.fullmatch(regex, email)) + + +@dataclass_transform(frozen_default=True) +def strict_dataclass(positional_args: Optional[AbstractSet[str]] = None): + """The decorated class functions exactly like a class decorated with the pydantic.dataclass + decorator, with two exceptions: + - It's frozen. I.e. mutation is not allowed. + - The constructor does not accept keyword arguments that aren't fields of the class. + - By default, the constructor does not accept positional arguments, but this can be overridden + with positional_args. + """ + + def decorator(cls): + cls = pydantic_dataclass(cls, frozen=True, config=dict(extra="forbid")) + original_init = cls.__init__ + + @wraps(original_init) + def new_init(self, *args, **kwargs): + annotation_keys_list = list(self.__pydantic_fields__.keys()) + if len(args) > len(annotation_keys_list): + raise TypeError( + f"{cls.__name__} constructor takes at most {len(annotation_keys_list)} positional " + f"arguments, but {len(args)} were given" + ) + positionally_supplied_args = {annotation_keys_list[i] for i in range(len(args))} + illegal_positionally_supplied_args = positionally_supplied_args - ( + positional_args or set() + ) + if illegal_positionally_supplied_args: + raise TypeError( + "These arguments must be specified as keyword arguments, not positional " + f"arguments: {illegal_positionally_supplied_args}" + ) + + original_init(self, *args, **kwargs) + + cls.__init__ = new_init + return cls + + return decorator diff --git a/python_modules/dagster/dagster_tests/core_tests/partition_tests/test_strict_dataclass.py b/python_modules/dagster/dagster_tests/core_tests/partition_tests/test_strict_dataclass.py new file mode 100644 index 0000000000000..43629b8eb5eb9 --- /dev/null +++ b/python_modules/dagster/dagster_tests/core_tests/partition_tests/test_strict_dataclass.py @@ -0,0 +1,95 @@ +from dataclasses import FrozenInstanceError + +import pytest +from dagster._core.utils import strict_dataclass +from pydantic import ValidationError + + +def test_kwargs(): + @strict_dataclass() + class MyClass: + a: int + b: str + + foo = MyClass(a=5, b="x") + assert foo.a == 5 + assert foo.b == "x" + + +def test_kwargs_extras(): + @strict_dataclass() + class MyClass: + a: int + b: str + + with pytest.raises(ValidationError): + MyClass(a=5, b="x", c=5) + + +def test_frozen(): + @strict_dataclass() + class MyClass: + a: int + b: str + + foo = MyClass(a=5, b="x") + + with pytest.raises(FrozenInstanceError): + foo.a = 6 + + +def test_positional_args_default(): + @strict_dataclass() + class MyClass: + a: int + b: str + + with pytest.raises(TypeError): + MyClass(5, "x") + + +def test_positional_args_override_init(): + @strict_dataclass(positional_args={"a"}) + class MyClass: + a: int + b: str + + foo = MyClass(5, b="x") + assert foo.a == 5 + assert foo.b == "x" + + with pytest.raises(TypeError): + MyClass(5, "x") + + +def test_type_validation(): + @strict_dataclass() + class MyClass: + a: int + b: str + + with pytest.raises(ValidationError): + MyClass(a=5, b=6) + + +def test_too_many_positional_arguments(): + @strict_dataclass(positional_args={"a"}) + class MyClass: + a: int + + with pytest.raises(TypeError): + MyClass(5, 6) + + +def test_inheritance(): + @strict_dataclass() + class MyClass: + a: int + + @strict_dataclass() + class MySubClass(MyClass): + b: str + + foo = MySubClass(a=5, b="x") + assert foo.a == 5 + assert foo.b == "x"