From 1a4844e49120815bd5420b9ef2e5cd19cf5cb4f9 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 6 Dec 2021 21:30:08 +0800 Subject: [PATCH 1/3] Fetch pickle value from flytekit remote Signed-off-by: Kevin Su --- flytekit/configuration/aws.py | 6 ++-- flytekit/remote/remote.py | 8 ++++- flytekit/types/file/file.py | 7 +++- flytekit/types/pickle/pickle.py | 6 ++++ .../tests/test_persist.py | 8 ++--- .../mock_flyte_repo/workflows/basic/pickle.py | 34 +++++++++++++++++++ .../integration/remote/test_remote.py | 9 +++++ tests/flytekit/unit/core/test_type_engine.py | 27 +++++++++++++++ 8 files changed, 96 insertions(+), 9 deletions(-) create mode 100644 tests/flytekit/integration/remote/mock_flyte_repo/workflows/basic/pickle.py diff --git a/flytekit/configuration/aws.py b/flytekit/configuration/aws.py index a6af62bf6f..785ee961de 100644 --- a/flytekit/configuration/aws.py +++ b/flytekit/configuration/aws.py @@ -4,11 +4,11 @@ S3_SHARD_STRING_LENGTH = _config_common.FlyteIntegerConfigurationEntry("aws", "s3_shard_string_length", default=2) -S3_ENDPOINT = _config_common.FlyteStringConfigurationEntry("aws", "endpoint", default=None) +S3_ENDPOINT = _config_common.FlyteStringConfigurationEntry("aws", "endpoint", default="http://localhost:30084") -S3_ACCESS_KEY_ID = _config_common.FlyteStringConfigurationEntry("aws", "access_key_id", default=None) +S3_ACCESS_KEY_ID = _config_common.FlyteStringConfigurationEntry("aws", "access_key_id", default="minio") -S3_SECRET_ACCESS_KEY = _config_common.FlyteStringConfigurationEntry("aws", "secret_access_key", default=None) +S3_SECRET_ACCESS_KEY = _config_common.FlyteStringConfigurationEntry("aws", "secret_access_key", default="miniostorage") S3_ACCESS_KEY_ID_ENV_NAME = "AWS_ACCESS_KEY_ID" diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 117abf27e3..2c3e8ab41d 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -237,7 +237,13 @@ def __init__( # Save the file access object locally, but also make it available for use from the context. FlyteContextManager.with_context(FlyteContextManager.current_context().with_file_access(file_access).build()) - self._file_access = file_access + raw_output_data_prefix = auth_config.RAW_OUTPUT_DATA_PREFIX.get() or os.path.join( + sdk_config.LOCAL_SANDBOX.get(), "control_plane_raw" + ) + self._file_access = file_access or FileAccessProvider( + local_sandbox_dir=os.path.join(sdk_config.LOCAL_SANDBOX.get(), "control_plane_metadata"), + raw_output_prefix=raw_output_data_prefix, + ) # TODO: Reconsider whether we want this. Probably best to not cache. self._serialized_entity_cache = OrderedDict() diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index eb64e5943d..9bff3a5dec 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -10,6 +10,7 @@ from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType +from flytekit.types.pickle.pickle import FlytePickleTransformer def noop(): @@ -348,7 +349,11 @@ def _downloader(): return ff def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]: - if literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format != FlytePickleTransformer.PYTHON_PICKLE_FORMAT + ): return FlyteFile.__class_getitem__(literal_type.blob.format) raise ValueError(f"Transformer {self} cannot reverse {literal_type}") diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 9219d3a8b4..57ead7ea09 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -78,6 +78,12 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp ctx.file_access.put_data(uri, remote_path, is_multipart=False) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlytePickle[typing.Any]]: + if literal_type.blob is not None and literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT: + return FlytePickle + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( diff --git a/plugins/flytekit-data-fsspec/tests/test_persist.py b/plugins/flytekit-data-fsspec/tests/test_persist.py index d2ac50a7f8..ce3a4d7e41 100644 --- a/plugins/flytekit-data-fsspec/tests/test_persist.py +++ b/plugins/flytekit-data-fsspec/tests/test_persist.py @@ -10,15 +10,15 @@ def test_s3_setup_args(): kwargs = s3_setup_args() - assert kwargs == {} + assert kwargs == {"client_kwargs": {"endpoint_url": "http://localhost:30084"}} - with aws.S3_ENDPOINT.get_patcher("http://localhost:30084"): + with aws.S3_ENDPOINT.get_patcher("http://flyte:30084"): kwargs = s3_setup_args() - assert kwargs == {"client_kwargs": {"endpoint_url": "http://localhost:30084"}} + assert kwargs == {"client_kwargs": {"endpoint_url": "http://flyte:30084"}} with aws.S3_ACCESS_KEY_ID.get_patcher("access"): kwargs = s3_setup_args() - assert kwargs == {} + assert kwargs == {"client_kwargs": {"endpoint_url": "http://localhost:30084"}} assert os.environ[aws.S3_ACCESS_KEY_ID_ENV_NAME] == "access" diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/basic/pickle.py b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/basic/pickle.py new file mode 100644 index 0000000000..3a5b6a3b2a --- /dev/null +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/basic/pickle.py @@ -0,0 +1,34 @@ +import typing + +from flytekit import task, workflow + + +class Foo(object): + def __init__(self, number: int): + self.number = number + + +@task +def t1(a: int) -> Foo: + return Foo(number=a) + + +@task +def t2(a: Foo) -> typing.List[Foo]: + return [a, a] + + +@task +def t3(a: typing.List[Foo]) -> typing.Dict[str, Foo]: + return {"hello": a[0]} + + +@workflow +def wf(a: int) -> typing.Dict[str, Foo]: + o1 = t1(a=a) + o2 = t2(a=o1) + return t3(a=o2) + + +if __name__ == "__main__": + print(f"Running wf(a=3) {wf(a=3)}") diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index a45c037279..689187784e 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -286,6 +286,15 @@ def test_execute_joblib_workflow(flyteclient, flyte_workflows_register, flyte_re assert output_obj == input_obj +def test_execute_pickle_workflow(flyteclient, flyte_workflows_register, flyte_remote_env): + remote = FlyteRemote.from_config(PROJECT, "development") + flyte_workflow = remote.fetch_workflow(name="workflows.basic.pickle.wf", version=f"v{VERSION}") + input_obj = 3 + execution = remote.execute(flyte_workflow, {"a": input_obj}, wait=True) + output = execution.outputs["o0"] + assert output.number == 3 + + def test_execute_with_default_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env): from mock_flyte_repo.workflows.basic.subworkflows import parent_wf diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 50a77c43d9..9f88a80fb2 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -356,6 +356,14 @@ def test_guessing_basic(): pt = TypeEngine.guess_python_type(lt) assert pt is None + lt = model_types.LiteralType( + blob=BlobType( + format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE + ) + ) + pt = TypeEngine.guess_python_type(lt) + assert pt is FlytePickle + def test_guessing_containers(): b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN) @@ -552,6 +560,25 @@ def test_enum_type(): TypeEngine.to_literal_type(UnsupportedEnumValues) +def test_pickle_type(): + class Foo(object): + def __init__(self, number: int): + self.number = number + + lt = TypeEngine.to_literal_type(FlytePickle) + assert lt.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT + assert lt.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, Foo(1), FlytePickle, lt) + assert "/tmp/flyte/" in lv.scalar.blob.uri + + transformer = FlytePickleTransformer() + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert Foo(1).number == pv.number + + def test_enum_in_dataclass(): @dataclass_json @dataclass From 972729fba66011068c592594f47d6c8ce9b731de Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 6 Dec 2021 23:54:32 +0800 Subject: [PATCH 2/3] Fix tests Signed-off-by: Kevin Su --- .../mock_flyte_repo/workflows/basic/pickle.py | 34 ------------------- .../integration/remote/test_remote.py | 9 ----- 2 files changed, 43 deletions(-) delete mode 100644 tests/flytekit/integration/remote/mock_flyte_repo/workflows/basic/pickle.py diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/basic/pickle.py b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/basic/pickle.py deleted file mode 100644 index 3a5b6a3b2a..0000000000 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/basic/pickle.py +++ /dev/null @@ -1,34 +0,0 @@ -import typing - -from flytekit import task, workflow - - -class Foo(object): - def __init__(self, number: int): - self.number = number - - -@task -def t1(a: int) -> Foo: - return Foo(number=a) - - -@task -def t2(a: Foo) -> typing.List[Foo]: - return [a, a] - - -@task -def t3(a: typing.List[Foo]) -> typing.Dict[str, Foo]: - return {"hello": a[0]} - - -@workflow -def wf(a: int) -> typing.Dict[str, Foo]: - o1 = t1(a=a) - o2 = t2(a=o1) - return t3(a=o2) - - -if __name__ == "__main__": - print(f"Running wf(a=3) {wf(a=3)}") diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 689187784e..a45c037279 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -286,15 +286,6 @@ def test_execute_joblib_workflow(flyteclient, flyte_workflows_register, flyte_re assert output_obj == input_obj -def test_execute_pickle_workflow(flyteclient, flyte_workflows_register, flyte_remote_env): - remote = FlyteRemote.from_config(PROJECT, "development") - flyte_workflow = remote.fetch_workflow(name="workflows.basic.pickle.wf", version=f"v{VERSION}") - input_obj = 3 - execution = remote.execute(flyte_workflow, {"a": input_obj}, wait=True) - output = execution.outputs["o0"] - assert output.number == 3 - - def test_execute_with_default_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env): from mock_flyte_repo.workflows.basic.subworkflows import parent_wf From f7b9450a8bee91f4626fdbfff94dd0ed21a9b6e0 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 7 Dec 2021 02:23:22 +0800 Subject: [PATCH 3/3] Remove default value Signed-off-by: Kevin Su --- flytekit/configuration/aws.py | 6 +++--- flytekit/remote/remote.py | 6 ++++-- flytekit/types/pickle/pickle.py | 6 +++++- plugins/flytekit-data-fsspec/tests/test_persist.py | 8 ++++---- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/flytekit/configuration/aws.py b/flytekit/configuration/aws.py index 785ee961de..a6af62bf6f 100644 --- a/flytekit/configuration/aws.py +++ b/flytekit/configuration/aws.py @@ -4,11 +4,11 @@ S3_SHARD_STRING_LENGTH = _config_common.FlyteIntegerConfigurationEntry("aws", "s3_shard_string_length", default=2) -S3_ENDPOINT = _config_common.FlyteStringConfigurationEntry("aws", "endpoint", default="http://localhost:30084") +S3_ENDPOINT = _config_common.FlyteStringConfigurationEntry("aws", "endpoint", default=None) -S3_ACCESS_KEY_ID = _config_common.FlyteStringConfigurationEntry("aws", "access_key_id", default="minio") +S3_ACCESS_KEY_ID = _config_common.FlyteStringConfigurationEntry("aws", "access_key_id", default=None) -S3_SECRET_ACCESS_KEY = _config_common.FlyteStringConfigurationEntry("aws", "secret_access_key", default="miniostorage") +S3_SECRET_ACCESS_KEY = _config_common.FlyteStringConfigurationEntry("aws", "secret_access_key", default=None) S3_ACCESS_KEY_ID_ENV_NAME = "AWS_ACCESS_KEY_ID" diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 2c3e8ab41d..60ae7bbf3e 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -235,8 +235,6 @@ def __init__( # Not exposing this as a property for now. self._entrypoint_settings = entrypoint_settings - # Save the file access object locally, but also make it available for use from the context. - FlyteContextManager.with_context(FlyteContextManager.current_context().with_file_access(file_access).build()) raw_output_data_prefix = auth_config.RAW_OUTPUT_DATA_PREFIX.get() or os.path.join( sdk_config.LOCAL_SANDBOX.get(), "control_plane_raw" ) @@ -244,6 +242,10 @@ def __init__( local_sandbox_dir=os.path.join(sdk_config.LOCAL_SANDBOX.get(), "control_plane_metadata"), raw_output_prefix=raw_output_data_prefix, ) + # Save the file access object locally, but also make it available for use from the context. + FlyteContextManager.with_context( + FlyteContextManager.current_context().with_file_access(self._file_access).build() + ) # TODO: Reconsider whether we want this. Probably best to not cache. self._serialized_entity_cache = OrderedDict() diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 57ead7ea09..3472dec7e6 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -79,7 +79,11 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlytePickle[typing.Any]]: - if literal_type.blob is not None and literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT + ): return FlytePickle raise ValueError(f"Transformer {self} cannot reverse {literal_type}") diff --git a/plugins/flytekit-data-fsspec/tests/test_persist.py b/plugins/flytekit-data-fsspec/tests/test_persist.py index ce3a4d7e41..d2ac50a7f8 100644 --- a/plugins/flytekit-data-fsspec/tests/test_persist.py +++ b/plugins/flytekit-data-fsspec/tests/test_persist.py @@ -10,15 +10,15 @@ def test_s3_setup_args(): kwargs = s3_setup_args() - assert kwargs == {"client_kwargs": {"endpoint_url": "http://localhost:30084"}} + assert kwargs == {} - with aws.S3_ENDPOINT.get_patcher("http://flyte:30084"): + with aws.S3_ENDPOINT.get_patcher("http://localhost:30084"): kwargs = s3_setup_args() - assert kwargs == {"client_kwargs": {"endpoint_url": "http://flyte:30084"}} + assert kwargs == {"client_kwargs": {"endpoint_url": "http://localhost:30084"}} with aws.S3_ACCESS_KEY_ID.get_patcher("access"): kwargs = s3_setup_args() - assert kwargs == {"client_kwargs": {"endpoint_url": "http://localhost:30084"}} + assert kwargs == {} assert os.environ[aws.S3_ACCESS_KEY_ID_ENV_NAME] == "access"