Skip to content

Commit

Permalink
SpytDistributions
Browse files Browse the repository at this point in the history
  • Loading branch information
faucct committed Oct 3, 2024
1 parent c5d489b commit 97ce5b9
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 124 deletions.
7 changes: 4 additions & 3 deletions spyt-package/src/main/python/spyt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from .arcadia import checked_extract_spark # noqa: E402
from .utils import default_token, default_discovery_dir, get_spark_master, set_conf, \
SparkDiscovery, parse_memory, format_memory, base_spark_conf, parse_bool, get_spyt_home # noqa: E402
from .conf import read_remote_conf, read_global_conf, validate_versions_compatibility, \
from .conf import validate_versions_compatibility, \
read_cluster_conf, SELF_VERSION # noqa: E402
from .standalone import get_spyt_distributions


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -269,8 +270,8 @@ def _build_spark_conf(num_executors=None,
num_executors, cores_per_executor, executor_memory_per_core,
driver_memory, dynamic_allocation)

global_conf = read_global_conf(client=client)
remote_conf = read_remote_conf(global_conf, spark_cluster_version, client=client)
spyt_distributions = get_spyt_distributions(client)
remote_conf = spyt_distributions.read_remote_conf(spark_cluster_version)
set_conf(spark_conf, remote_conf["spark_conf"])

if is_client_mode:
Expand Down
148 changes: 66 additions & 82 deletions spyt-package/src/main/python/spyt/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,11 @@
from spyt.dependency_utils import require_yt_client
require_yt_client()

from yt.wrapper import get, YPath, list as yt_list, exists # noqa: E402
from yt.wrapper import get, YPath, list as yt_list, exists, YtClient # noqa: E402
from yt.wrapper.common import update_inplace # noqa: E402
from .version import __scala_version__ # noqa: E402
from pyspark import __version__ as spark_version # noqa: E402

SPARK_BASE_PATH = YPath("//home/spark")

CONF_BASE_PATH = SPARK_BASE_PATH.join("conf")
GLOBAL_CONF_PATH = CONF_BASE_PATH.join("global")

SPYT_BASE_PATH = SPARK_BASE_PATH.join("spyt")
DISTRIB_BASE_PATH = SPARK_BASE_PATH.join("distrib")

RELEASES_SUBDIR = "releases"
SNAPSHOTS_SUBDIR = "snapshots"

Expand All @@ -26,6 +18,71 @@
logger = logging.getLogger(__name__)


class SpytDistributions:
def __init__(self, client: YtClient, yt_root: str):
self.client = client
self.yt_root = YPath(yt_root)
self.conf_base_path = self.yt_root.join("conf")
self.global_conf = client.get(self.conf_base_path.join("global"))
self.distrib_base_path = self.yt_root.join("distrib")
self.spyt_base_path = self.yt_root.join("spyt")

def read_remote_conf(self, cluster_version):
version_conf = self.client.get(self._get_version_conf_path(cluster_version))
version_conf["cluster_version"] = cluster_version
return update_inplace(self.global_conf, version_conf) # TODO(alex-shishkin): Might cause undefined behaviour

def latest_ytserver_proxy_path(self, cluster_version):
if cluster_version:
return None
global_conf = self.global_conf
symlink_path = global_conf.get("ytserver_proxy_path")
if symlink_path is None:
return None
return get("{}&/@target_path".format(symlink_path), client=self.client)

def validate_cluster_version(self, spark_cluster_version):
if not exists(self._get_version_conf_path(spark_cluster_version), client=self.client):
raise RuntimeError("Unknown SPYT cluster version: {}. Available release versions are: {}".format(
spark_cluster_version, self.get_available_cluster_versions()
))
spyt_minor_version = SpytVersion(SELF_VERSION).get_minor()
cluster_minor_version = SpytVersion(spark_cluster_version).get_minor()
if spyt_minor_version < cluster_minor_version:
logger.warning("You required SPYT version {} which is older than your local ytsaurus-spyt version {}."
"Please update your local ytsaurus-spyt".format(spark_cluster_version, SELF_VERSION))

def get_available_cluster_versions(self):
subdirs = yt_list(self.conf_base_path.join(RELEASES_SUBDIR), client=self.client)
return [x for x in subdirs if x != "spark-launch-conf"]

def latest_compatible_spyt_version(self, version):
minor_version = SpytVersion(version).get_minor()
spyt_versions = self.get_available_spyt_versions()
compatible_spyt_versions = [x for x in spyt_versions if SpytVersion(x).get_minor() == minor_version]
if not compatible_spyt_versions:
raise RuntimeError(f"No compatible SPYT versions found for specified version {version}")
return max(compatible_spyt_versions, key=SpytVersion)

def get_available_spyt_versions(self):
return yt_list(self.spyt_base_path.join(RELEASES_SUBDIR), client=self.client)

def get_spark_distributive(self):
distrib_root = self.distrib_base_path.join(spark_version.replace('.', '/'))
distrib_root_contents = yt_list(distrib_root, client=self.client)
spark_tgz = [x for x in distrib_root_contents if x.endswith('.tgz')]
if len(spark_tgz) == 0:
raise RuntimeError(f"Spark {spark_version} tgz distributive doesn't exist "
f"at path {distrib_root} on cluster {self.client.config['proxy']['url']}")
return (spark_tgz[0], distrib_root.join(spark_tgz[0]))

def _get_version_conf_path(self, cluster_version):
return self.conf_base_path.join(self._version_subdir(cluster_version)).join(cluster_version).join("spark-launch-conf")

def _version_subdir(self, version):
return SNAPSHOTS_SUBDIR if "SNAPSHOT" in version or "beta" in version or "dev" in version else RELEASES_SUBDIR


class SpytVersion:
def __init__(self, version=None, major=0, minor=0, patch=0):
if version is not None:
Expand Down Expand Up @@ -58,18 +115,6 @@ def __str__(self):
return f"{self.major}.{self.minor}.{self.patch}"


def validate_cluster_version(spark_cluster_version, client=None):
if not check_cluster_version_exists(spark_cluster_version, client=client):
raise RuntimeError("Unknown SPYT cluster version: {}. Available release versions are: {}".format(
spark_cluster_version, get_available_cluster_versions(client=client)
))
spyt_minor_version = SpytVersion(SELF_VERSION).get_minor()
cluster_minor_version = SpytVersion(spark_cluster_version).get_minor()
if spyt_minor_version < cluster_minor_version:
logger.warning("You required SPYT version {} which is older than your local ytsaurus-spyt version {}."
"Please update your local ytsaurus-spyt".format(spark_cluster_version, SELF_VERSION))


def validate_versions_compatibility(spyt_version, spark_cluster_version):
spyt_minor_version = SpytVersion(spyt_version).get_minor()
spark_cluster_minor_version = SpytVersion(spark_cluster_version).get_minor()
Expand All @@ -84,15 +129,6 @@ def validate_mtn_config(enablers, network_project, tvm_id, tvm_secret):
raise RuntimeError("When using MTN, network_project arg must be set.")


def latest_compatible_spyt_version(version, client=None):
minor_version = SpytVersion(version).get_minor()
spyt_versions = get_available_spyt_versions(client)
compatible_spyt_versions = [x for x in spyt_versions if SpytVersion(x).get_minor() == minor_version]
if not compatible_spyt_versions:
raise RuntimeError(f"No compatible SPYT versions found for specified version {version}")
return max(compatible_spyt_versions, key=SpytVersion)


def python_bin_path(global_conf, version):
return global_conf["python_cluster_paths"].get(version)

Expand All @@ -115,26 +151,6 @@ def validate_ssd_config(disk_limit, disk_account):
raise RuntimeError("Disk account must be provided to use disk limit, please add --worker-disk-account option")


def get_available_cluster_versions(client=None):
subdirs = yt_list(CONF_BASE_PATH.join(RELEASES_SUBDIR), client=client)
return [x for x in subdirs if x != "spark-launch-conf"]


def check_cluster_version_exists(cluster_version, client=None):
return exists(_get_version_conf_path(cluster_version), client=client)


def read_global_conf(client=None):
return client.get(GLOBAL_CONF_PATH)


def read_remote_conf(global_conf, cluster_version, client=None):
version_conf_path = _get_version_conf_path(cluster_version)
version_conf = get(version_conf_path, client=client)
version_conf["cluster_version"] = cluster_version
return update_inplace(global_conf, version_conf) # TODO(alex-shishkin): Might cause undefined behaviour


def read_cluster_conf(path=None, client=None):
if path is None:
return {}
Expand All @@ -156,41 +172,9 @@ def validate_custom_params(params):
"Use argument 'enablers' instead")


def get_available_spyt_versions(client=None):
return yt_list(SPYT_BASE_PATH.join(RELEASES_SUBDIR), client=client)


def latest_ytserver_proxy_path(cluster_version, client=None):
if cluster_version:
return None
global_conf = read_global_conf(client=client)
symlink_path = global_conf.get("ytserver_proxy_path")
if symlink_path is None:
return None
return get("{}&/@target_path".format(symlink_path), client=client)


def ytserver_proxy_attributes(path, client=None):
return get("{}/@user_attributes".format(path), client=client)


def get_spark_distributive(client):
distrib_root = DISTRIB_BASE_PATH.join(spark_version.replace('.', '/'))
distrib_root_contents = yt_list(distrib_root, client=client)
spark_tgz = [x for x in distrib_root_contents if x.endswith('.tgz')]
if len(spark_tgz) == 0:
raise RuntimeError(f"Spark {spark_version} tgz distributive doesn't exist "
f"at path {distrib_root} on cluster {client.config['proxy']['url']}")
return (spark_tgz[0], distrib_root.join(spark_tgz[0]))


def _get_or_else(d, key, default):
return d.get(key) or default


def _version_subdir(version):
return SNAPSHOTS_SUBDIR if "SNAPSHOT" in version or "beta" in version or "dev" in version else RELEASES_SUBDIR


def _get_version_conf_path(cluster_version):
return CONF_BASE_PATH.join(_version_subdir(cluster_version)).join(cluster_version).join("spark-launch-conf")
8 changes: 5 additions & 3 deletions spyt-package/src/main/python/spyt/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from yt.wrapper.http_helpers import get_token, get_user_name # noqa: E402
from yt.wrapper.spec_builders import VanillaSpecBuilder # noqa: E402

from .conf import ytserver_proxy_attributes, get_spark_distributive # noqa: E402
from .conf import ytserver_proxy_attributes, SpytDistributions # noqa: E402
from .utils import SparkDiscovery, call_get_proxy_address_url, parse_memory # noqa: E402
from .enabler import SpytEnablers # noqa: E402
from .version import __version__ # noqa: E402
Expand Down Expand Up @@ -43,6 +43,7 @@ class SparkDefaultArguments(object):
@staticmethod
def get_params():
return {
"spyt_distributions": { "yt_root": "//home/spark" },
"operation_spec": {
"annotations": {
"is_spark": True,
Expand Down Expand Up @@ -362,15 +363,16 @@ def _script_absolute_path(script):
.end_task()


def build_spark_operation_spec(config: dict, client: YtClient,
def build_spark_operation_spec(config: dict, spyt_distributions: SpytDistributions,
job_types: List[str], common_config: CommonComponentConfig,
master_config: MasterConfig = None, worker_config: WorkerConfig = None,
hs_config: HistoryServerConfig = None, livy_config: LivyConfig = None):
client = spyt_distributions.client
if job_types == [] or job_types is None:
job_types = ['master', 'history', 'worker']

spark_home = "./tmpfs" if common_config.enable_tmpfs else "."
spark_distributive_tgz, spark_distributive_path = get_spark_distributive(client)
spark_distributive_tgz, spark_distributive_path = spyt_distributions.get_spark_distributive()

extra_java_opts = ["-Dlog4j.loglevel={}".format(common_config.cluster_log_level)]
if common_config.enablers.enable_preference_ipv6:
Expand Down
Loading

0 comments on commit 97ce5b9

Please sign in to comment.