diff --git a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py index 7e0c0b77e7..e769540aea 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py @@ -20,4 +20,4 @@ from .pyspark_transformers import PySparkPipelineModelTransformer from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer # noqa from .sd_transformers import ParquetToSparkDecodingHandler, SparkToParquetEncodingHandler -from .task import Spark, new_spark_session # noqa +from .task import Databricks, Spark, new_spark_session # noqa diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 180a28bb87..7b32e9f28b 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -118,7 +118,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: spark_type=SparkType.PYTHON, ) if isinstance(self.task_config, Databricks): - cfg = typing.cast(self.task_config, Databricks) + cfg = typing.cast(Databricks, self.task_config) job._databricks_conf = cfg.databricks_conf job._databricks_token = cfg.databricks_token job._databricks_instance = cfg.databricks_instance @@ -150,3 +150,4 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: # Inject the Spark plugin into flytekits dynamic plugin loading system TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask) +TaskPlugins.register_pythontask_plugin(Databricks, PysparkFunctionTask)