diff --git a/flytekit/extras/persistence/s3_awscli.py b/flytekit/extras/persistence/s3_awscli.py index 3b24fef94b..64e09e219c 100644 --- a/flytekit/extras/persistence/s3_awscli.py +++ b/flytekit/extras/persistence/s3_awscli.py @@ -1,4 +1,3 @@ -import logging import os as _os import re as _re import string as _string @@ -10,8 +9,11 @@ from flytekit.configuration import aws from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.exceptions.user import FlyteUserException +from flytekit.loggers import logger from flytekit.tools import subprocess +S3_ANONYMOUS_FLAG = "--no-sign-request" + def _update_cmd_config_and_execute(cmd: List[str]): env = _os.environ.copy() @@ -32,16 +34,26 @@ def _update_cmd_config_and_execute(cmd: List[str]): retry = 0 while True: try: - return subprocess.check_call(cmd, env=env) + try: + return subprocess.check_call(cmd, env=env) + except Exception as e: + if retry > 0: + logger.info(f"AWS command failed with error {e}, command: {cmd}, retry {retry}") + + logger.debug(f"Appending anonymous flag and retrying command {cmd}") + anonymous_cmd = cmd[:] # strings only, so this is deep enough + anonymous_cmd.insert(1, S3_ANONYMOUS_FLAG) + return subprocess.check_call(anonymous_cmd, env=env) + except Exception as e: - logging.error(f"Exception when trying to execute {cmd}, reason: {str(e)}") + logger.error(f"Exception when trying to execute {cmd}, reason: {str(e)}") retry += 1 if retry > aws.RETRIES.get(): raise secs = aws.BACKOFF_SECONDS.get() - logging.info(f"Sleeping before retrying again, after {secs} seconds") + logger.info(f"Sleeping before retrying again, after {secs} seconds") time.sleep(secs) - logging.info("Retrying again") + logger.info("Retrying again") def _extra_args(extra_args: Dict[str, str]) -> List[str]: diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py index d2ea879ce0..68bc92b493 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py @@ -67,14 +67,34 @@ def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]: return f, t def exists(self, path: str) -> bool: - fs = self._get_filesystem(path) - return fs.exists(path) + try: + fs = self._get_filesystem(path) + return fs.exists(path) + except OSError as oe: + logger.debug(f"Error in exists checking {path} {oe}") + protocol = FSSpecPersistence._get_protocol(path) + if protocol == "s3": + logger.debug("S3 source detected, attempting anonymous S3 exists check") + kwargs = s3_setup_args() + anonymous_fs = fsspec.filesystem(protocol, anon=True, **kwargs) # type: ignore + return anonymous_fs.exists(path) + raise oe def get(self, from_path: str, to_path: str, recursive: bool = False): fs = self._get_filesystem(from_path) if recursive: from_path, to_path = self.recursive_paths(from_path, to_path) - return fs.get(from_path, to_path, recursive=recursive) + try: + return fs.get(from_path, to_path, recursive=recursive) + except OSError as oe: + logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") + protocol = FSSpecPersistence._get_protocol(from_path) + if protocol == "s3": + logger.debug("S3 source detected, attempting anonymous S3 access") + kwargs = s3_setup_args() + anonymous_fs = fsspec.filesystem(protocol, anon=True, **kwargs) # type: ignore + return anonymous_fs.get(from_path, to_path, recursive=recursive) + raise oe def put(self, from_path: str, to_path: str, recursive: bool = False): fs = self._get_filesystem(to_path) diff --git a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py b/tests/flytekit/unit/extras/persistence/test_s3_awscli.py index 78f7d67b88..bcf1fd3495 100644 --- a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py +++ b/tests/flytekit/unit/extras/persistence/test_s3_awscli.py @@ -25,7 +25,7 @@ def test_retries(mock_subprocess, mock_delay, mock_check): proxy = S3Persistence() assert proxy.exists("s3://test/fdsa/fdsa") is False - assert mock_subprocess.check_call.call_count == 4 + assert mock_subprocess.check_call.call_count == 8 def test_extra_args():