Skip to content

Commit

Permalink
Add anonymous retry (#854)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Feb 16, 2022
1 parent b175324 commit f6e38dd
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
22 changes: 17 additions & 5 deletions flytekit/extras/persistence/s3_awscli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import os as _os
import re as _re
import string as _string
Expand All @@ -10,8 +9,11 @@
from flytekit.configuration import aws
from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins
from flytekit.exceptions.user import FlyteUserException
from flytekit.loggers import logger
from flytekit.tools import subprocess

S3_ANONYMOUS_FLAG = "--no-sign-request"


def _update_cmd_config_and_execute(cmd: List[str]):
env = _os.environ.copy()
Expand All @@ -32,16 +34,26 @@ def _update_cmd_config_and_execute(cmd: List[str]):
retry = 0
while True:
try:
return subprocess.check_call(cmd, env=env)
try:
return subprocess.check_call(cmd, env=env)
except Exception as e:
if retry > 0:
logger.info(f"AWS command failed with error {e}, command: {cmd}, retry {retry}")

logger.debug(f"Appending anonymous flag and retrying command {cmd}")
anonymous_cmd = cmd[:] # strings only, so this is deep enough
anonymous_cmd.insert(1, S3_ANONYMOUS_FLAG)
return subprocess.check_call(anonymous_cmd, env=env)

except Exception as e:
logging.error(f"Exception when trying to execute {cmd}, reason: {str(e)}")
logger.error(f"Exception when trying to execute {cmd}, reason: {str(e)}")
retry += 1
if retry > aws.RETRIES.get():
raise
secs = aws.BACKOFF_SECONDS.get()
logging.info(f"Sleeping before retrying again, after {secs} seconds")
logger.info(f"Sleeping before retrying again, after {secs} seconds")
time.sleep(secs)
logging.info("Retrying again")
logger.info("Retrying again")


def _extra_args(extra_args: Dict[str, str]) -> List[str]:
Expand Down
26 changes: 23 additions & 3 deletions plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,34 @@ def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]:
return f, t

def exists(self, path: str) -> bool:
fs = self._get_filesystem(path)
return fs.exists(path)
try:
fs = self._get_filesystem(path)
return fs.exists(path)
except OSError as oe:
logger.debug(f"Error in exists checking {path} {oe}")
protocol = FSSpecPersistence._get_protocol(path)
if protocol == "s3":
logger.debug("S3 source detected, attempting anonymous S3 exists check")
kwargs = s3_setup_args()
anonymous_fs = fsspec.filesystem(protocol, anon=True, **kwargs) # type: ignore
return anonymous_fs.exists(path)
raise oe

def get(self, from_path: str, to_path: str, recursive: bool = False):
fs = self._get_filesystem(from_path)
if recursive:
from_path, to_path = self.recursive_paths(from_path, to_path)
return fs.get(from_path, to_path, recursive=recursive)
try:
return fs.get(from_path, to_path, recursive=recursive)
except OSError as oe:
logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}")
protocol = FSSpecPersistence._get_protocol(from_path)
if protocol == "s3":
logger.debug("S3 source detected, attempting anonymous S3 access")
kwargs = s3_setup_args()
anonymous_fs = fsspec.filesystem(protocol, anon=True, **kwargs) # type: ignore
return anonymous_fs.get(from_path, to_path, recursive=recursive)
raise oe

def put(self, from_path: str, to_path: str, recursive: bool = False):
fs = self._get_filesystem(to_path)
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/extras/persistence/test_s3_awscli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_retries(mock_subprocess, mock_delay, mock_check):

proxy = S3Persistence()
assert proxy.exists("s3://test/fdsa/fdsa") is False
assert mock_subprocess.check_call.call_count == 4
assert mock_subprocess.check_call.call_count == 8


def test_extra_args():
Expand Down

0 comments on commit f6e38dd

Please sign in to comment.