diff --git a/python_modules/dagster/dagster/_core/definitions/asset_key.py b/python_modules/dagster/dagster/_core/definitions/asset_key.py index 0ba4c259d0c17..252caed433f33 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_key.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_key.py @@ -1,9 +1,10 @@ import re -from typing import TYPE_CHECKING, Mapping, NamedTuple, Optional, Sequence, Union +from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Union import dagster._check as check import dagster._seven as seven from dagster._annotations import PublicAttr +from dagster._core.utils import StrictModel from dagster._serdes import whitelist_for_serdes ASSET_KEY_SPLIT_REGEX = re.compile("[^a-zA-Z0-9_]") @@ -19,7 +20,7 @@ def parse_asset_key_string(s: str) -> Sequence[str]: @whitelist_for_serdes -class AssetKey(NamedTuple("_AssetKey", [("path", PublicAttr[Sequence[str]])])): +class AssetKey(StrictModel): """Object representing the structure of an asset key. Takes in a sanitized string, list of strings, or tuple of strings. @@ -39,13 +40,15 @@ class AssetKey(NamedTuple("_AssetKey", [("path", PublicAttr[Sequence[str]])])): strings represent the hierarchical structure of the asset_key. """ - def __new__(cls, path: Union[str, Sequence[str]]): + path: PublicAttr[Sequence[str]] + + def __init__(self, path: Union[str, Sequence[str]]): if isinstance(path, str): path = [path] else: path = list(check.sequence_param(path, "path", of_type=str)) - return super(AssetKey, cls).__new__(cls, path=path) + super().__init__(path=path) def __str__(self): return f"AssetKey({self.path})" @@ -66,6 +69,9 @@ def __eq__(self, other): return False return True + def __lt__(self, other): + return self.path < other.path + def to_string(self) -> str: """E.g. '["first_component", "second_component"]'.""" return seven.json.dumps(self.path) diff --git a/python_modules/dagster/dagster/_core/utils.py b/python_modules/dagster/dagster/_core/utils.py index ec901add5a75b..f7a4709be1404 100644 --- a/python_modules/dagster/dagster/_core/utils.py +++ b/python_modules/dagster/dagster/_core/utils.py @@ -23,6 +23,7 @@ from weakref import WeakSet import toposort as toposort_ +from pydantic import BaseModel from typing_extensions import Final import dagster._check as check @@ -196,3 +197,12 @@ def submit(self, fn, *args, **kwargs): def is_valid_email(email: str) -> bool: regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,7}\b" return bool(re.fullmatch(regex, email)) + + +class StrictModel(BaseModel): + def __init__(self, **data: Any) -> None: + super().__init__(**data) + + class Config: + extra = "forbid" + frozen = True diff --git a/python_modules/dagster/dagster_tests/core_tests/test_strict_model.py b/python_modules/dagster/dagster_tests/core_tests/test_strict_model.py new file mode 100644 index 0000000000000..2cd12bbea17ca --- /dev/null +++ b/python_modules/dagster/dagster_tests/core_tests/test_strict_model.py @@ -0,0 +1,37 @@ +import pytest +from dagster._core.utils import StrictModel +from pydantic import ValidationError + + +def test_override_constructor_in_subclass(): + class MyClass(StrictModel): + foo: str + bar: int + + def __init__(self, foo: str, bar: int): + super().__init__(foo=foo, bar=bar) + + MyClass(foo="fdsjk", bar=4) + + +def test_override_constructor_in_subclass_different_arg_names(): + class MyClass(StrictModel): + foo: str + bar: int + + def __init__(self, fooarg: str, bararg: int): + super().__init__(foo=fooarg, bar=bararg) + + MyClass(fooarg="fdsjk", bararg=4) + + +def test_override_constructor_in_subclass_wrong_type(): + class MyClass(StrictModel): + foo: str + bar: int + + def __init__(self, foo: str, bar: str): + super().__init__(foo=foo, bar=bar) + + with pytest.raises(ValidationError): + MyClass(foo="fdsjk", bar="fdslk") diff --git a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py index 99c026605c0b4..ee95e8ceb18e3 100644 --- a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py +++ b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py @@ -8,6 +8,7 @@ import pydantic import pytest from dagster._check import ParameterCheckError, inst_param, set_param +from dagster._core.utils import StrictModel from dagster._serdes.errors import DeserializationError, SerdesUsageError, SerializationError from dagster._serdes.serdes import ( EnumSerializer, @@ -822,6 +823,11 @@ class SomeModel(pydantic.BaseModel): id: int name: str + @_whitelist_for_serdes(test_env) + class SomeStrictModel(StrictModel): + id: int + name: str + @_whitelist_for_serdes(test_env) @pydantic.dataclasses.dataclass class DataclassObj: @@ -830,8 +836,16 @@ class DataclassObj: d: InnerDataclass nt: SomeNT m: SomeModel - - o = DataclassObj("woo", 4, InnerDataclass(1.2), SomeNT([1, 2, 3]), SomeModel(id=4, name="zuck")) + st: SomeStrictModel + + o = DataclassObj( + "woo", + 4, + InnerDataclass(1.2), + SomeNT([1, 2, 3]), + SomeModel(id=4, name="zuck"), + SomeStrictModel(id=4, name="zuck"), + ) ser_o = serialize_value(o, whitelist_map=test_env) assert deserialize_value(ser_o, whitelist_map=test_env) == o diff --git a/python_modules/libraries/dagster-wandb/dagster_wandb/io_manager.py b/python_modules/libraries/dagster-wandb/dagster_wandb/io_manager.py index 90fdb98c51697..9477944c808ad 100644 --- a/python_modules/libraries/dagster-wandb/dagster_wandb/io_manager.py +++ b/python_modules/libraries/dagster-wandb/dagster_wandb/io_manager.py @@ -393,7 +393,7 @@ def _download_artifact(self, context: InputContext): artifact_name = parameters.get("name") if artifact_name is None: - artifact_name = context.asset_key[0][0] # name of asset + artifact_name = context.asset_key.path[0] # name of asset partitions = [ (key, f"{artifact_name}.{ str(key).replace('|', '-')}")