diff --git a/python_modules/dagster/dagster/_core/definitions/asset_selection.py b/python_modules/dagster/dagster/_core/definitions/asset_selection.py index a75cbba8a4803..5822e5db7daba 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_selection.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_selection.py @@ -2,8 +2,10 @@ import operator from abc import ABC, abstractmethod from functools import reduce -from typing import AbstractSet, Iterable, NamedTuple, Optional, Sequence, Union, cast +from typing import AbstractSet, Iterable, Optional, Sequence, Union, cast +import pydantic +from pydantic import BaseModel from typing_extensions import TypeAlias import dagster._check as check @@ -38,7 +40,7 @@ ] -class AssetSelection(ABC): +class AssetSelection(ABC, BaseModel, frozen=True): """An AssetSelection defines a query over a set of assets and asset checks, normally all that are defined in a code location. You can use the "|", "&", and "-" operators to create unions, intersections, and differences of selections, respectively. @@ -95,7 +97,9 @@ def all_asset_checks() -> "AllAssetCheckSelection": @staticmethod def assets(*assets_defs: AssetsDefinition) -> "KeysAssetSelection": """Returns a selection that includes all of the provided assets and asset checks that target them.""" - return KeysAssetSelection([key for assets_def in assets_defs for key in assets_def.keys]) + return KeysAssetSelection( + selected_keys=[key for assets_def in assets_defs for key in assets_def.keys] + ) @public @staticmethod @@ -120,7 +124,7 @@ def keys(*asset_keys: CoercibleToAssetKey) -> "KeysAssetSelection": AssetKey.from_user_string(key) if isinstance(key, str) else AssetKey.from_coercible(key) for key in asset_keys ] - return KeysAssetSelection(_asset_keys) + return KeysAssetSelection(selected_keys=_asset_keys) @public @staticmethod @@ -144,7 +148,9 @@ def key_prefixes( AssetSelection.key_prefixes(["a", "b"], ["a", "c"]) """ _asset_key_prefixes = [key_prefix_from_coercible(key_prefix) for key_prefix in key_prefixes] - return KeyPrefixesAssetSelection(_asset_key_prefixes, include_sources=include_sources) + return KeyPrefixesAssetSelection( + selected_key_prefixes=_asset_key_prefixes, include_sources=include_sources + ) @public @staticmethod @@ -157,14 +163,14 @@ def groups(*group_strs, include_sources: bool = False) -> "GroupsAssetSelection" selection. """ check.tuple_param(group_strs, "group_strs", of_type=str) - return GroupsAssetSelection(group_strs, include_sources=include_sources) + return GroupsAssetSelection(selected_groups=group_strs, include_sources=include_sources) @public @staticmethod def checks_for_assets(*assets_defs: AssetsDefinition) -> "AssetChecksForAssetKeysSelection": """Returns a selection with the asset checks that target the provided assets.""" return AssetChecksForAssetKeysSelection( - [key for assets_def in assets_defs for key in assets_def.keys] + selected_asset_keys=[key for assets_def in assets_defs for key in assets_def.keys] ) @public @@ -172,7 +178,7 @@ def checks_for_assets(*assets_defs: AssetsDefinition) -> "AssetChecksForAssetKey def checks(*asset_checks: AssetChecksDefinition) -> "AssetCheckKeysSelection": """Returns a selection that includes all of the provided asset checks.""" return AssetCheckKeysSelection( - [ + selected_asset_check_keys=[ AssetCheckKey(asset_key=AssetKey.from_coercible(spec.asset_key), name=spec.name) for checks_def in asset_checks for spec in checks_def.specs @@ -196,7 +202,7 @@ def downstream( """ check.opt_int_param(depth, "depth") check.opt_bool_param(include_self, "include_self") - return DownstreamAssetSelection(self, depth=depth, include_self=include_self) + return DownstreamAssetSelection(child=self, depth=depth, include_self=include_self) @public def upstream( @@ -219,7 +225,7 @@ def upstream( """ check.opt_int_param(depth, "depth") check.opt_bool_param(include_self, "include_self") - return UpstreamAssetSelection(self, depth=depth, include_self=include_self) + return UpstreamAssetSelection(child=self, depth=depth, include_self=include_self) @public def sinks(self) -> "SinksAssetSelection": @@ -229,7 +235,7 @@ def sinks(self) -> "SinksAssetSelection": A sink asset is an asset that has no downstream dependencies within the asset selection. The sink asset can have downstream dependencies outside of the asset selection. """ - return SinksAssetSelection(self) + return SinksAssetSelection(child=self) @public def required_multi_asset_neighbors(self) -> "RequiredNeighborsAssetSelection": @@ -237,7 +243,7 @@ def required_multi_asset_neighbors(self) -> "RequiredNeighborsAssetSelection": which cannot be subset, returns a new asset selection that contains all of the assets required to execute the original asset selection. Includes the asset checks targeting the returned assets. """ - return RequiredNeighborsAssetSelection(self) + return RequiredNeighborsAssetSelection(child=self) @public def roots(self) -> "RootsAssetSelection": @@ -251,7 +257,7 @@ def roots(self) -> "RootsAssetSelection": keys corresponding to `SourceAssets` will not be included as roots. To select source assets, use the `upstream_source_assets` method. """ - return RootsAssetSelection(self) + return RootsAssetSelection(child=self) @public @deprecated(breaking_version="2.0", additional_warn_text="Use AssetSelection.roots instead.") @@ -274,7 +280,7 @@ def upstream_source_assets(self) -> "ParentSourcesAssetSelection": assets that are parents of assets in the original selection. Includes the asset checks targeting the returned assets. """ - return ParentSourcesAssetSelection(self) + return ParentSourcesAssetSelection(child=self) @public def without_checks(self) -> "AssetSelection": @@ -291,7 +297,7 @@ def __or__(self, other: "AssetSelection") -> "OrAssetSelection": else: operands.append(selection) - return OrAssetSelection(operands) + return OrAssetSelection(operands=operands) def __and__(self, other: "AssetSelection") -> "AndAssetSelection": check.inst_param(other, "other", AssetSelection) @@ -303,7 +309,7 @@ def __and__(self, other: "AssetSelection") -> "AndAssetSelection": else: operands.append(selection) - return AndAssetSelection(operands) + return AndAssetSelection(operands=operands) def __bool__(self): # Ensure that even if a subclass is a NamedTuple with no fields, it is still truthy @@ -311,7 +317,7 @@ def __bool__(self): def __sub__(self, other: "AssetSelection") -> "SubtractAssetSelection": check.inst_param(other, "other", AssetSelection) - return SubtractAssetSelection(self, other) + return SubtractAssetSelection(left=self, right=other) def resolve( self, all_assets: Union[Iterable[Union[AssetsDefinition, SourceAsset]], AssetGraph] @@ -396,9 +402,16 @@ def from_coercible(cls, selection: CoercibleToAssetSelection) -> "AssetSelection def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection": return AssetSelection.keys(*self.resolve(asset_graph)) + def replace(self, **kwargs): + if pydantic.__version__ >= "2": + func = getattr(BaseModel, "model_copy") + else: + func = getattr(BaseModel, "copy") + return func(self, update=kwargs) + @whitelist_for_serdes -class AllSelection(AssetSelection, NamedTuple("_AllSelection", [])): +class AllSelection(AssetSelection, frozen=True): def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return asset_graph.materializable_asset_keys @@ -410,7 +423,7 @@ def __str__(self) -> str: @whitelist_for_serdes -class AllAssetCheckSelection(AssetSelection, NamedTuple("_AllAssetChecksSelection", [])): +class AllAssetCheckSelection(AssetSelection, frozen=True): def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return set() @@ -425,10 +438,9 @@ def __str__(self) -> str: @whitelist_for_serdes -class AssetChecksForAssetKeysSelection( - NamedTuple("_AssetChecksForAssetKeysSelection", [("selected_asset_keys", Sequence[AssetKey])]), - AssetSelection, -): +class AssetChecksForAssetKeysSelection(AssetSelection, frozen=True): + selected_asset_keys: Sequence[AssetKey] + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return set() @@ -444,12 +456,9 @@ def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSele @whitelist_for_serdes -class AssetCheckKeysSelection( - NamedTuple( - "_AssetCheckKeysSelection", [("selected_asset_check_keys", Sequence[AssetCheckKey])] - ), - AssetSelection, -): +class AssetCheckKeysSelection(AssetSelection, frozen=True): + selected_asset_check_keys: Sequence[AssetCheckKey] + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return set() @@ -467,8 +476,11 @@ def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSele @whitelist_for_serdes class AndAssetSelection( AssetSelection, - NamedTuple("_AndAssetSelection", [("operands", Sequence[AssetSelection])]), + frozen=True, + arbitrary_types_allowed=True, ): + operands: Sequence[AssetSelection] + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return reduce( operator.and_, (selection.resolve_inner(asset_graph) for selection in self.operands) @@ -481,7 +493,7 @@ def resolve_checks_inner(self, asset_graph: InternalAssetGraph) -> AbstractSet[A ) def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection": - return self._replace( + return self.replace( operands=[ operand.to_serializable_asset_selection(asset_graph) for operand in self.operands ] @@ -491,8 +503,11 @@ def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSele @whitelist_for_serdes class OrAssetSelection( AssetSelection, - NamedTuple("_OrAssetSelection", [("operands", Sequence[AssetSelection])]), + frozen=True, + arbitrary_types_allowed=True, ): + operands: Sequence[AssetSelection] + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return reduce( operator.or_, (selection.resolve_inner(asset_graph) for selection in self.operands) @@ -505,7 +520,7 @@ def resolve_checks_inner(self, asset_graph: InternalAssetGraph) -> AbstractSet[A ) def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection": - return self._replace( + return self.replace( operands=[ operand.to_serializable_asset_selection(asset_graph) for operand in self.operands ] @@ -515,8 +530,12 @@ def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSele @whitelist_for_serdes class SubtractAssetSelection( AssetSelection, - NamedTuple("_SubtractAssetSelection", [("left", AssetSelection), ("right", AssetSelection)]), + frozen=True, + arbitrary_types_allowed=True, ): + left: AssetSelection + right: AssetSelection + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return self.left.resolve_inner(asset_graph) - self.right.resolve_inner(asset_graph) @@ -526,7 +545,7 @@ def resolve_checks_inner(self, asset_graph: InternalAssetGraph) -> AbstractSet[A ) def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection": - return self._replace( + return self.replace( left=self.left.to_serializable_asset_selection(asset_graph), right=self.right.to_serializable_asset_selection(asset_graph), ) @@ -535,21 +554,27 @@ def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSele @whitelist_for_serdes class SinksAssetSelection( AssetSelection, - NamedTuple("_SinksAssetSelection", [("child", AssetSelection)]), + frozen=True, + arbitrary_types_allowed=True, ): + child: AssetSelection + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: selection = self.child.resolve_inner(asset_graph) return fetch_sinks(asset_graph.asset_dep_graph, selection) def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection": - return self._replace(child=self.child.to_serializable_asset_selection(asset_graph)) + return self.replace(child=self.child.to_serializable_asset_selection(asset_graph)) @whitelist_for_serdes class RequiredNeighborsAssetSelection( AssetSelection, - NamedTuple("_RequiredNeighborsAssetSelection", [("child", AssetSelection)]), + frozen=True, + arbitrary_types_allowed=True, ): + child: AssetSelection + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: selection = self.child.resolve_inner(asset_graph) output = set(selection) @@ -558,34 +583,35 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return output def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection": - return self._replace(child=self.child.to_serializable_asset_selection(asset_graph)) + return self.replace(child=self.child.to_serializable_asset_selection(asset_graph)) @whitelist_for_serdes class RootsAssetSelection( AssetSelection, - NamedTuple("_RootsAssetSelection", [("child", AssetSelection)]), + frozen=True, + arbitrary_types_allowed=True, ): + child: AssetSelection + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: selection = self.child.resolve_inner(asset_graph) return fetch_sources(asset_graph.asset_dep_graph, selection) def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection": - return self._replace(child=self.child.to_serializable_asset_selection(asset_graph)) + return self.replace(child=self.child.to_serializable_asset_selection(asset_graph)) @whitelist_for_serdes class DownstreamAssetSelection( AssetSelection, - NamedTuple( - "_DownstreamAssetSelection", - [ - ("child", AssetSelection), - ("depth", Optional[int]), - ("include_self", bool), - ], - ), + frozen=True, + arbitrary_types_allowed=True, ): + child: AssetSelection + depth: Optional[int] + include_self: bool + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: selection = self.child.resolve_inner(asset_graph) return operator.sub( @@ -606,20 +632,14 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: ) def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection": - return self._replace(child=self.child.to_serializable_asset_selection(asset_graph)) + return self.replace(child=self.child.to_serializable_asset_selection(asset_graph)) @whitelist_for_serdes -class GroupsAssetSelection( - NamedTuple( - "_GroupsAssetSelection", - [ - ("selected_groups", Sequence[str]), - ("include_sources", bool), - ], - ), - AssetSelection, -): +class GroupsAssetSelection(AssetSelection, frozen=True): + selected_groups: Sequence[str] + include_sources: bool + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: base_set = ( asset_graph.all_asset_keys @@ -643,10 +663,9 @@ def __str__(self) -> str: @whitelist_for_serdes -class KeysAssetSelection( - NamedTuple("_KeysAssetSelection", [("selected_keys", Sequence[AssetKey])]), - AssetSelection, -): +class KeysAssetSelection(AssetSelection, frozen=True): + selected_keys: Sequence[AssetKey] + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: specified_keys = set(self.selected_keys) invalid_keys = {key for key in specified_keys if key not in asset_graph.all_asset_keys} @@ -666,13 +685,10 @@ def __str__(self) -> str: @whitelist_for_serdes -class KeyPrefixesAssetSelection( - NamedTuple( - "_KeyPrefixesAssetSelection", - [("selected_key_prefixes", Sequence[Sequence[str]]), ("include_sources", bool)], - ), - AssetSelection, -): +class KeyPrefixesAssetSelection(AssetSelection, frozen=True): + selected_key_prefixes: Sequence[Sequence[str]] + include_sources: bool + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: base_set = ( asset_graph.all_asset_keys @@ -724,15 +740,13 @@ def _fetch_all_upstream( @whitelist_for_serdes class UpstreamAssetSelection( AssetSelection, - NamedTuple( - "_UpstreamAssetSelection", - [ - ("child", AssetSelection), - ("depth", Optional[int]), - ("include_self", bool), - ], - ), + frozen=True, + arbitrary_types_allowed=True, ): + child: AssetSelection + depth: Optional[int] + include_self: bool + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: selection = self.child.resolve_inner(asset_graph) if len(selection) == 0: @@ -741,14 +755,17 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return {key for key in all_upstream if key not in asset_graph.source_asset_keys} def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection": - return self._replace(child=self.child.to_serializable_asset_selection(asset_graph)) + return self.replace(child=self.child.to_serializable_asset_selection(asset_graph)) @whitelist_for_serdes class ParentSourcesAssetSelection( AssetSelection, - NamedTuple("_ParentSourcesAssetSelection", [("child", AssetSelection)]), + frozen=True, + arbitrary_types_allowed=True, ): + child: AssetSelection + def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: selection = self.child.resolve_inner(asset_graph) if len(selection) == 0: @@ -757,4 +774,4 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return {key for key in all_upstream if key in asset_graph.source_asset_keys} def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection": - return self._replace(child=self.child.to_serializable_asset_selection(asset_graph)) + return self.replace(child=self.child.to_serializable_asset_selection(asset_graph)) diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_selection.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_selection.py index 92da96400c903..b20794ec38ef6 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_selection.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_selection.py @@ -19,11 +19,28 @@ multi_asset, ) from dagster._core.definitions import AssetSelection, asset +from dagster._core.definitions.asset_check_spec import AssetCheckKey from dagster._core.definitions.asset_graph import AssetGraph -from dagster._core.definitions.asset_selection import AndAssetSelection, OrAssetSelection +from dagster._core.definitions.asset_selection import ( + AndAssetSelection, + AssetCheckKeysSelection, + AssetChecksForAssetKeysSelection, + DownstreamAssetSelection, + GroupsAssetSelection, + KeyPrefixesAssetSelection, + KeysAssetSelection, + OrAssetSelection, + ParentSourcesAssetSelection, + RequiredNeighborsAssetSelection, + RootsAssetSelection, + SinksAssetSelection, + SubtractAssetSelection, + UpstreamAssetSelection, +) from dagster._core.definitions.assets import AssetsDefinition from dagster._core.definitions.events import AssetKey from dagster._serdes.serdes import _WHITELIST_MAP +from pydantic import ValidationError from typing_extensions import TypeAlias earth = SourceAsset(["celestial", "earth"], group_name="planets") @@ -386,17 +403,110 @@ def test_multi_operand_selection(): bar = AssetSelection.keys("bar") baz = AssetSelection.keys("baz") - assert foo & bar & baz == AndAssetSelection([foo, bar, baz]) - assert (foo & bar) & baz == AndAssetSelection([foo, bar, baz]) - assert foo & (bar & baz) == AndAssetSelection([foo, bar, baz]) - assert foo | bar | baz == OrAssetSelection([foo, bar, baz]) - assert (foo | bar) | baz == OrAssetSelection([foo, bar, baz]) - assert foo | (bar | baz) == OrAssetSelection([foo, bar, baz]) + assert foo & bar & baz == AndAssetSelection(operands=[foo, bar, baz]) + assert (foo & bar) & baz == AndAssetSelection(operands=[foo, bar, baz]) + assert foo & (bar & baz) == AndAssetSelection(operands=[foo, bar, baz]) + assert foo | bar | baz == OrAssetSelection(operands=[foo, bar, baz]) + assert (foo | bar) | baz == OrAssetSelection(operands=[foo, bar, baz]) + assert foo | (bar | baz) == OrAssetSelection(operands=[foo, bar, baz]) + + assert (foo & bar) | baz == OrAssetSelection( + operands=[AndAssetSelection(operands=[foo, bar]), baz] + ) + assert foo & (bar | baz) == AndAssetSelection( + operands=[foo, OrAssetSelection(operands=[bar, baz])] + ) + assert (foo | bar) & baz == AndAssetSelection( + operands=[OrAssetSelection(operands=[foo, bar]), baz] + ) + assert foo | (bar & baz) == OrAssetSelection( + operands=[foo, AndAssetSelection(operands=[bar, baz])] + ) + + +def test_asset_selection_type_checking(): + valid_asset_selection = AssetSelection.keys("foo") + valid_asset_selection_sequence = [valid_asset_selection] + valid_asset_key = AssetKey("bar") + valid_asset_key_sequence = [valid_asset_key] + valid_string_sequence = ["string"] + valid_string_sequence_sequence = [valid_string_sequence] + valid_asset_check_key = AssetCheckKey(asset_key=valid_asset_key, name="test_name") + valid_asset_check_key_sequence = [valid_asset_check_key] + + invalid_argument = "invalid_argument" + + with pytest.raises(ValidationError): + AssetChecksForAssetKeysSelection(selected_asset_keys=invalid_argument) + test = AssetChecksForAssetKeysSelection(selected_asset_keys=valid_asset_key_sequence) + assert isinstance(test, AssetChecksForAssetKeysSelection) + + with pytest.raises(ValidationError): + AssetCheckKeysSelection(selected_asset_check_keys=invalid_argument) + test = AssetCheckKeysSelection(selected_asset_check_keys=valid_asset_check_key_sequence) + assert isinstance(test, AssetCheckKeysSelection) + + with pytest.raises(ValidationError): + AndAssetSelection(operands=invalid_argument) + test = AndAssetSelection(operands=valid_asset_selection_sequence) + assert isinstance(test, AndAssetSelection) + + with pytest.raises(ValidationError): + OrAssetSelection(operands=invalid_argument) + test = OrAssetSelection(operands=valid_asset_selection_sequence) + assert isinstance(test, OrAssetSelection) + + with pytest.raises(ValidationError): + SubtractAssetSelection(left=invalid_argument, right=invalid_argument) + test = SubtractAssetSelection(left=valid_asset_selection, right=valid_asset_selection) + assert isinstance(test, SubtractAssetSelection) + + with pytest.raises(ValidationError): + SinksAssetSelection(child=invalid_argument) + test = SinksAssetSelection(child=valid_asset_selection) + assert isinstance(test, SinksAssetSelection) + + with pytest.raises(ValidationError): + RequiredNeighborsAssetSelection(child=invalid_argument) + test = RequiredNeighborsAssetSelection(child=valid_asset_selection) + assert isinstance(test, RequiredNeighborsAssetSelection) + + with pytest.raises(ValidationError): + RootsAssetSelection(child=invalid_argument) + test = RootsAssetSelection(child=valid_asset_selection) + assert isinstance(test, RootsAssetSelection) + + with pytest.raises(ValidationError): + DownstreamAssetSelection(child=invalid_argument, depth=0, include_self=False) + test = DownstreamAssetSelection(child=valid_asset_selection, depth=0, include_self=False) + assert isinstance(test, DownstreamAssetSelection) + + with pytest.raises(ValidationError): + GroupsAssetSelection(selected_groups=invalid_argument, include_sources=False) + test = GroupsAssetSelection(selected_groups=valid_string_sequence, include_sources=False) + assert isinstance(test, GroupsAssetSelection) + + with pytest.raises(ValidationError): + KeysAssetSelection(selected_keys=invalid_argument) + test = KeysAssetSelection(selected_keys=valid_asset_key_sequence) + assert isinstance(test, KeysAssetSelection) + + with pytest.raises(ValidationError): + KeyPrefixesAssetSelection(selected_key_prefixes=invalid_argument, include_sources=False) + test = KeyPrefixesAssetSelection( + selected_key_prefixes=valid_string_sequence_sequence, include_sources=False + ) + assert isinstance(test, KeyPrefixesAssetSelection) + + with pytest.raises(ValidationError): + UpstreamAssetSelection(child=invalid_argument, depth=0, include_self=False) + test = UpstreamAssetSelection(child=valid_asset_selection, depth=0, include_self=False) + assert isinstance(test, UpstreamAssetSelection) - assert (foo & bar) | baz == OrAssetSelection([AndAssetSelection([foo, bar]), baz]) - assert foo & (bar | baz) == AndAssetSelection([foo, OrAssetSelection([bar, baz])]) - assert (foo | bar) & baz == AndAssetSelection([OrAssetSelection([foo, bar]), baz]) - assert foo | (bar & baz) == OrAssetSelection([foo, AndAssetSelection([bar, baz])]) + with pytest.raises(ValidationError): + ParentSourcesAssetSelection(child=invalid_argument) + test = ParentSourcesAssetSelection(child=valid_asset_selection) + assert isinstance(test, ParentSourcesAssetSelection) def test_all_asset_selection_subclasses_serializable(): @@ -416,7 +526,7 @@ def test_all_asset_selection_subclasses_serializable(): def test_to_serializable_asset_selection(): - class UnserializableAssetSelection(AssetSelection): + class UnserializableAssetSelection(AssetSelection, frozen=True): def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return asset_graph.materializable_asset_keys - {AssetKey("asset2")} diff --git a/python_modules/dagster/dagster_tests/core_tests/host_representation_tests/test_external_sensor_data.py b/python_modules/dagster/dagster_tests/core_tests/host_representation_tests/test_external_sensor_data.py index 7e5db3f06a685..fa85a672093be 100644 --- a/python_modules/dagster/dagster_tests/core_tests/host_representation_tests/test_external_sensor_data.py +++ b/python_modules/dagster/dagster_tests/core_tests/host_representation_tests/test_external_sensor_data.py @@ -27,7 +27,7 @@ def asset1(): def asset2(): ... - class MySpecialAssetSelection(AssetSelection): + class MySpecialAssetSelection(AssetSelection, frozen=True): def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: return asset_graph.materializable_asset_keys - {AssetKey("asset2")} diff --git a/python_modules/dagster/dagster_tests/definitions_tests/decorators_tests/test_asset_decorator_with_check_specs.py b/python_modules/dagster/dagster_tests/definitions_tests/decorators_tests/test_asset_decorator_with_check_specs.py index ccde4f52f98ba..4ec8f8f7ed6ac 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/decorators_tests/test_asset_decorator_with_check_specs.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/decorators_tests/test_asset_decorator_with_check_specs.py @@ -675,7 +675,7 @@ def foo(context: AssetExecutionContext): result = materialize( [foo], selection=AssetCheckKeysSelection( - [AssetCheckKey(asset_key=AssetKey("asset1"), name="check1")] + selected_asset_check_keys=[AssetCheckKey(asset_key=AssetKey("asset1"), name="check1")] ), ) diff --git a/python_modules/libraries/dagster-dbt/dagster_dbt/asset_utils.py b/python_modules/libraries/dagster-dbt/dagster_dbt/asset_utils.py index cf2e6a2d9b8fa..79c9d694036f7 100644 --- a/python_modules/libraries/dagster-dbt/dagster_dbt/asset_utils.py +++ b/python_modules/libraries/dagster-dbt/dagster_dbt/asset_utils.py @@ -227,7 +227,7 @@ def all_dbt_assets(): manifest, dagster_dbt_translator = get_manifest_and_translator_from_dbt_assets(dbt_assets) from .dbt_manifest_asset_selection import DbtManifestAssetSelection - return DbtManifestAssetSelection( + return DbtManifestAssetSelection.build( manifest=manifest, dagster_dbt_translator=dagster_dbt_translator, select=dbt_select, diff --git a/python_modules/libraries/dagster-dbt/dagster_dbt/dbt_manifest_asset_selection.py b/python_modules/libraries/dagster-dbt/dagster_dbt/dbt_manifest_asset_selection.py index 2a57e200affd5..b631182770c52 100644 --- a/python_modules/libraries/dagster-dbt/dagster_dbt/dbt_manifest_asset_selection.py +++ b/python_modules/libraries/dagster-dbt/dagster_dbt/dbt_manifest_asset_selection.py @@ -1,4 +1,4 @@ -from typing import AbstractSet, Optional +from typing import AbstractSet, Any, Mapping, Optional from dagster import ( AssetKey, @@ -18,7 +18,11 @@ ) -class DbtManifestAssetSelection(AssetSelection): +class DbtManifestAssetSelection( + AssetSelection, + frozen=True, + arbitrary_types_allowed=True, +): """Defines a selection of assets from a dbt manifest wrapper and a dbt selection string. Args: @@ -40,22 +44,30 @@ class DbtManifestAssetSelection(AssetSelection): my_selection = DbtManifestAssetSelection(manifest=manifest, select="tag:foo") """ - def __init__( - self, + manifest: Mapping[str, Any] + select: str + dagster_dbt_translator: DagsterDbtTranslator + exclude: str + + @classmethod + def build( + cls, manifest: DbtManifestParam, select: str = "fqn:*", *, dagster_dbt_translator: Optional[DagsterDbtTranslator] = None, exclude: Optional[str] = None, - ) -> None: - self.manifest = validate_manifest(manifest) - self.select = check.str_param(select, "select") - self.exclude = check.opt_str_param(exclude, "exclude", default="") - self.dagster_dbt_translator = check.opt_inst_param( - dagster_dbt_translator, - "dagster_dbt_translator", - DagsterDbtTranslator, - DagsterDbtTranslator(), + ): + return cls( + manifest=validate_manifest(manifest), + select=check.str_param(select, "select"), + dagster_dbt_translator=check.opt_inst_param( + dagster_dbt_translator, + "dagster_dbt_translator", + DagsterDbtTranslator, + DagsterDbtTranslator(), + ), + exclude=check.opt_str_param(exclude, "exclude", default=""), ) def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]: