diff --git a/spyt-package/src/main/python/spyt/client.py b/spyt-package/src/main/python/spyt/client.py index 2cb0efa8..41b56d7d 100644 --- a/spyt-package/src/main/python/spyt/client.py +++ b/spyt-package/src/main/python/spyt/client.py @@ -16,7 +16,7 @@ from .utils import default_token, default_discovery_dir, get_spark_master, set_conf, \ SparkDiscovery, parse_memory, format_memory, base_spark_conf, parse_bool, get_spyt_home # noqa: E402 from .conf import validate_versions_compatibility, \ - read_cluster_conf, SELF_VERSION # noqa: E402 + read_cluster_conf, SELF_VERSION, SpytDistributions # noqa: E402 from .standalone import get_spyt_distributions @@ -84,6 +84,7 @@ def spark_session(num_executors=None, dynamic_allocation=False, spark_conf_args=None, local_conf_path=Defaults.LOCAL_CONF_PATH, + spyt_distributions: SpytDistributions = None, client=None, spyt_version=None): def do_create_inner_cluster_session(): @@ -98,6 +99,7 @@ def do_create_inner_cluster_session(): dynamic_allocation=dynamic_allocation, spark_conf_args=spark_conf_args, local_conf_path=local_conf_path, + spyt_distributions=spyt_distributions, client=client, spyt_version=spyt_version, ) @@ -247,6 +249,7 @@ def _build_spark_conf(num_executors=None, dynamic_allocation=None, spark_conf_args=None, local_conf_path=None, + spyt_distributions: SpytDistributions = None, client=None, spyt_version=None): is_client_mode = os.getenv("IS_SPARK_CLUSTER") is None @@ -270,7 +273,8 @@ def _build_spark_conf(num_executors=None, num_executors, cores_per_executor, executor_memory_per_core, driver_memory, dynamic_allocation) - spyt_distributions = get_spyt_distributions(client) + if spyt_distributions is None: + spyt_distributions = get_spyt_distributions(client) remote_conf = spyt_distributions.read_remote_conf(spark_cluster_version) set_conf(spark_conf, remote_conf["spark_conf"]) @@ -305,6 +309,7 @@ def connect(num_executors=5, dynamic_allocation=True, spark_conf_args=None, local_conf_path=Defaults.LOCAL_CONF_PATH, + spyt_distributions: SpytDistributions = None, client=None, spyt_version=None): conf = _build_spark_conf( @@ -318,6 +323,7 @@ def connect(num_executors=5, dynamic_allocation=dynamic_allocation, spark_conf_args=spark_conf_args, local_conf_path=local_conf_path, + spyt_distributions=spyt_distributions, client=client, spyt_version=spyt_version, )