diff --git a/flytekit/__init__.py b/flytekit/__init__.py index c8ab6eb438..e7ca1c2abd 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -2,4 +2,4 @@ import flytekit.plugins # noqa: F401 -__version__ = "0.12.0" +__version__ = "0.12.1" diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 7ad447de4a..98e1a464ac 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -53,7 +53,7 @@ def _map_job_index_to_child_index(local_input_dir, datadir, index): @_scopes.system_entry_point -def _execute_task(task_module, task_name, inputs, output_prefix, test): +def _execute_task(task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test): with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()): with _utils.AutoDeletingTempDir("input_dir") as input_dir: # Load user code @@ -83,7 +83,8 @@ def _execute_task(task_module, task_name, inputs, output_prefix, test): _data_proxy.Data.get_data(inputs, local_inputs_file) input_proto = _utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file) _engine_loader.get_engine().get_task(task_def).execute( - _literal_models.LiteralMap.from_flyte_idl(input_proto), context={"output_prefix": output_prefix}, + _literal_models.LiteralMap.from_flyte_idl(input_proto), + context={"output_prefix": output_prefix, "raw_output_data_prefix": raw_output_data_prefix}, ) @@ -97,10 +98,17 @@ def _pass_through(): @_click.option("--task-name", required=True) @_click.option("--inputs", required=True) @_click.option("--output-prefix", required=True) +@_click.option("--raw-output-data-prefix", required=False) @_click.option("--test", is_flag=True) -def execute_task_cmd(task_module, task_name, inputs, output_prefix, test): +def execute_task_cmd(task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test): _click.echo(_utils.get_version_message()) - _execute_task(task_module, task_name, inputs, output_prefix, test) + # Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original + # template string, so let's explicitly set it to None so that the downstream functions will know to fall back + # to the original shard formatter/prefix config. + if raw_output_data_prefix == "{{.rawOutputDataPrefix}}": + raw_output_data_prefix = None + + _execute_task(task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test) if __name__ == "__main__": diff --git a/flytekit/common/tasks/sdk_runnable.py b/flytekit/common/tasks/sdk_runnable.py index f868475037..08dd815fa5 100644 --- a/flytekit/common/tasks/sdk_runnable.py +++ b/flytekit/common/tasks/sdk_runnable.py @@ -447,6 +447,8 @@ def _get_container_definition( "{{.input}}", "--output-prefix", "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", ], resources=_task_models.Resources(limits=limits, requests=requests), env=environment, diff --git a/flytekit/contrib/notebook/tasks.py b/flytekit/contrib/notebook/tasks.py index e78ae859ed..ca47eabd62 100644 --- a/flytekit/contrib/notebook/tasks.py +++ b/flytekit/contrib/notebook/tasks.py @@ -329,6 +329,8 @@ def container(self): "{{.input}}", "--output-prefix", "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", ] return self._container diff --git a/flytekit/engines/common.py b/flytekit/engines/common.py index 68c8c6f1d4..48d7864c77 100644 --- a/flytekit/engines/common.py +++ b/flytekit/engines/common.py @@ -396,12 +396,13 @@ def fetch_workflow(self, workflow_id): class EngineContext(object): - def __init__(self, execution_date, tmp_dir, stats, execution_id, logging): + def __init__(self, execution_date, tmp_dir, stats, execution_id, logging, raw_output_data_prefix=None): self._stats = stats self._execution_date = execution_date self._working_directory = tmp_dir self._execution_id = execution_id self._logging = logging + self._raw_output_data_prefix = raw_output_data_prefix @property def stats(self): @@ -437,3 +438,7 @@ def execution_id(self): :rtype: flytekit.models.core.identifier.WorkflowExecutionIdentifier """ return self._execution_id + + @property + def raw_output_data_prefix(self) -> str: + return self._raw_output_data_prefix diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py index 458eca15c8..32b7c39ea4 100644 --- a/flytekit/engines/flyte/engine.py +++ b/flytekit/engines/flyte/engine.py @@ -270,11 +270,11 @@ def execute(self, inputs, context=None): :param dict[Text, Text] context: :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] """ - with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir: with _common_utils.AutoDeletingTempDir("task_dir") as task_dir: with _data_proxy.LocalWorkingDirectoryContext(task_dir): - with _data_proxy.RemoteDataContext(): + raw_output_data_prefix = context.get("raw_output_data_prefix", None) + with _data_proxy.RemoteDataContext(raw_output_data_prefix_override=raw_output_data_prefix): output_file_dict = dict() # This sets the logging level for user code and is the only place an sdk setting gets @@ -311,6 +311,9 @@ def execute(self, inputs, context=None): ), logging=_logging, tmp_dir=task_dir, + raw_output_data_prefix=context["raw_output_data_prefix"] + if "raw_output_data_prefix" in context + else None, ), inputs, ) diff --git a/flytekit/interfaces/data/data_proxy.py b/flytekit/interfaces/data/data_proxy.py index e51659cdd0..9a24135204 100644 --- a/flytekit/interfaces/data/data_proxy.py +++ b/flytekit/interfaces/data/data_proxy.py @@ -64,22 +64,23 @@ def __init__(self, sandbox): class RemoteDataContext(_OutputDataContext): _CLOUD_PROVIDER_TO_PROXIES = { - _constants.CloudProvider.AWS: _s3proxy.AwsS3Proxy(), - _constants.CloudProvider.GCP: _gcs_proxy.GCSProxy(), + _constants.CloudProvider.AWS: _s3proxy.AwsS3Proxy, + _constants.CloudProvider.GCP: _gcs_proxy.GCSProxy, } - def __init__(self, cloud_provider=None): + def __init__(self, cloud_provider=None, raw_output_data_prefix_override=None): """ :param Optional[Text] cloud_provider: From flytekit.common.constants.CloudProvider enum """ cloud_provider = cloud_provider or _platform_config.CLOUD_PROVIDER.get() - proxy = type(self)._CLOUD_PROVIDER_TO_PROXIES.get(cloud_provider, None) - if proxy is None: + proxy_class = type(self)._CLOUD_PROVIDER_TO_PROXIES.get(cloud_provider, None) + if proxy_class is None: raise _user_exception.FlyteAssertion( "Configured cloud provider is not supported for data I/O. Received: {}, expected one of: {}".format( cloud_provider, list(type(self)._CLOUD_PROVIDER_TO_PROXIES.keys()) ) ) + proxy = proxy_class(raw_output_data_prefix_override) super(RemoteDataContext, self).__init__(proxy) diff --git a/flytekit/interfaces/data/gcs/gcs_proxy.py b/flytekit/interfaces/data/gcs/gcs_proxy.py index 2313b38136..3b40e664d3 100644 --- a/flytekit/interfaces/data/gcs/gcs_proxy.py +++ b/flytekit/interfaces/data/gcs/gcs_proxy.py @@ -28,6 +28,19 @@ def _amend_path(path): class GCSProxy(_common_data.DataProxy): _GS_UTIL_CLI = "gsutil" + def __init__(self, raw_output_data_prefix_override: str = None): + """ + :param raw_output_data_prefix_override: Instead of relying on the AWS or GCS configuration (see + S3_SHARD_FORMATTER for AWS and GCS_PREFIX for GCP) setting when computing the shard + path (_get_shard_path), use this prefix instead as a base. This code assumes that the + path passed in is correct. That is, an S3 path won't be passed in when running on GCP. + """ + self._raw_output_data_prefix_override = raw_output_data_prefix_override + + @property + def raw_output_data_prefix_override(self) -> str: + return self._raw_output_data_prefix_override + @staticmethod def _check_binary(): """ @@ -119,12 +132,14 @@ def upload_directory(self, local_path, remote_path): ) return _update_cmd_config_and_execute(cmd) - def get_random_path(self): + def get_random_path(self) -> str: """ - :rtype: Text + If this object was created with a raw output data prefix, usually set by Propeller/Plugins at execution time + and piped all the way here, it will be used instead of referencing the GCS_PREFIX configuration. """ key = _uuid.UUID(int=_flyte_random.random.getrandbits(128)).hex - return _os.path.join(_gcp_config.GCS_PREFIX.get(), key) + prefix = self.raw_output_data_prefix_override or _gcp_config.GCS_PREFIX.get() + return _os.path.join(prefix, key) def get_random_directory(self): """ diff --git a/flytekit/interfaces/data/s3/s3proxy.py b/flytekit/interfaces/data/s3/s3proxy.py index 8ed560b35d..5948b7d59f 100644 --- a/flytekit/interfaces/data/s3/s3proxy.py +++ b/flytekit/interfaces/data/s3/s3proxy.py @@ -41,6 +41,19 @@ class AwsS3Proxy(_common_data.DataProxy): _AWS_CLI = "aws" _SHARD_CHARACTERS = [_text_type(x) for x in _six_moves.range(10)] + list(_string.ascii_lowercase) + def __init__(self, raw_output_data_prefix_override: str = None): + """ + :param raw_output_data_prefix_override: Instead of relying on the AWS or GCS configuration (see + S3_SHARD_FORMATTER for AWS and GCS_PREFIX for GCP) setting when computing the shard + path (_get_shard_path), use this prefix instead as a base. This code assumes that the + path passed in is correct. That is, an S3 path won't be passed in when running on GCP. + """ + self._raw_output_data_prefix_override = raw_output_data_prefix_override + + @property + def raw_output_data_prefix_override(self) -> str: + return self._raw_output_data_prefix_override + @staticmethod def _check_binary(): """ @@ -179,10 +192,14 @@ def get_random_directory(self): """ return self.get_random_path() + "/" - def _get_shard_path(self): + def _get_shard_path(self) -> str: """ - :rtype: Text + If this object was created with a raw output data prefix, usually set by Propeller/Plugins at execution time + and piped all the way here, it will be used instead of referencing the S3 shard configuration. """ + if self.raw_output_data_prefix_override: + return self.raw_output_data_prefix_override + shard = "" for _ in _six_moves.range(_aws_config.S3_SHARD_STRING_LENGTH.get()): shard += _flyte_random.random.choice(self._SHARD_CHARACTERS) diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index ea2ac10568..1060a3658e 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -2,6 +2,7 @@ import os +import mock import six from click.testing import CliRunner from flyteidl.core import literals_pb2 as _literals_pb2 @@ -38,6 +39,7 @@ def test_single_step_entrypoint_in_proc(): _task_defs.add_one.task_function_name, input_file, output_dir.name, + output_dir.name, False, ) @@ -113,7 +115,12 @@ def test_arrayjob_entrypoint_in_proc(): os.environ["AWS_BATCH_JOB_ARRAY_INDEX"] = "0" _execute_task( - _task_defs.add_one.task_module, _task_defs.add_one.task_function_name, dir.name, dir.name, False, + _task_defs.add_one.task_module, + _task_defs.add_one.task_function_name, + dir.name, + dir.name, + dir.name, + False, ) raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( @@ -132,3 +139,26 @@ def test_arrayjob_entrypoint_in_proc(): os.environ["BATCH_JOB_ARRAY_INDEX_VAR_NAME"] = orig_env_index_var_name if orig_env_array_index: os.environ["AWS_BATCH_JOB_ARRAY_INDEX"] = orig_env_array_index + + +@mock.patch("flytekit.bin.entrypoint._execute_task") +def test_backwards_compatible_replacement(mock_execute_task): + def return_args(*args, **kwargs): + assert args[4] is None + + mock_execute_task.side_effect = return_args + + with _TemporaryConfiguration( + os.path.join(os.path.dirname(__file__), "fake.config"), + internal_overrides={"project": "test", "domain": "development"}, + ): + with _utils.AutoDeletingTempDir("in"): + with _utils.AutoDeletingTempDir("out"): + cmd = [] + cmd.extend(["--task-module", "fake"]) + cmd.extend(["--task-name", "fake"]) + cmd.extend(["--inputs", "fake"]) + cmd.extend(["--output-prefix", "fake"]) + cmd.extend(["--raw-output-data-prefix", "{{.rawOutputDataPrefix}}"]) + result = CliRunner().invoke(execute_task_cmd, cmd) + assert result.exit_code == 0 diff --git a/tests/flytekit/unit/interfaces/data/gcs/test_gcs_proxy.py b/tests/flytekit/unit/interfaces/data/gcs/test_gcs_proxy.py index d1b29232d1..ea799ccf4e 100644 --- a/tests/flytekit/unit/interfaces/data/gcs/test_gcs_proxy.py +++ b/tests/flytekit/unit/interfaces/data/gcs/test_gcs_proxy.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import os as _os import mock as _mock @@ -69,3 +67,14 @@ def test_upload_directory_with_parallelism(mock_update_cmd_config_and_execute, g local_path, remote_path = "/foo/*", "gs://bar/0/" gcs_proxy.upload_directory(local_path, remote_path) mock_update_cmd_config_and_execute.assert_called_once_with(["gsutil", "-m", "cp", "-r", local_path, remote_path]) + + +def test_raw_prefix_property(mock_update_cmd_config_and_execute, gsutil_parallelism, gcs_proxy): + gcs_with_raw_prefix = _gcs_proxy.GCSProxy("gcs://stuff") + assert gcs_with_raw_prefix.raw_output_data_prefix_override == "gcs://stuff" + + +def test_random_path(mock_update_cmd_config_and_execute, gsutil_parallelism, gcs_proxy): + gcs_with_raw_prefix = _gcs_proxy.GCSProxy("gcs://stuff") + result = gcs_with_raw_prefix.get_random_path() + assert result.startswith("gcs://stuff") diff --git a/tests/flytekit/unit/interfaces/data/s3/test_s3_proxy.py b/tests/flytekit/unit/interfaces/data/s3/test_s3_proxy.py new file mode 100644 index 0000000000..38978aa38e --- /dev/null +++ b/tests/flytekit/unit/interfaces/data/s3/test_s3_proxy.py @@ -0,0 +1,23 @@ +import mock as _mock + +from flytekit.interfaces.data.s3.s3proxy import AwsS3Proxy as _AwsS3Proxy + + +def test_property(): + aws = _AwsS3Proxy("s3://raw-output") + assert aws.raw_output_data_prefix_override == "s3://raw-output" + + +@_mock.patch("flytekit.configuration.aws.S3_SHARD_FORMATTER") +def test_random_path(mock_formatter): + mock_formatter.get.return_value = "s3://flyte/{}/" + + # Without raw output data prefix override + aws = _AwsS3Proxy() + p = str(aws.get_random_path()) + assert p.startswith("s3://flyte") + + # With override + aws = _AwsS3Proxy("s3://raw-output") + p = str(aws.get_random_path()) + assert p.startswith("s3://raw-output") diff --git a/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py b/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py index be903a9c82..532a92d5a8 100644 --- a/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import mock from k8s.io.api.core.v1 import generated_pb2 @@ -62,6 +60,8 @@ def test_dynamic_sidecar_task(): "{{.input}}", "--output-prefix", "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", ] assert primary_container["volumeMounts"] == [{"mountPath": "/scratch", "name": "scratch"}] assert {"name": "foo", "value": "bar"} in primary_container["env"] diff --git a/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py index 72989eeffc..503202706f 100644 --- a/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py @@ -58,6 +58,8 @@ def test_sidecar_task(): "{{.input}}", "--output-prefix", "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", ] assert primary_container["volumeMounts"] == [{"mountPath": "some/where", "name": "volume mount"}] assert {"name": "foo", "value": "bar"} in primary_container["env"]