Skip to content

Commit

Permalink
Pass through FlyteFile and FlyteDirectory if created from a remote so…
Browse files Browse the repository at this point in the history
…urce (#436)
  • Loading branch information
wild-endeavor authored Mar 26, 2021
1 parent e76b28d commit b4608ef
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 0 deletions.
5 changes: 5 additions & 0 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ def to_literal(
# There are two kinds of literals we handle, either an actual FlyteDirectory, or a string path to a directory.
# Handle the FlyteDirectory case
if isinstance(python_val, FlyteDirectory):
# If the object has a remote source, then we just convert it back.
if python_val._remote_source is not None:
meta = BlobMetadata(type=self._blob_type(format=self.get_format(python_type)))
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=python_val._remote_source)))

source_path = python_val.path
if python_val.remote_directory is False:
# If the user specified the remote_path to be False, that means no matter what, do not upload
Expand Down
5 changes: 5 additions & 0 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ def to_literal(
if python_val is None:
raise AssertionError("None value cannot be converted to a file.")
if isinstance(python_val, FlyteFile):
# If the object has a remote source, then we just convert it back.
if python_val._remote_source is not None:
meta = BlobMetadata(type=self._blob_type(format=self.get_format(python_type)))
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=python_val._remote_source)))

source_path = python_val.path
if python_val.remote_path is False:
# If the user specified the remote_path to be False, that means no matter what, do not upload
Expand Down
32 changes: 32 additions & 0 deletions tests/flytekit/unit/core/test_flyte_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

import flytekit
from flytekit.core import context_manager
from flytekit.core.context_manager import ExecutionState, Image, ImageConfig
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.interfaces.data.data_proxy import FileAccessProvider
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer


Expand Down Expand Up @@ -133,3 +136,32 @@ def wf2() -> int:

x = wf2()
assert x == 5


def test_dont_convert_remotes():
@task
def t1(in1: FlyteDirectory):
print(in1)

@dynamic
def dyn(in1: FlyteDirectory):
t1(in1=in1)

fd = FlyteDirectory("s3://anything")

with context_manager.FlyteContext.current_context().new_serialization_settings(
serialization_settings=context_manager.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
) as ctx:
with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
lit = TypeEngine.to_literal(
ctx, fd, FlyteDirectory, BlobType("", dimensionality=BlobType.BlobDimensionality.MULTIPART)
)
lm = LiteralMap(literals={"in1": lit})
wf = dyn.dispatch_execute(ctx, lm)
assert wf.nodes[0].inputs[0].binding.scalar.blob.uri == "s3://anything"
34 changes: 34 additions & 0 deletions tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

import flytekit
from flytekit.core import context_manager
from flytekit.core.context_manager import ExecutionState, Image, ImageConfig
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.interfaces.data.data_proxy import FileAccessProvider
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap
from flytekit.types.file.file import FlyteFile


Expand Down Expand Up @@ -207,3 +212,32 @@ def my_wf() -> FlyteFile:

# The file name is maintained on download.
assert str(workflow_output).endswith(os.path.split(SAMPLE_DATA)[1])


def test_dont_convert_remotes():
@task
def t1(in1: FlyteFile):
print(in1)

@dynamic
def dyn(in1: FlyteFile):
t1(in1=in1)

fd = FlyteFile("s3://anything")

with context_manager.FlyteContext.current_context().new_serialization_settings(
serialization_settings=context_manager.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
) as ctx:
with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
lit = TypeEngine.to_literal(
ctx, fd, FlyteFile, BlobType("", dimensionality=BlobType.BlobDimensionality.SINGLE)
)
lm = LiteralMap(literals={"in1": lit})
wf = dyn.dispatch_execute(ctx, lm)
assert wf.nodes[0].inputs[0].binding.scalar.blob.uri == "s3://anything"

0 comments on commit b4608ef

Please sign in to comment.