Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Master thread safe loader #54878

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 84 additions & 4 deletions salt/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import salt.utils.odict
import salt.utils.platform
import salt.utils.stringutils
import salt.utils.thread_local_proxy as thread_local_proxy
import salt.utils.versions
from salt.exceptions import LoaderError

Expand Down Expand Up @@ -1089,6 +1090,75 @@ def _mod_type(module_path):
return "ext"


def _inject_into_mod(mod, name, value, force_lock=False):
"""
Inject a variable into a module. This is used to inject "globals" like
``__salt__``, ``__pillar``, or ``grains``.

Instead of injecting the value directly, a ``ThreadLocalProxy`` is created.
If such a proxy is already present under the specified name, it is updated
with the new value. This update only affects the current thread, so that
the same name can refer to different values depending on the thread of
execution.

This is important for data that is not truly global. For example, pillar
data might be dynamically overriden through function parameters and thus
the actual values available in pillar might depend on the thread that is
calling a module.

mod:
module object into which the value is going to be injected.

name:
name of the variable that is injected into the module.

value:
value that is injected into the variable. The value is not injected
directly, but instead set as the new reference of the proxy that has
been created for the variable.

force_lock:
whether the lock should be acquired before checking whether a proxy
object for the specified name has already been injected into the
module. If ``False`` (the default), this function checks for the
module's variable without acquiring the lock and only acquires the lock
if a new proxy has to be created and injected.
"""
old_value = getattr(mod, name, None)
# We use a double-checked locking scheme in order to avoid taking the lock
# when a proxy object has already been injected.
# In most programming languages, double-checked locking is considered
# unsafe when used without explicit memory barriers because one might read
# an uninitialized value. In CPython it is safe due to the global
# interpreter lock (GIL). In Python implementations that do not have the
# GIL, it could be unsafe, but at least Jython also guarantees that (for
# Python objects) memory is not corrupted when writing and reading without
# explicit synchronization
# (http://www.jython.org/jythonbook/en/1.0/Concurrency.html).
# Please note that in order to make this code safe in a runtime environment
# that does not make this guarantees, it is not sufficient. The
# ThreadLocalProxy must also be created with fallback_to_shared set to
# False or a lock must be added to the ThreadLocalProxy.
if force_lock:
with _inject_into_mod.lock:
if isinstance(old_value, thread_local_proxy.ThreadLocalProxy):
thread_local_proxy.ThreadLocalProxy.set_reference(old_value, value)
else:
setattr(mod, name, thread_local_proxy.ThreadLocalProxy(value, True))
else:
if isinstance(old_value, thread_local_proxy.ThreadLocalProxy):
thread_local_proxy.ThreadLocalProxy.set_reference(old_value, value)
else:
_inject_into_mod(mod, name, value, True)


# Lock used when injecting globals. This is needed to avoid a race condition
# when two threads try to load the same module concurrently. This must be
# outside the loader because there might be more than one loader for the same
# namespace.
_inject_into_mod.lock = threading.RLock()


# TODO: move somewhere else?
class FilterDictWrapper(MutableMapping):
"""
Expand Down Expand Up @@ -1185,7 +1255,11 @@ def __init__(

for k, v in six.iteritems(self.pack):
if v is None: # if the value of a pack is None, lets make an empty dict
self.context_dict.setdefault(k, {})
value = thread_local_proxy.ThreadLocalProxy.unproxy(
self.context_dict.get(k, {})
)

self.context_dict[k] = value
self.pack[k] = salt.utils.context.NamespacedDictWrapper(
self.context_dict, k
)
Expand Down Expand Up @@ -1468,13 +1542,19 @@ def __prep_mod_opts(self, opts):
Strip out of the opts any logger instance
"""
if "__grains__" not in self.pack:
self.context_dict["grains"] = opts.get("grains", {})
_grains = thread_local_proxy.ThreadLocalProxy.unproxy(
opts.get("grains", {})
)

self.context_dict["grains"] = _grains
self.pack["__grains__"] = salt.utils.context.NamespacedDictWrapper(
self.context_dict, "grains"
)

if "__pillar__" not in self.pack:
self.context_dict["pillar"] = opts.get("pillar", {})
pillar = thread_local_proxy.ThreadLocalProxy.unproxy(opts.get("pillar", {}))

self.context_dict["pillar"] = pillar
self.pack["__pillar__"] = salt.utils.context.NamespacedDictWrapper(
self.context_dict, "pillar"
)
Expand Down Expand Up @@ -1670,7 +1750,7 @@ def _load_module(self, name):

# pack whatever other globals we were asked to
for p_name, p_value in six.iteritems(self.pack):
setattr(mod, p_name, p_value)
_inject_into_mod(mod, p_name, p_value)

module_name = mod.__name__.rsplit(".", 1)[-1]

Expand Down
19 changes: 17 additions & 2 deletions salt/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# Import Salt libs
import salt.utils.data
import salt.utils.stringutils
import salt.utils.thread_local_proxy as thread_local_proxy

# Import 3rd-party libs
from salt.ext import six
Expand Down Expand Up @@ -119,11 +120,18 @@ def dump(obj, fp, **kwargs):
using the _json_module argument)
"""
json_module = kwargs.pop("_json_module", json)
orig_enc_func = kwargs.pop("default", lambda x: x)

def _enc_func(_obj):
return orig_enc_func(thread_local_proxy.ThreadLocalProxy.unproxy(_obj))

if "ensure_ascii" not in kwargs:
kwargs["ensure_ascii"] = False
if six.PY2:
obj = salt.utils.data.encode(obj)
return json_module.dump(obj, fp, **kwargs) # future lint: blacklisted-function
return json_module.dump(
obj, fp, default=_enc_func, **kwargs
) # future lint: blacklisted-function


def dumps(obj, **kwargs):
Expand All @@ -142,8 +150,15 @@ def dumps(obj, **kwargs):
using the _json_module argument)
"""
json_module = kwargs.pop("_json_module", json)
orig_enc_func = kwargs.pop("default", lambda x: x)

def _enc_func(_obj):
return orig_enc_func(thread_local_proxy.ThreadLocalProxy.unproxy(_obj))

if "ensure_ascii" not in kwargs:
kwargs["ensure_ascii"] = False
if six.PY2:
obj = salt.utils.data.encode(obj)
return json_module.dumps(obj, **kwargs) # future lint: blacklisted-function
return json_module.dumps(
obj, default=_enc_func, **kwargs
) # future lint: blacklisted-function
17 changes: 15 additions & 2 deletions salt/utils/msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

import logging

# Import Salt libs
import salt.utils.thread_local_proxy as thread_local_proxy

log = logging.getLogger(__name__)

# Import 3rd party libs
Expand Down Expand Up @@ -94,8 +97,13 @@ def pack(o, stream, **kwargs):
By default, this function uses the msgpack module and falls back to
msgpack_pure, if the msgpack is not available.
"""
orig_enc_func = kwargs.pop("default", lambda x: x)

def _enc_func(obj):
return orig_enc_func(thread_local_proxy.ThreadLocalProxy.unproxy(obj))

# Writes to a stream, there is no return
msgpack.pack(o, stream, **_sanitize_msgpack_kwargs(kwargs))
msgpack.pack(o, stream, default=_enc_func, **_sanitize_msgpack_kwargs(kwargs))


def packb(o, **kwargs):
Expand All @@ -108,7 +116,12 @@ def packb(o, **kwargs):
By default, this function uses the msgpack module and falls back to
msgpack_pure, if the msgpack is not available.
"""
return msgpack.packb(o, **_sanitize_msgpack_kwargs(kwargs))
orig_enc_func = kwargs.pop("default", lambda x: x)

def _enc_func(obj):
return orig_enc_func(thread_local_proxy.ThreadLocalProxy.unproxy(obj))

return msgpack.packb(o, default=_enc_func, **_sanitize_msgpack_kwargs(kwargs))


def unpack(stream, **kwargs):
Expand Down
Loading