Skip to content

Commit

Permalink
make getting the filesystem to be async and await it
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Oct 22, 2024
1 parent 0d9e298 commit 7641697
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def get_filesystem(

return fsspec.filesystem(protocol, **kwargs)

def get_async_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> AsyncFileSystem:
async def get_async_filesystem_for_path(
self, path: str = "", anonymous: bool = False, **kwargs
) -> Union[AsyncFileSystem, fsspec.AbstractFileSystem]:
protocol = get_protocol(path)
loop = asyncio.get_running_loop()

Expand Down Expand Up @@ -293,7 +295,7 @@ def exists(self, path: str) -> bool:

@retry_request
async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
file_system = self.get_async_filesystem_for_path(from_path)
file_system = await self.get_async_filesystem_for_path(from_path)
if recursive:
from_path, to_path = self.recursive_paths(from_path, to_path)
try:
Expand Down Expand Up @@ -330,7 +332,7 @@ async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kw
More of an internal function to be called by put_data and put_raw_data
This does not need a separate sync function.
"""
file_system = self.get_async_filesystem_for_path(to_path)
file_system = await self.get_async_filesystem_for_path(to_path)
from_path = self.strip_file_header(from_path)
if recursive:
# Only check this for the local filesystem
Expand Down Expand Up @@ -419,7 +421,7 @@ async def async_put_raw_data(

# raw bytes
if isinstance(lpath, bytes):
fs = self.get_async_filesystem_for_path(to_path)
fs = await self.get_async_filesystem_for_path(to_path)

Check warning on line 424 in flytekit/core/data_persistence.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/data_persistence.py#L424

Added line #L424 was not covered by tests
if isinstance(fs, AsyncFileSystem):
async with fs.open_async(to_path, "wb", **kwargs) as s:
s.write(lpath)

Check warning on line 427 in flytekit/core/data_persistence.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/data_persistence.py#L426-L427

Added lines #L426 - L427 were not covered by tests
Expand All @@ -433,7 +435,7 @@ async def async_put_raw_data(
if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO):
if not lpath.readable():
raise FlyteAssertion("Buffered reader must be readable")
fs = self.get_async_filesystem_for_path(to_path)
fs = await self.get_async_filesystem_for_path(to_path)

Check warning on line 438 in flytekit/core/data_persistence.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/data_persistence.py#L438

Added line #L438 was not covered by tests
lpath.seek(0)
if isinstance(fs, AsyncFileSystem):
async with fs.open_async(to_path, "wb", **kwargs) as s:

Check warning on line 441 in flytekit/core/data_persistence.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/data_persistence.py#L441

Added line #L441 was not covered by tests
Expand All @@ -448,7 +450,7 @@ async def async_put_raw_data(
if isinstance(lpath, io.StringIO):
if not lpath.readable():
raise FlyteAssertion("Buffered reader must be readable")
fs = self.get_async_filesystem_for_path(to_path)
fs = await self.get_async_filesystem_for_path(to_path)
lpath.seek(0)
if isinstance(fs, AsyncFileSystem):
async with fs.open_async(to_path, "wb", **kwargs) as s:

Check warning on line 456 in flytekit/core/data_persistence.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/data_persistence.py#L456

Added line #L456 was not covered by tests
Expand Down

0 comments on commit 7641697

Please sign in to comment.