Skip to content

Commit

Permalink
Preserve master credentials on spawning platforms
Browse files Browse the repository at this point in the history
Prevent spawning platform minions from having to re-authenticate on
every job when using multiprocessing=True
  • Loading branch information
dwoz committed Aug 11, 2023
1 parent a46d846 commit 6107897
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 11 deletions.
2 changes: 2 additions & 0 deletions changelog/64914.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Preserve credentials on spawning platforms, minions no longer re-authenticate
with every job when using `multiprocessing=True`.
10 changes: 7 additions & 3 deletions salt/metaproxy/deltaproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def subproxy_post_master_init(minion_id, uid, opts, main_proxy, main_utils):
return {"proxy_minion": _proxy_minion, "proxy_opts": proxyopts}


def target(cls, minion_instance, opts, data, connected):
def target(cls, minion_instance, opts, data, connected, creds_map):
"""
Handle targeting of the minion.
Expand All @@ -593,6 +593,8 @@ def target(cls, minion_instance, opts, data, connected):
minion_instance.opts["id"],
opts["id"],
)
if creds_map:
salt.crypt.AsyncAuth.creds_map = creds_map

if not hasattr(minion_instance, "proc_dir"):
uid = salt.utils.user.get_uid(user=opts.get("user", None))
Expand Down Expand Up @@ -1061,21 +1063,23 @@ def handle_decoded_payload(self, data):
instance = self
multiprocessing_enabled = self.opts.get("multiprocessing", True)
name = "ProcessPayload(jid={})".format(data["jid"])
creds_map = None
if multiprocessing_enabled:
if salt.utils.platform.spawning_platform():
# let python reconstruct the minion on the other side if we"re
# running on spawning platforms
instance = None
creds_map = salt.crypt.AsyncAuth.creds_map
with default_signals(signal.SIGINT, signal.SIGTERM):
process = SignalHandlingProcess(
target=target,
args=(self, instance, self.opts, data, self.connected),
args=(self, instance, self.opts, data, self.connected, creds_map),
name=name,
)
else:
process = threading.Thread(
target=target,
args=(self, instance, self.opts, data, self.connected),
args=(self, instance, self.opts, data, self.connected, creds_map),
name=name,
)

Expand Down
10 changes: 7 additions & 3 deletions salt/metaproxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,15 @@ def post_master_init(self, master):
self.ready = True


def target(cls, minion_instance, opts, data, connected):
def target(cls, minion_instance, opts, data, connected, creds_map):
"""
Handle targeting of the minion.
Calling _thread_multi_return or _thread_return
depending on a single or multiple commands.
"""
if creds_map:
salt.crypt.AsyncAuth.creds_map = creds_map
if not minion_instance:
minion_instance = cls(opts)
minion_instance.connected = connected
Expand Down Expand Up @@ -814,21 +816,23 @@ def handle_decoded_payload(self, data):
instance = self
multiprocessing_enabled = self.opts.get("multiprocessing", True)
name = "ProcessPayload(jid={})".format(data["jid"])
creds_map = None
if multiprocessing_enabled:
if salt.utils.platform.spawning_platform():
# let python reconstruct the minion on the other side if we're
# running on windows
instance = None
creds_map = salt.crypt.AsyncAuth.creds_map
with default_signals(signal.SIGINT, signal.SIGTERM):
process = SignalHandlingProcess(
target=self._target,
name=name,
args=(instance, self.opts, data, self.connected),
args=(instance, self.opts, data, self.connected, creds_map),
)
else:
process = threading.Thread(
target=self._target,
args=(instance, self.opts, data, self.connected),
args=(instance, self.opts, data, self.connected, creds_map),
name=name,
)

Expand Down
14 changes: 9 additions & 5 deletions salt/minion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,24 +1763,26 @@ def _handle_decoded_payload(self, data):
# python needs to be able to reconstruct the reference on the other
# side.
instance = self
creds_map = None
multiprocessing_enabled = self.opts.get("multiprocessing", True)
name = "ProcessPayload(jid={})".format(data["jid"])
if multiprocessing_enabled:
if salt.utils.platform.spawning_platform():
# let python reconstruct the minion on the other side if we're
# running on windows
instance = None
creds_map = salt.crypt.AsyncAuth.creds_map
with default_signals(signal.SIGINT, signal.SIGTERM):
process = SignalHandlingProcess(
target=self._target,
name=name,
args=(instance, self.opts, data, self.connected),
args=(instance, self.opts, data, self.connected, creds_map),
)
process.register_after_fork_method(salt.utils.crypt.reinit_crypto)
else:
process = threading.Thread(
target=self._target,
args=(instance, self.opts, data, self.connected),
args=(instance, self.opts, data, self.connected, creds_map),
name=name,
)

Expand All @@ -1804,7 +1806,9 @@ def ctx(self):
return exitstack

@classmethod
def _target(cls, minion_instance, opts, data, connected):
def _target(cls, minion_instance, opts, data, connected, creds_map):
if creds_map:
salt.crypt.AsyncAuth.creds_map = creds_map
if not minion_instance:
minion_instance = cls(opts, load_grains=False)
minion_instance.connected = connected
Expand Down Expand Up @@ -3879,10 +3883,10 @@ def _handle_decoded_payload(self, data):
return mp_call(self, data)

@classmethod
def _target(cls, minion_instance, opts, data, connected):
def _target(cls, minion_instance, opts, data, connected, creds_map):

mp_call = _metaproxy_call(opts, "target")
return mp_call(cls, minion_instance, opts, data, connected)
return mp_call(cls, minion_instance, opts, data, connected, creds_map)

@classmethod
def _thread_return(cls, minion_instance, opts, data):
Expand Down
49 changes: 49 additions & 0 deletions tests/pytests/integration/minion/test_reauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import time


def test_reauth(salt_master_factory, event_listener):
"""
Validate non of our platform need to re-authenticate when runing a job with
multiprocessing=True.
"""
sls_name = "issue-64941"
sls_contents = """
custom_test_state:
test.configurable_test_state:
- name: example
- changes: True
- result: True
- comment: "Nothing has acutally been changed"
"""
events = []

def handler(data):
events.append(data)

event_listener.register_auth_event_handler("test_reauth-master", handler)
master = salt_master_factory.salt_master_daemon(
"test_reauth-master",
overrides={"log_level": "info"},
)
sls_tempfile = master.state_tree.base.temp_file(
"{}.sls".format(sls_name), sls_contents
)
minion = master.salt_minion_daemon(
"test_reauth-minion",
overrides={"log_level": "info"},
)
cli = master.salt_cli()
start_time = time.time()
with master.started(), minion.started():
events = event_listener.get_events(
[(master.id, "salt/auth")],
after_time=start_time,
)
num_auth = len(events)
proc = cli.run("state.sls", sls_name, minion_tgt="*")
assert proc.returncode == 1
events = event_listener.get_events(
[(master.id, "salt/auth")],
after_time=start_time,
)
assert num_auth == len(events)

0 comments on commit 6107897

Please sign in to comment.