diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index fba84187b3..47e5b146c8 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -273,6 +273,24 @@ def __init__( self.reference_artifact: Optional[Artifact] = None self.granularity = granularity + def __rich_repr__(self): + if self.value: + if isinstance(self.value, art_id.LabelValue): + if self.value.HasField("time_value"): + yield "Time Partition", str(self.value.time_value.ToDatetime()) + elif self.value.HasField("input_binding"): + yield "Time Partition (bound to)", self.value.input_binding.var + else: + yield "Time Partition", "unspecified" + else: + yield "Time Partition", "unspecified" + + def _repr_html_(self): + """ + Jupyter notebook rendering. + """ + return "".join([str(x) for x in self.__rich_repr__()]) + def __add__(self, other: timedelta) -> TimePartition: tp = TimePartition(self.value, op=Op.PLUS, other=other, granularity=self.granularity) tp.reference_artifact = self.reference_artifact @@ -293,6 +311,15 @@ def __init__(self, value: Optional[art_id.LabelValue], name: str): self.value = value self.reference_artifact: Optional[Artifact] = None + def __rich_repr__(self): + yield self.name, self.value + + def _repr_html_(self): + """ + Jupyter notebook rendering. + """ + return "".join([f"{x[0]}: {x[1]}" for x in self.__rich_repr__()]) + class Partitions(object): def __init__(self, partitions: Optional[typing.Mapping[str, Union[str, art_id.InputBindingData, Partition]]]): @@ -307,6 +334,19 @@ def __init__(self, partitions: Optional[typing.Mapping[str, Union[str, art_id.In self._partitions[k] = Partition(art_id.LabelValue(static_value=v), name=k) self.reference_artifact: Optional[Artifact] = None + def __rich_repr__(self): + if self.partitions: + ps = [str(next(v.__rich_repr__())) for k, v in self.partitions.items()] + yield "Partitions", ", ".join(ps) + else: + yield "" + + def _repr_html_(self): + """ + Jupyter notebook rendering. + """ + return ", ".join([str(x) for x in self.__rich_repr__()]) + @property def partitions(self) -> Optional[typing.Dict[str, Partition]]: return self._partitions @@ -562,7 +602,8 @@ def embed_as_query( op: Optional[Op] = None, ) -> art_id.ArtifactQuery: """ - This should only be called in the context of a Trigger + This should only be called in the context of a Trigger. The type of query this returns is different from the + query() function. This type of query is used to reference the triggering artifact, rather than running a query. :param partition: Can embed a time partition :param bind_to_time_partition: Set to true if you want to bind to a time partition :param expr: Only valid if there's a time partition. diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 5b0eb62c65..15c03059bb 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -2023,8 +2023,6 @@ class LiteralsResolver(collections.UserDict): LiteralsResolver is a helper class meant primarily for use with the FlyteRemote experience or any other situation where you might be working with LiteralMaps. This object allows the caller to specify the Python type that should correspond to an element of the map. - - TODO: Consider inheriting from collections.UserDict instead of manually having the _native_values cache """ def __init__( diff --git a/flytekit/models/types.py b/flytekit/models/types.py index 23f818e7a6..9fac15fa79 100644 --- a/flytekit/models/types.py +++ b/flytekit/models/types.py @@ -283,12 +283,38 @@ def __init__( self._enum_type = enum_type self._union_type = union_type self._structured_dataset_type = structured_dataset_type - self._metadata = metadata self._structure = structure self._structured_dataset_type = structured_dataset_type self._metadata = metadata self._annotation = annotation + def __rich_repr__(self): + if self.simple: + yield "Simple" + elif self.schema: + yield "Schema" + elif self.collection_type: + sub = next(self.collection_type.__rich_repr__()) + yield f"List[{sub}]" + elif self.map_value_type: + sub = next(self.map_value_type.__rich_repr__()) + yield f"Dict[str, {sub}]" + elif self.blob: + if self.blob.dimensionality == _types_pb2.BlobType.BlobDimensionality.SINGLE: + yield "File" + elif self.blob.dimensionality == _types_pb2.BlobType.BlobDimensionality.MULTIPART: + yield "Directory" + else: + yield "Unknown Blob Type" + elif self.enum_type: + yield "Enum" + elif self.union_type: + yield "Union" + elif self.structured_dataset_type: + yield f"StructuredDataset(format={self.structured_dataset_type.format})" + else: + yield "Unknown Type" + @property def simple(self) -> SimpleType: return self._simple diff --git a/tests/flytekit/unit/core/test_artifacts.py b/tests/flytekit/unit/core/test_artifacts.py index 9437d16add..34d19f50cb 100644 --- a/tests/flytekit/unit/core/test_artifacts.py +++ b/tests/flytekit/unit/core/test_artifacts.py @@ -615,6 +615,27 @@ def test_tp_math(): assert tp2 is not tp +def test_tp_printing(): + d = datetime.datetime(2063, 4, 5, 0, 0) + pt = Timestamp() + pt.FromDatetime(d) + tp = TimePartition(value=art_id.LabelValue(time_value=pt), granularity=Granularity.HOUR) + txt = "".join([str(x) for x in tp.__rich_repr__()]) + # should show something like ('Time Partition', '2063-04-05 00:00:00') + # just check that we don't accidentally fail to evaluate a generator + assert "generator" not in txt + + +def test_partition_printing(): + a1_b = Artifact(name="my_data", partition_keys=["b"]) + spec = a1_b(b="my_b_value") + ps = spec.partitions + txt = "".join([str(x) for x in ps.__rich_repr__()]) + # should look something like ('Partitions', '(\'b\', static_value: "my_b_value"\n)') + # just check that we don't accidentally fail to evaluate a generator + assert "generator" not in txt + + def test_lims(): # test an artifact with 11 partition keys with pytest.raises(ValueError):