From 35bb5561537e45a32c6ae328684fe09cea318786 Mon Sep 17 00:00:00 2001 From: peridotml <106936600+peridotml@users.noreply.github.com> Date: Mon, 8 May 2023 13:03:39 -0700 Subject: [PATCH] fix PipelineModel transformer issue 3648 (#1623) Signed-off-by: esad --- .../flytekitplugins/spark/pyspark_transformers.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py index e48778ad70..4afb257f9d 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py @@ -1,4 +1,3 @@ -import pathlib from typing import Type from pyspark.ml import PipelineModel @@ -24,22 +23,17 @@ def to_literal( python_type: Type[PipelineModel], expected: LiteralType, ) -> Literal: - local_path = ctx.file_access.get_random_local_path() - pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) - python_val.save(local_path) - + # Must write to remote directory remote_dir = ctx.file_access.get_random_remote_directory() - ctx.file_access.upload_directory(local_path, remote_dir) + python_val.write().overwrite().save(remote_dir) return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO)))) def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[PipelineModel] ) -> PipelineModel: - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.download_directory(lv.scalar.blob.uri, local_dir) - - return PipelineModel.load(local_dir) + remote_dir = lv.scalar.blob.uri + return PipelineModel.load(remote_dir) TypeEngine.register(PySparkPipelineModelTransformer())