Skip to content

Commit

Permalink
tag asset selection
Browse files Browse the repository at this point in the history
branch-name: asset-selection-tags
  • Loading branch information
sryza committed Mar 18, 2024
1 parent cf3ba00 commit 35cdf77
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing_extensions import TypeAlias

import dagster._check as check
from dagster._annotations import deprecated, experimental_param, public
from dagster._annotations import deprecated, experimental, experimental_param, public
from dagster._core.definitions.asset_graph import AssetGraph
from dagster._core.definitions.resolved_asset_deps import resolve_similar_asset_names
from dagster._core.errors import DagsterInvalidSubsetError
Expand Down Expand Up @@ -171,6 +171,33 @@ def groups(*group_strs, include_sources: bool = False) -> "GroupsAssetSelection"
check.tuple_param(group_strs, "group_strs", of_type=str)
return GroupsAssetSelection(selected_groups=group_strs, include_sources=include_sources)

@public
@staticmethod
@experimental
def tag(cls, key: str, value: str, include_sources: bool = False) -> "AssetSelection":
"""Returns a selection that includes materializable assets that have the provided tag, and
all the asset checks that target them.
Args:
include_sources (bool): If True, then include source assets matching the group in the
selection.
"""
return TagAssetSelection(key=key, value=value, include_sources=include_sources)

@classmethod
def tag_string(cls, string: str, include_sources: bool = False) -> "AssetSelection":
"""Returns a selection that includes materializable assets that have the provided tag, and
all the asset checks that target them.
Args:
include_sources (bool): If True, then include source assets matching the group in the
selection.
"""
key, value = string.split("=")
return TagAssetSelection(key=key, value=value, include_sources=include_sources)

@public
@staticmethod
def checks_for_assets(*assets_defs: AssetsDefinition) -> "AssetChecksForAssetKeysSelection":
Expand Down Expand Up @@ -389,6 +416,10 @@ def from_string(cls, string: str) -> "AssetSelection":
selection = key_selection
return selection

elif string.startswith("tag:"):
tag_str = string[len("tag:") :]
return cls.tag_string(tag_str)

check.failed(f"Invalid selection string: {string}")

@classmethod
Expand Down Expand Up @@ -779,6 +810,27 @@ def __str__(self) -> str:
return f"group:({' or '.join(self.selected_groups)})"


@whitelist_for_serdes
class TagAssetSelection(AssetSelection, frozen=True):
key: str
value: str
include_sources: bool

def resolve_inner(
self, asset_graph: BaseAssetGraph, allow_missing: bool
) -> AbstractSet[AssetKey]:
base_set = (
asset_graph.all_asset_keys
if self.include_sources
else asset_graph.materializable_asset_keys
)

return {key for key in base_set if asset_graph.get(key).tags.get(self.key) == self.value}

def __str__(self) -> str:
return f"tag:{self.key}={self.value}"


@whitelist_for_serdes
class KeysAssetSelection(AssetSelection, frozen=True):
selected_keys: Sequence[AssetKey]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dagster import (
AssetIn,
AssetOut,
AssetSpec,
DailyPartitionsDefinition,
DimensionPartitionMapping,
IdentityPartitionMapping,
Expand Down Expand Up @@ -706,3 +707,39 @@ def test_deserialize_old_all_asset_selection():
old_serialized_value = '{"__class__": "AllSelection"}'
new_unserialized_value = deserialize_value(old_serialized_value, AllSelection)
assert not new_unserialized_value.include_sources


def test_from_string_tag():
assert AssetSelection.from_string("tag:foo=bar") == AssetSelection.tag("foo", "bar")


def test_tag():
@multi_asset(
specs=[
AssetSpec("asset1", tags={"foo": "fooval"}),
AssetSpec("asset2", tags={"foo": "fooval2"}),
AssetSpec("asset3", tags={"foo": "fooval", "bar": "barval"}),
AssetSpec("asset4", tags={"bar": "barval"}),
]
)
def assets(): ...

assert AssetSelection.tag("foo", "fooval").resolve([assets]) == {
AssetKey(k) for k in ["asset1", "asset3"]
}


def test_tag_string():
@multi_asset(
specs=[
AssetSpec("asset1", tags={"foo": "fooval"}),
AssetSpec("asset2", tags={"foo": "fooval2"}),
AssetSpec("asset3", tags={"foo": "fooval", "bar": "barval"}),
AssetSpec("asset4", tags={"bar": "barval"}),
]
)
def assets(): ...

assert AssetSelection.tag_string("foo=fooval").resolve([assets]) == {
AssetKey(k) for k in ["asset1", "asset3"]
}

0 comments on commit 35cdf77

Please sign in to comment.