Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert AssetKey to StrictModel #20643

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions python_modules/dagster/dagster/_core/definitions/asset_key.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re
from typing import TYPE_CHECKING, Mapping, NamedTuple, Optional, Sequence, Union
from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Union

import dagster._check as check
import dagster._seven as seven
from dagster._annotations import PublicAttr
from dagster._core.utils import StrictModel
from dagster._serdes import whitelist_for_serdes

ASSET_KEY_SPLIT_REGEX = re.compile("[^a-zA-Z0-9_]")
Expand All @@ -19,7 +20,7 @@ def parse_asset_key_string(s: str) -> Sequence[str]:


@whitelist_for_serdes
class AssetKey(NamedTuple("_AssetKey", [("path", PublicAttr[Sequence[str]])])):
class AssetKey(StrictModel):
"""Object representing the structure of an asset key. Takes in a sanitized string, list of
strings, or tuple of strings.

Expand All @@ -39,13 +40,15 @@ class AssetKey(NamedTuple("_AssetKey", [("path", PublicAttr[Sequence[str]])])):
strings represent the hierarchical structure of the asset_key.
"""

def __new__(cls, path: Union[str, Sequence[str]]):
path: PublicAttr[Sequence[str]]

def __init__(self, path: Union[str, Sequence[str]]):
if isinstance(path, str):
path = [path]
else:
path = list(check.sequence_param(path, "path", of_type=str))

return super(AssetKey, cls).__new__(cls, path=path)
super().__init__(path=path)

def __str__(self):
return f"AssetKey({self.path})"
Expand All @@ -66,6 +69,9 @@ def __eq__(self, other):
return False
return True

def __lt__(self, other):
return self.path < other.path

def to_string(self) -> str:
"""E.g. '["first_component", "second_component"]'."""
return seven.json.dumps(self.path)
Expand Down
10 changes: 10 additions & 0 deletions python_modules/dagster/dagster/_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from weakref import WeakSet

import toposort as toposort_
from pydantic import BaseModel
from typing_extensions import Final

import dagster._check as check
Expand Down Expand Up @@ -196,3 +197,12 @@ def submit(self, fn, *args, **kwargs):
def is_valid_email(email: str) -> bool:
regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,7}\b"
return bool(re.fullmatch(regex, email))


class StrictModel(BaseModel):
def __init__(self, **data: Any) -> None:
super().__init__(**data)

class Config:
extra = "forbid"
frozen = True
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
from dagster._core.utils import StrictModel
from pydantic import ValidationError


def test_override_constructor_in_subclass():
class MyClass(StrictModel):
foo: str
bar: int

def __init__(self, foo: str, bar: int):
super().__init__(foo=foo, bar=bar)

MyClass(foo="fdsjk", bar=4)


def test_override_constructor_in_subclass_different_arg_names():
class MyClass(StrictModel):
foo: str
bar: int

def __init__(self, fooarg: str, bararg: int):
super().__init__(foo=fooarg, bar=bararg)

MyClass(fooarg="fdsjk", bararg=4)


def test_override_constructor_in_subclass_wrong_type():
class MyClass(StrictModel):
foo: str
bar: int

def __init__(self, foo: str, bar: str):
super().__init__(foo=foo, bar=bar)

with pytest.raises(ValidationError):
MyClass(foo="fdsjk", bar="fdslk")
18 changes: 16 additions & 2 deletions python_modules/dagster/dagster_tests/general_tests/test_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pydantic
import pytest
from dagster._check import ParameterCheckError, inst_param, set_param
from dagster._core.utils import StrictModel
from dagster._serdes.errors import DeserializationError, SerdesUsageError, SerializationError
from dagster._serdes.serdes import (
EnumSerializer,
Expand Down Expand Up @@ -822,6 +823,11 @@ class SomeModel(pydantic.BaseModel):
id: int
name: str

@_whitelist_for_serdes(test_env)
class SomeStrictModel(StrictModel):
id: int
name: str

@_whitelist_for_serdes(test_env)
@pydantic.dataclasses.dataclass
class DataclassObj:
Expand All @@ -830,8 +836,16 @@ class DataclassObj:
d: InnerDataclass
nt: SomeNT
m: SomeModel

o = DataclassObj("woo", 4, InnerDataclass(1.2), SomeNT([1, 2, 3]), SomeModel(id=4, name="zuck"))
st: SomeStrictModel

o = DataclassObj(
"woo",
4,
InnerDataclass(1.2),
SomeNT([1, 2, 3]),
SomeModel(id=4, name="zuck"),
SomeStrictModel(id=4, name="zuck"),
)
ser_o = serialize_value(o, whitelist_map=test_env)
assert deserialize_value(ser_o, whitelist_map=test_env) == o

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def _download_artifact(self, context: InputContext):

artifact_name = parameters.get("name")
if artifact_name is None:
artifact_name = context.asset_key[0][0] # name of asset
artifact_name = context.asset_key.path[0] # name of asset

partitions = [
(key, f"{artifact_name}.{ str(key).replace('|', '-')}")
Expand Down