Skip to content

Commit

Permalink
updated tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Dec 14, 2021
1 parent 97ab89f commit cd98150
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
6 changes: 5 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,12 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None)
# dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a
# JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the
# path in FlyteFile and FlyteDirectory could be a local path. Therefore, We reset the python value here,
# path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here,
# so that dataclass_json can always get a remote path.
# In other words, the file transformer has special code that handles the fact that if remote_source is
# set, then the real uri in the literal should be the remote source, not the path (which may be an
# auto-generated random local path). To be sure we're writing the right path to the json, use the uri
# as determined by the transformer.
if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory):
python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri))

Expand Down
13 changes: 10 additions & 3 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def t1(path: str) -> FileStruct:
@dynamic
def dyn(fs: FileStruct):
t2(fs=fs)
t3(fs=fs)

@task
def t2(fs: FileStruct) -> os.PathLike:
Expand All @@ -384,13 +385,19 @@ def t2(fs: FileStruct) -> os.PathLike:

return fs.a.path

@task
def t3(fs: FileStruct) -> FlyteFile:
return fs.a

@workflow
def wf(path: str) -> os.PathLike:
def wf(path: str) -> (os.PathLike, FlyteFile):
n1 = t1(path=path)
dyn(fs=n1)
return t2(fs=n1)
return t2(fs=n1), t3(fs=n1)

assert "/tmp/flyte/" in wf(path="s3://somewhere").path
assert "/tmp/flyte/" in wf(path="s3://somewhere")[0].path
assert "/tmp/flyte/" in wf(path="s3://somewhere")[1].path
assert "s3://somewhere" == wf(path="s3://somewhere")[1].remote_source


def test_flyte_directory_in_dataclass():
Expand Down

0 comments on commit cd98150

Please sign in to comment.