diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 0a4ec05c23..44b573d66b 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -15,7 +15,6 @@ from pyspark import SparkConf from pyspark.sql import SparkSession from pytz import utc -from sdk.python.feast.infra.utils import aws_utils from feast import FeatureView, OnDemandFeatureView from feast.data_source import DataSource @@ -32,6 +31,7 @@ RetrievalMetadata, ) from feast.infra.registry.registry import Registry +from feast.infra.utils import aws_utils from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage from feast.type_map import spark_schema_to_np_dtypes @@ -368,7 +368,7 @@ def to_remote_storage(self) -> List[str]: sdf: pyspark.sql.DataFrame = self.to_spark_df() - if self._config.offline_store.staging_location.startswith("file://"): + if self._config.offline_store.staging_location.startswith("/"): local_file_staging_location = os.path.abspath( self._config.offline_store.staging_location ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py index 64a2a01cee..71c07b20c2 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py @@ -58,8 +58,8 @@ def create_offline_store_config(self): self.spark_offline_store_config = SparkOfflineStoreConfig() self.spark_offline_store_config.type = "spark" self.spark_offline_store_config.spark_conf = self.spark_conf - self.spark_offline_store_config.staging_location = "file://" + str( - tempfile.TemporaryDirectory() + self.spark_offline_store_config.staging_location = ( + tempfile.TemporaryDirectory().name ) self.spark_offline_store_config.region = "eu-west-1" return self.spark_offline_store_config