Skip to content

Commit

Permalink
default_discovery_dir should use YT username
Browse files Browse the repository at this point in the history
  • Loading branch information
faucct committed Sep 17, 2024
1 parent 67b38d8 commit 6e2113a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
6 changes: 3 additions & 3 deletions spyt-package/src/main/python/spyt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def _create_spark_session(do_create_spark_session):
stop(spark, exception)


def get_spark_discovery(discovery_path, conf):
def get_spark_discovery(discovery_path, conf, client=None):
discovery_path = discovery_path or conf.get("discovery_path") or conf.get(
"discovery_dir") or default_discovery_dir()
"discovery_dir") or default_discovery_dir(client=client)
return SparkDiscovery(discovery_path=discovery_path)


Expand All @@ -150,7 +150,7 @@ def _configure_client_mode(spark_conf,
local_conf,
client=None,
spyt_version=None):
discovery = get_spark_discovery(discovery_path, local_conf)
discovery = get_spark_discovery(discovery_path, local_conf, client=client)
master = get_spark_master(discovery, rest=False, yt_client=client)
set_conf(spark_conf, base_spark_conf(client=client, discovery=discovery))
spark_conf.set("spark.master", master)
Expand Down
13 changes: 7 additions & 6 deletions spyt-package/src/main/python/spyt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def scala_buffer_to_list(buffer):
return [buffer.apply(i) for i in range(buffer.length())]


def default_user():
return os.getenv("YT_USER") or getpass.getuser()
def default_user(client=None):
return os.getenv("YT_USER") or (get_user_name(client=client) if client else None) or getpass.getuser()


def default_token():
Expand Down Expand Up @@ -262,8 +262,9 @@ def set_conf(conf, dict_conf):
conf.set(key, value)


def default_discovery_dir():
return os.getenv("SPARK_YT_DISCOVERY_DIR") or YPath("//home").join(os.getenv("USER")).join("spark-tmp")
def default_discovery_dir(client=None):
return os.getenv("SPARK_YT_DISCOVERY_DIR") \
or YPath("//home").join(default_user(client=client)).join("spark-tmp")


def default_proxy():
Expand All @@ -287,11 +288,11 @@ def get_default_arg_parser(**kwargs):
return parser


def parse_args(parser=None, parser_arguments=None, raw_args=None):
def parse_args(parser=None, parser_arguments=None, raw_args=None, client=None):
parser_arguments = parser_arguments or {}
parser = parser or get_default_arg_parser(**parser_arguments)
args, unknown_args = parser.parse_known_args(args=raw_args)
args.discovery_path = args.discovery_path or args.discovery_dir or default_discovery_dir()
args.discovery_path = args.discovery_path or args.discovery_dir or default_discovery_dir(client=client)
return args, unknown_args


Expand Down

0 comments on commit 6e2113a

Please sign in to comment.