diff --git a/spyt-package/src/main/python/spyt/client.py b/spyt-package/src/main/python/spyt/client.py index ab0a010d..41b56d7d 100644 --- a/spyt-package/src/main/python/spyt/client.py +++ b/spyt-package/src/main/python/spyt/client.py @@ -15,8 +15,9 @@ from .arcadia import checked_extract_spark # noqa: E402 from .utils import default_token, default_discovery_dir, get_spark_master, set_conf, \ SparkDiscovery, parse_memory, format_memory, base_spark_conf, parse_bool, get_spyt_home # noqa: E402 -from .conf import read_remote_conf, read_global_conf, validate_versions_compatibility, \ - read_cluster_conf, SELF_VERSION # noqa: E402 +from .conf import validate_versions_compatibility, \ + read_cluster_conf, SELF_VERSION, SpytDistributions # noqa: E402 +from .standalone import get_spyt_distributions logger = logging.getLogger(__name__) @@ -83,6 +84,7 @@ def spark_session(num_executors=None, dynamic_allocation=False, spark_conf_args=None, local_conf_path=Defaults.LOCAL_CONF_PATH, + spyt_distributions: SpytDistributions = None, client=None, spyt_version=None): def do_create_inner_cluster_session(): @@ -97,6 +99,7 @@ def do_create_inner_cluster_session(): dynamic_allocation=dynamic_allocation, spark_conf_args=spark_conf_args, local_conf_path=local_conf_path, + spyt_distributions=spyt_distributions, client=client, spyt_version=spyt_version, ) @@ -246,6 +249,7 @@ def _build_spark_conf(num_executors=None, dynamic_allocation=None, spark_conf_args=None, local_conf_path=None, + spyt_distributions: SpytDistributions = None, client=None, spyt_version=None): is_client_mode = os.getenv("IS_SPARK_CLUSTER") is None @@ -269,8 +273,9 @@ def _build_spark_conf(num_executors=None, num_executors, cores_per_executor, executor_memory_per_core, driver_memory, dynamic_allocation) - global_conf = read_global_conf(client=client) - remote_conf = read_remote_conf(global_conf, spark_cluster_version, client=client) + if spyt_distributions is None: + spyt_distributions = get_spyt_distributions(client) + remote_conf = spyt_distributions.read_remote_conf(spark_cluster_version) set_conf(spark_conf, remote_conf["spark_conf"]) if is_client_mode: @@ -304,6 +309,7 @@ def connect(num_executors=5, dynamic_allocation=True, spark_conf_args=None, local_conf_path=Defaults.LOCAL_CONF_PATH, + spyt_distributions: SpytDistributions = None, client=None, spyt_version=None): conf = _build_spark_conf( @@ -317,6 +323,7 @@ def connect(num_executors=5, dynamic_allocation=dynamic_allocation, spark_conf_args=spark_conf_args, local_conf_path=local_conf_path, + spyt_distributions=spyt_distributions, client=client, spyt_version=spyt_version, ) diff --git a/spyt-package/src/main/python/spyt/conf.py b/spyt-package/src/main/python/spyt/conf.py index bfd7b8cf..5e7a1266 100644 --- a/spyt-package/src/main/python/spyt/conf.py +++ b/spyt-package/src/main/python/spyt/conf.py @@ -3,19 +3,11 @@ from spyt.dependency_utils import require_yt_client require_yt_client() -from yt.wrapper import get, YPath, list as yt_list, exists # noqa: E402 +from yt.wrapper import get, YPath, list as yt_list, exists, YtClient # noqa: E402 from yt.wrapper.common import update_inplace # noqa: E402 from .version import __scala_version__ # noqa: E402 from pyspark import __version__ as spark_version # noqa: E402 -SPARK_BASE_PATH = YPath("//home/spark") - -CONF_BASE_PATH = SPARK_BASE_PATH.join("conf") -GLOBAL_CONF_PATH = CONF_BASE_PATH.join("global") - -SPYT_BASE_PATH = SPARK_BASE_PATH.join("spyt") -DISTRIB_BASE_PATH = SPARK_BASE_PATH.join("distrib") - RELEASES_SUBDIR = "releases" SNAPSHOTS_SUBDIR = "snapshots" @@ -26,6 +18,69 @@ logger = logging.getLogger(__name__) +class SpytDistributions: + def __init__(self, client: YtClient, yt_root: str): + self.client = client + self.yt_root = YPath(yt_root) + self.conf_base_path = self.yt_root.join("conf") + self.global_conf = client.get(self.conf_base_path.join("global")) + self.distrib_base_path = self.yt_root.join("distrib") + self.spyt_base_path = self.yt_root.join("spyt") + + def read_remote_conf(self, cluster_version): + version_conf = self.client.get(self._get_version_conf_path(cluster_version)) + version_conf["cluster_version"] = cluster_version + return update_inplace(self.global_conf, version_conf) # TODO(alex-shishkin): Might cause undefined behaviour + + def latest_ytserver_proxy_path(self, cluster_version): + if cluster_version: + return None + symlink_path = self.global_conf.get("ytserver_proxy_path") + if symlink_path is None: + return None + return self.client.get("{}&/@target_path".format(symlink_path)) + + def validate_cluster_version(self, spark_cluster_version): + if not exists(self._get_version_conf_path(spark_cluster_version), client=self.client): + raise RuntimeError("Unknown SPYT cluster version: {}. Available release versions are: {}".format( + spark_cluster_version, self.get_available_cluster_versions() + )) + spyt_minor_version = SpytVersion(SELF_VERSION).get_minor() + cluster_minor_version = SpytVersion(spark_cluster_version).get_minor() + if spyt_minor_version < cluster_minor_version: + logger.warning("You required SPYT version {} which is older than your local ytsaurus-spyt version {}." + "Please update your local ytsaurus-spyt".format(spark_cluster_version, SELF_VERSION)) + + def get_available_cluster_versions(self): + subdirs = yt_list(self.conf_base_path.join(RELEASES_SUBDIR), client=self.client) + return [x for x in subdirs if x != "spark-launch-conf"] + + def latest_compatible_spyt_version(self, version): + minor_version = SpytVersion(version).get_minor() + spyt_versions = self.get_available_spyt_versions() + compatible_spyt_versions = [x for x in spyt_versions if SpytVersion(x).get_minor() == minor_version] + if not compatible_spyt_versions: + raise RuntimeError(f"No compatible SPYT versions found for specified version {version}") + return max(compatible_spyt_versions, key=SpytVersion) + + def get_available_spyt_versions(self): + return yt_list(self.spyt_base_path.join(RELEASES_SUBDIR), client=self.client) + + def get_spark_distributive(self): + distrib_root = self.distrib_base_path.join(spark_version.replace('.', '/')) + distrib_root_contents = yt_list(distrib_root, client=self.client) + spark_tgz = [x for x in distrib_root_contents if x.endswith('.tgz')] + if len(spark_tgz) == 0: + raise RuntimeError(f"Spark {spark_version} tgz distributive doesn't exist " + f"at path {distrib_root} on cluster {self.client.config['proxy']['url']}") + return (spark_tgz[0], distrib_root.join(spark_tgz[0])) + + def _get_version_conf_path(self, version): + return self.conf_base_path.join( + SNAPSHOTS_SUBDIR if "SNAPSHOT" in version or "beta" in version or "dev" in version else RELEASES_SUBDIR + ).join(version).join("spark-launch-conf") + + class SpytVersion: def __init__(self, version=None, major=0, minor=0, patch=0): if version is not None: @@ -58,18 +113,6 @@ def __str__(self): return f"{self.major}.{self.minor}.{self.patch}" -def validate_cluster_version(spark_cluster_version, client=None): - if not check_cluster_version_exists(spark_cluster_version, client=client): - raise RuntimeError("Unknown SPYT cluster version: {}. Available release versions are: {}".format( - spark_cluster_version, get_available_cluster_versions(client=client) - )) - spyt_minor_version = SpytVersion(SELF_VERSION).get_minor() - cluster_minor_version = SpytVersion(spark_cluster_version).get_minor() - if spyt_minor_version < cluster_minor_version: - logger.warning("You required SPYT version {} which is older than your local ytsaurus-spyt version {}." - "Please update your local ytsaurus-spyt".format(spark_cluster_version, SELF_VERSION)) - - def validate_versions_compatibility(spyt_version, spark_cluster_version): spyt_minor_version = SpytVersion(spyt_version).get_minor() spark_cluster_minor_version = SpytVersion(spark_cluster_version).get_minor() @@ -84,15 +127,6 @@ def validate_mtn_config(enablers, network_project, tvm_id, tvm_secret): raise RuntimeError("When using MTN, network_project arg must be set.") -def latest_compatible_spyt_version(version, client=None): - minor_version = SpytVersion(version).get_minor() - spyt_versions = get_available_spyt_versions(client) - compatible_spyt_versions = [x for x in spyt_versions if SpytVersion(x).get_minor() == minor_version] - if not compatible_spyt_versions: - raise RuntimeError(f"No compatible SPYT versions found for specified version {version}") - return max(compatible_spyt_versions, key=SpytVersion) - - def python_bin_path(global_conf, version): return global_conf["python_cluster_paths"].get(version) @@ -115,26 +149,6 @@ def validate_ssd_config(disk_limit, disk_account): raise RuntimeError("Disk account must be provided to use disk limit, please add --worker-disk-account option") -def get_available_cluster_versions(client=None): - subdirs = yt_list(CONF_BASE_PATH.join(RELEASES_SUBDIR), client=client) - return [x for x in subdirs if x != "spark-launch-conf"] - - -def check_cluster_version_exists(cluster_version, client=None): - return exists(_get_version_conf_path(cluster_version), client=client) - - -def read_global_conf(client=None): - return client.get(GLOBAL_CONF_PATH) - - -def read_remote_conf(global_conf, cluster_version, client=None): - version_conf_path = _get_version_conf_path(cluster_version) - version_conf = get(version_conf_path, client=client) - version_conf["cluster_version"] = cluster_version - return update_inplace(global_conf, version_conf) # TODO(alex-shishkin): Might cause undefined behaviour - - def read_cluster_conf(path=None, client=None): if path is None: return {} @@ -156,41 +170,9 @@ def validate_custom_params(params): "Use argument 'enablers' instead") -def get_available_spyt_versions(client=None): - return yt_list(SPYT_BASE_PATH.join(RELEASES_SUBDIR), client=client) - - -def latest_ytserver_proxy_path(cluster_version, client=None): - if cluster_version: - return None - global_conf = read_global_conf(client=client) - symlink_path = global_conf.get("ytserver_proxy_path") - if symlink_path is None: - return None - return get("{}&/@target_path".format(symlink_path), client=client) - - def ytserver_proxy_attributes(path, client=None): return get("{}/@user_attributes".format(path), client=client) -def get_spark_distributive(client): - distrib_root = DISTRIB_BASE_PATH.join(spark_version.replace('.', '/')) - distrib_root_contents = yt_list(distrib_root, client=client) - spark_tgz = [x for x in distrib_root_contents if x.endswith('.tgz')] - if len(spark_tgz) == 0: - raise RuntimeError(f"Spark {spark_version} tgz distributive doesn't exist " - f"at path {distrib_root} on cluster {client.config['proxy']['url']}") - return (spark_tgz[0], distrib_root.join(spark_tgz[0])) - - def _get_or_else(d, key, default): return d.get(key) or default - - -def _version_subdir(version): - return SNAPSHOTS_SUBDIR if "SNAPSHOT" in version or "beta" in version or "dev" in version else RELEASES_SUBDIR - - -def _get_version_conf_path(cluster_version): - return CONF_BASE_PATH.join(_version_subdir(cluster_version)).join(cluster_version).join("spark-launch-conf") diff --git a/spyt-package/src/main/python/spyt/spec.py b/spyt-package/src/main/python/spyt/spec.py index 6d565b71..e6ddf257 100644 --- a/spyt-package/src/main/python/spyt/spec.py +++ b/spyt-package/src/main/python/spyt/spec.py @@ -12,7 +12,7 @@ from yt.wrapper.http_helpers import get_token, get_user_name # noqa: E402 from yt.wrapper.spec_builders import VanillaSpecBuilder # noqa: E402 -from .conf import ytserver_proxy_attributes, get_spark_distributive # noqa: E402 +from .conf import ytserver_proxy_attributes, SpytDistributions # noqa: E402 from .utils import SparkDiscovery, call_get_proxy_address_url, parse_memory # noqa: E402 from .enabler import SpytEnablers # noqa: E402 from .version import __version__ # noqa: E402 @@ -43,6 +43,7 @@ class SparkDefaultArguments(object): @staticmethod def get_params(): return { + "spyt_distributions": { "yt_root": "//home/spark" }, "operation_spec": { "annotations": { "is_spark": True, @@ -362,15 +363,16 @@ def _script_absolute_path(script): .end_task() -def build_spark_operation_spec(config: dict, client: YtClient, +def build_spark_operation_spec(config: dict, spyt_distributions: SpytDistributions, job_types: List[str], common_config: CommonComponentConfig, master_config: MasterConfig = None, worker_config: WorkerConfig = None, hs_config: HistoryServerConfig = None, livy_config: LivyConfig = None): + client = spyt_distributions.client if job_types == [] or job_types is None: job_types = ['master', 'history', 'worker'] spark_home = "./tmpfs" if common_config.enable_tmpfs else "." - spark_distributive_tgz, spark_distributive_path = get_spark_distributive(client) + spark_distributive_tgz, spark_distributive_path = spyt_distributions.get_spark_distributive() extra_java_opts = ["-Dlog4j.loglevel={}".format(common_config.cluster_log_level)] if common_config.enablers.enable_preference_ipv6: diff --git a/spyt-package/src/main/python/spyt/standalone.py b/spyt-package/src/main/python/spyt/standalone.py index 5b2bb514..fc1e2f05 100644 --- a/spyt-package/src/main/python/spyt/standalone.py +++ b/spyt-package/src/main/python/spyt/standalone.py @@ -23,9 +23,8 @@ from yt.wrapper.operation_commands \ import process_operation_unsuccesful_finish_state as process_operation_unsuccessful_finish_state -from .conf import read_remote_conf, validate_cluster_version, \ - latest_compatible_spyt_version, update_config_inplace, validate_custom_params, validate_mtn_config, \ - latest_ytserver_proxy_path, read_global_conf, python_bin_path, \ +from .conf import update_config_inplace, validate_custom_params, validate_mtn_config, \ + python_bin_path, SpytDistributions, \ worker_num_limit, validate_worker_num, read_cluster_conf, validate_ssd_config, cuda_toolkit_version # noqa: E402 from .utils import get_spark_master, base_spark_conf, SparkDiscovery, SparkCluster, call_get_proxy_address_url, \ parse_bool, _add_conf # noqa: E402 @@ -147,6 +146,7 @@ def raw_submit(discovery_path, spark_home, spark_args, 'No permission for reading cluster, actual permission status is ' + str(permission_status)) discovery = SparkDiscovery(discovery_path=discovery_path) + spyt_distributions = get_spyt_distributions(client) cluster_conf = read_cluster_conf(str(discovery.conf()), client) spark_conf = cluster_conf['spark_conf'] dedicated_driver_op = parse_bool(spark_conf.get('spark.dedicated_operation_mode')) @@ -155,7 +155,7 @@ def raw_submit(discovery_path, spark_home, spark_args, _add_master(discovery, spark_base_args, rest=True, client=client) _add_shs_option(discovery, spark_base_args, client=client) _add_base_spark_conf(client, discovery, spark_base_args) - _add_python_version(python_version, spark_base_args, client) + _add_python_version(python_version, spark_base_args, spyt_distributions) _add_dedicated_driver_op_conf(spark_base_args, dedicated_driver_op) _add_ipv6_preference(ipv6_preference_enabled, spark_base_args) spark_env = _create_spark_env(client, spark_home) @@ -203,10 +203,9 @@ def _add_dedicated_driver_op_conf(spark_args, dedicated_driver_op): }, spark_args) -def _add_python_version(python_version, spark_args, client): +def _add_python_version(python_version, spark_args, spyt_distributions: SpytDistributions): if python_version is not None: - global_conf = read_global_conf(client=client) - python_path = python_bin_path(global_conf, python_version) + python_path = python_bin_path(spyt_distributions.global_conf, python_version) if python_path: _add_conf({ "spark.pyspark.python": python_path @@ -273,9 +272,18 @@ def abort_spark_operations(spark_discovery, client): raise error -def get_base_cluster_config(global_conf, spark_cluster_version, params, base_discovery_path=None, client=None): +def get_spyt_distributions(client, params=None) -> SpytDistributions: + if params is None: + params = {} + return SpytDistributions( + client=client, + **params.get('spyt_distributions', SparkDefaultArguments.get_params()['spyt_distributions']), + ) + + +def get_base_cluster_config(spyt_distributions: SpytDistributions, spark_cluster_version, params, base_discovery_path=None): dynamic_config = SparkDefaultArguments.get_params() - update_config_inplace(dynamic_config, read_remote_conf(global_conf, spark_cluster_version, client=client)) + update_config_inplace(dynamic_config, spyt_distributions.read_remote_conf(spark_cluster_version)) update_config_inplace(dynamic_config, params) if base_discovery_path is not None: dynamic_config['spark_conf']['spark.base.discovery.path'] = base_discovery_path @@ -316,15 +324,16 @@ def start_livy_server(operation_alias=None, discovery_path=None, pool=None, enab "If you want use direct submit, " "please provide the option `--spark-master-address ytsaurus://`") + spyt_distributions = get_spyt_distributions(client, params) + if spark_cluster_version is None: - spark_cluster_version = latest_compatible_spyt_version(__scala_version__, client=client) + spark_cluster_version = spyt_distributions.latest_compatible_spyt_version(__scala_version__) - validate_cluster_version(spark_cluster_version, client=client) + spyt_distributions.validate_cluster_version(spark_cluster_version) validate_custom_params(params) validate_mtn_config(enablers, network_project, tvm_id, tvm_secret) - global_conf = read_global_conf(client=client) - dynamic_config = get_base_cluster_config(global_conf, spark_cluster_version, params, discovery_path, client) + dynamic_config = get_base_cluster_config(spyt_distributions, spark_cluster_version, params, discovery_path) enablers.apply_config(dynamic_config) spark_discovery = None @@ -341,7 +350,7 @@ def start_livy_server(operation_alias=None, discovery_path=None, pool=None, enab livy_config = LivyConfig( livy_driver_cores, livy_driver_memory, livy_max_sessions, spark_master_address, master_group_id ) - livy_builder = build_spark_operation_spec(config=dynamic_config, client=client, job_types=['livy'], + livy_builder = build_spark_operation_spec(config=dynamic_config, spyt_distributions=spyt_distributions, job_types=['livy'], common_config=common_config, livy_config=livy_config) address_path = spark_discovery.livy() if spark_discovery is not None else None return run_operation_wrapper(livy_builder, address_path, client) @@ -360,16 +369,17 @@ def start_history_server(operation_alias=None, discovery_path=None, pool=None, e spark_discovery = SparkDiscovery(discovery_path=discovery_path) - global_conf = read_global_conf(client=client) + spyt_distributions = get_spyt_distributions(client, params) if spark_cluster_version is None: - spark_cluster_version = latest_compatible_spyt_version(__scala_version__, client=client) + spark_cluster_version = spyt_distributions.latest_compatible_spyt_version(__scala_version__) - validate_cluster_version(spark_cluster_version, client=client) + spyt_distributions.validate_cluster_version(spark_cluster_version) validate_custom_params(params) validate_mtn_config(enablers, network_project, tvm_id, tvm_secret) - dynamic_config = get_base_cluster_config(global_conf, spark_cluster_version, params, - spark_discovery.base_discovery_path, client) + dynamic_config = get_base_cluster_config( + spyt_distributions, spark_cluster_version, params, spark_discovery.base_discovery_path + ) enablers.apply_config(dynamic_config) @@ -384,7 +394,7 @@ def start_history_server(operation_alias=None, discovery_path=None, pool=None, e history_server_memory_limit, history_server_cpu_limit, history_server_memory_overhead, shs_location, advanced_event_log ) - hs_builder = build_spark_operation_spec(config=dynamic_config, client=client, job_types=['history'], + hs_builder = build_spark_operation_spec(config=dynamic_config, spyt_distributions=spyt_distributions, job_types=['history'], common_config=common_config, hs_config=hs_config) return run_operation_wrapper(hs_builder, spark_discovery.shs(), client) @@ -511,10 +521,11 @@ def start_spark_cluster(worker_cores, worker_memory, worker_num, worker_cores_ov else: raise RuntimeError("This spark cluster is started already, use --abort-existing for auto restarting") - ytserver_proxy_path = latest_ytserver_proxy_path(spark_cluster_version, client=client) - global_conf = read_global_conf(client=client) + spyt_distributions = get_spyt_distributions(client, params) + ytserver_proxy_path = spyt_distributions.latest_ytserver_proxy_path(spark_cluster_version) + global_conf = spyt_distributions.global_conf if spark_cluster_version is None: - spark_cluster_version = latest_compatible_spyt_version(__scala_version__, client=client) + spark_cluster_version = spyt_distributions.latest_compatible_spyt_version(__scala_version__) logger.info(f"{spark_cluster_version} cluster version will be launched") if enable_history_server and not advanced_event_log: @@ -536,14 +547,15 @@ def start_spark_cluster(worker_cores, worker_memory, worker_num, worker_cores_ov else: logger.info("Launcher files will be placed to node disk with no guarantees on free space") - validate_cluster_version(spark_cluster_version, client=client) + spyt_distributions.validate_cluster_version(spark_cluster_version) validate_custom_params(params) validate_mtn_config(enablers, network_project, tvm_id, tvm_secret) validate_worker_num(worker_res.num, worker_num_limit(global_conf)) validate_ssd_config(worker_disk_limit, worker_disk_account) - dynamic_config = get_base_cluster_config(global_conf, spark_cluster_version, params, - spark_discovery.base_discovery_path, client) + dynamic_config = get_base_cluster_config( + spyt_distributions, spark_cluster_version, params, spark_discovery.base_discovery_path + ) if ytserver_proxy_path: dynamic_config["ytserver_proxy_path"] = ytserver_proxy_path dynamic_config['spark_conf']['spark.dedicated_operation_mode'] = dedicated_operation_mode @@ -587,16 +599,16 @@ def start_spark_cluster(worker_cores, worker_memory, worker_num, worker_cores_ov advanced_event_log ) livy_config = LivyConfig(livy_driver_cores, livy_driver_memory, livy_max_sessions) - args = { - 'config': dynamic_config, - 'client': client, - 'job_types': job_types, - 'common_config': common_config, - 'master_config': master_config, - 'worker_config': worker_config, - 'hs_config': hs_config, - 'livy_config': livy_config, - } + args = dict( + config=dynamic_config, + spyt_distributions=spyt_distributions, + job_types=job_types, + common_config=common_config, + master_config=master_config, + worker_config=worker_config, + hs_config=hs_config, + livy_config=livy_config, + ) master_args = args.copy() master_builder = build_spark_operation_spec(**master_args)