Skip to content

Commit

Permalink
[Artifacts/Elastic] Skip partitions (#2620)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Jul 29, 2024
1 parent 5bc5d5c commit 3f5ba98
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
2 changes: 2 additions & 0 deletions flytekit/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def set_reference_artifact(self, artifact: Artifact):
p.reference_artifact = artifact

def __getattr__(self, item):
if item == "partitions" or item == "_partitions":
raise AttributeError("Partitions in an uninitialized state, skipping partitions")
if self.partitions and item in self.partitions:
return self.partitions[item]
raise AttributeError(f"Partition {item} not found in {self}")
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class ElasticWorkerResult(NamedTuple):

return_value: Any
decks: List[flytekit.Deck]
om: OutputMetadata
om: Optional[OutputMetadata] = None


def spawn_helper(
Expand Down Expand Up @@ -435,7 +435,7 @@ def fn_partial():
if isinstance(e, FlyteRecoverableException):
create_recoverable_error_file()
raise
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks)
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=None)

launcher_target_func = fn_partial
launcher_args = ()
Expand Down
12 changes: 12 additions & 0 deletions tests/flytekit/unit/core/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,15 @@ def test_lims():
# test an artifact with 11 partition keys
with pytest.raises(ValueError):
Artifact(name="test artifact", time_partitioned=True, partition_keys=[f"key_{i}" for i in range(11)])


def test_cloudpickle():
a1_b = Artifact(name="my_data", partition_keys=["b"])

spec = a1_b(b="my_b_value")
import cloudpickle

d = cloudpickle.dumps(spec)
spec2 = cloudpickle.loads(d)

assert spec2.partitions.b.value.static_value == "my_b_value"

0 comments on commit 3f5ba98

Please sign in to comment.