Skip to content

Commit

Permalink
MONGOID-5823 Use proper thread-local variables instead of fiber-local…
Browse files Browse the repository at this point in the history
… variables (#5891)

* MONGOID-5823 Use proper thread-local variables

Using fiber-local variables instead of thread-local variables has
the potential to introduce difficult bugs when Mongoid's internal
state is not visible to Fiber-wrapped cascading callbacks.

* remove cruft from an earlier experient

* *grumble* rubocop *grumble*

* fix test failures

* compensate for jruby
  • Loading branch information
jamis committed Oct 24, 2024
1 parent a9ccd04 commit 04a4432
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 53 deletions.
12 changes: 7 additions & 5 deletions lib/mongoid/persistence_context.rb
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ def clear(object, cluster = nil, original_context = nil)
# @api private
PERSISTENCE_CONTEXT_KEY = :"[mongoid]:persistence_context"

def context_store
Threaded.get(PERSISTENCE_CONTEXT_KEY) { {} }
end

# Get the persistence context for a given object from the thread local
# storage.
#
Expand All @@ -295,8 +299,7 @@ def clear(object, cluster = nil, original_context = nil)
#
# @api private
def get_context(object)
Thread.current[PERSISTENCE_CONTEXT_KEY] ||= {}
Thread.current[PERSISTENCE_CONTEXT_KEY][object.object_id]
context_store[object.object_id]
end

# Store persistence context for a given object in the thread local
Expand All @@ -308,10 +311,9 @@ def get_context(object)
# @api private
def store_context(object, context)
if context.nil?
Thread.current[PERSISTENCE_CONTEXT_KEY]&.delete(object.object_id)
context_store.delete(object.object_id)
else
Thread.current[PERSISTENCE_CONTEXT_KEY] ||= {}
Thread.current[PERSISTENCE_CONTEXT_KEY][object.object_id] = context
context_store[object.object_id] = context
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions lib/mongoid/railties/controller_runtime.rb
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _completed e
#
# @return [ Integer ] The runtime value.
def self.runtime
Thread.current[VARIABLE_NAME] ||= 0
Threaded.get(VARIABLE_NAME) { 0 }
end

# Set the runtime value on the current thread.
Expand All @@ -87,7 +87,7 @@ def self.runtime
#
# @return [ Integer ] The runtime value.
def self.runtime= value
Thread.current[VARIABLE_NAME] = value
Threaded.set(VARIABLE_NAME, value)
end

# Reset the runtime value to zero the current thread.
Expand Down
119 changes: 94 additions & 25 deletions lib/mongoid/threaded.rb
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,75 @@ module Threaded

extend self

# Queries the thread-local variable with the given name. If a block is
# given, and the variable does not already exist, the return value of the
# block will be set as the value of the variable before returning it.
#
# It is very important that applications (and espcially Mongoid)
# use this method instead of Thread#[], since Thread#[] is actually for
# fiber-local variables, and Mongoid uses Fibers as an implementation
# detail in some callbacks. Putting thread-local state in a fiber-local
# store will result in the state being invisible when relevant callbacks are
# run in a different fiber.
#
# Affected callbacks are cascading callbacks on embedded children.
#
# @param [ String | Symbol ] key the name of the variable to query
# @param [ Proc ] default an optional block that must return the default
# (initial) value of this variable.
#
# @return [ Object | nil ] the value of the queried variable, or nil if
# it is not set and no default was given.
def get(key, &default)
result = Thread.current.thread_variable_get(key)

if result.nil? && default
result = yield
set(key, result)
end

result
end

# Sets a thread-local variable with the given name to the given value.
# See #get for a discussion of why this method is necessary, and why
# Thread#[]= should be avoided in cascading callbacks on embedded children.
#
# @param [ String | Symbol ] key the name of the variable to set.
# @param [ Object | nil ] value the value of the variable to set (or `nil`
# if you wish to unset the variable)
def set(key, value)
Thread.current.thread_variable_set(key, value)
end

# Removes the named variable from thread-local storage.
#
# @param [ String | Symbol ] key the name of the variable to remove.
def delete(key)
set(key, nil)
end

# Queries the presence of a named variable in thread-local storage.
#
# @param [ String | Symbol ] key the name of the variable to query.
#
# @return [ true | false ] whether the given variable is present or not.
def has?(key)
# Here we have a classic example of JRuby not behaving like MRI. In
# MRI, if you set a thread variable to nil, it removes it from the list
# and subsequent calls to thread_variable?(key) will return false. Not
# so with JRuby. Once set, you cannot unset the thread variable.
#
# However, because setting a variable to nil is supposed to remove it,
# we can assume a nil-valued variable doesn't actually exist.

# So, instead of this:
# Thread.current.thread_variable?(key)

# We have to do this:
!get(key).nil?
end

# Begin entry into a named thread local stack.
#
# @example Begin entry into the stack.
Expand All @@ -56,7 +125,7 @@ def begin_execution(name)
#
# @return [ String | Symbol ] The override.
def database_override
Thread.current[DATABASE_OVERRIDE_KEY]
get(DATABASE_OVERRIDE_KEY)
end

# Set the global database override.
Expand All @@ -68,7 +137,7 @@ def database_override
#
# @return [ String | Symbol ] The override.
def database_override=(name)
Thread.current[DATABASE_OVERRIDE_KEY] = name
set(DATABASE_OVERRIDE_KEY, name)
end

# Are in the middle of executing the named stack
Expand Down Expand Up @@ -104,7 +173,7 @@ def exit_execution(name)
#
# @return [ Array ] The stack.
def stack(name)
Thread.current[STACK_KEYS[name]] ||= []
get(STACK_KEYS[name]) { [] }
end

# Begin autosaving a document on the current thread.
Expand Down Expand Up @@ -178,7 +247,7 @@ def exit_without_default_scope(klass)
#
# @return [ String | Symbol ] The override.
def client_override
Thread.current[CLIENT_OVERRIDE_KEY]
get(CLIENT_OVERRIDE_KEY)
end

# Set the global client override.
Expand All @@ -190,7 +259,7 @@ def client_override
#
# @return [ String | Symbol ] The override.
def client_override=(name)
Thread.current[CLIENT_OVERRIDE_KEY] = name
set(CLIENT_OVERRIDE_KEY, name)
end

# Get the current Mongoid scope.
Expand All @@ -203,12 +272,12 @@ def client_override=(name)
#
# @return [ Criteria ] The scope.
def current_scope(klass = nil)
if klass && Thread.current[CURRENT_SCOPE_KEY].respond_to?(:keys)
Thread.current[CURRENT_SCOPE_KEY][
Thread.current[CURRENT_SCOPE_KEY].keys.find { |k| k <= klass }
]
current_scope = get(CURRENT_SCOPE_KEY)

if klass && current_scope.respond_to?(:keys)
current_scope[current_scope.keys.find { |k| k <= klass }]
else
Thread.current[CURRENT_SCOPE_KEY]
current_scope
end
end

Expand All @@ -221,7 +290,7 @@ def current_scope(klass = nil)
#
# @return [ Criteria ] The scope.
def current_scope=(scope)
Thread.current[CURRENT_SCOPE_KEY] = scope
set(CURRENT_SCOPE_KEY, scope)
end

# Set the current Mongoid scope. Safe for multi-model scope chaining.
Expand All @@ -237,8 +306,8 @@ def set_current_scope(scope, klass)
if scope.nil?
unset_current_scope(klass)
else
Thread.current[CURRENT_SCOPE_KEY] ||= {}
Thread.current[CURRENT_SCOPE_KEY][klass] = scope
current_scope = get(CURRENT_SCOPE_KEY) { {} }
current_scope[klass] = scope
end
end

Expand Down Expand Up @@ -285,7 +354,7 @@ def validated?(document)
#
# @return [ Hash ] The current autosaves.
def autosaves
Thread.current[AUTOSAVES_KEY] ||= {}
get(AUTOSAVES_KEY) { {} }
end

# Get all validations on the current thread.
Expand All @@ -295,7 +364,7 @@ def autosaves
#
# @return [ Hash ] The current validations.
def validations
Thread.current[VALIDATIONS_KEY] ||= {}
get(VALIDATIONS_KEY) { {} }
end

# Get all autosaves on the current thread for the class.
Expand Down Expand Up @@ -389,8 +458,8 @@ def clear_modified_documents(session)
# @return [ true | false ] Whether or not document callbacks should be
# executed by default.
def execute_callbacks?
if Thread.current.key?(EXECUTE_CALLBACKS)
Thread.current[EXECUTE_CALLBACKS]
if has?(EXECUTE_CALLBACKS)
get(EXECUTE_CALLBACKS)
else
true
end
Expand All @@ -403,7 +472,7 @@ def execute_callbacks?
# @param flag [ true | false ] Whether or not document callbacks should be
# executed by default.
def execute_callbacks=(flag)
Thread.current[EXECUTE_CALLBACKS] = flag
set(EXECUTE_CALLBACKS, flag)
end

# Returns the thread store of sessions.
Expand All @@ -412,7 +481,7 @@ def execute_callbacks=(flag)
#
# @api private
def sessions
Thread.current[SESSIONS_KEY] ||= {}.compare_by_identity
get(SESSIONS_KEY) { {}.compare_by_identity }
end

# Returns the thread store of modified documents.
Expand All @@ -422,9 +491,7 @@ def sessions
#
# @api private
def modified_documents
Thread.current[MODIFIED_DOCUMENTS_KEY] ||= Hash.new do |h, k|
h[k] = Set.new
end
get(MODIFIED_DOCUMENTS_KEY) { Hash.new { |h, k| h[k] = Set.new } }
end

private
Expand All @@ -434,10 +501,12 @@ def modified_documents
#
# @param klass [ Class ] the class to remove from the current scope.
def unset_current_scope(klass)
return unless Thread.current[CURRENT_SCOPE_KEY]
return unless has?(CURRENT_SCOPE_KEY)

scope = get(CURRENT_SCOPE_KEY)
scope.delete(klass)

Thread.current[CURRENT_SCOPE_KEY].delete(klass)
Thread.current[CURRENT_SCOPE_KEY] = nil if Thread.current[CURRENT_SCOPE_KEY].empty?
delete(CURRENT_SCOPE_KEY) if scope.empty?
end
end
end
5 changes: 4 additions & 1 deletion lib/mongoid/timestamps/timeless.rb
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,17 @@ def timeless?
class << self
extend Forwardable

# The key to use to store the timeless table
TIMELESS_TABLE_KEY = '[mongoid]:timeless'

# Returns the in-memory thread cache of classes
# for which to skip timestamping.
#
# @return [ Hash ] The timeless table.
#
# @api private
def timeless_table
Thread.current['[mongoid]:timeless'] ||= Hash.new
Threaded.get(TIMELESS_TABLE_KEY) { Hash.new }
end

def_delegators :timeless_table, :[]=, :[]
Expand Down
2 changes: 1 addition & 1 deletion lib/mongoid/touchable.rb
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def touch_callbacks_suppressed?(name)
# @return [ Hash ] The hash that contains touch callback suppression
# statuses
def touch_callback_statuses
Thread.current[SUPPRESS_TOUCH_CALLBACKS_KEY] ||= {}
Threaded.get(SUPPRESS_TOUCH_CALLBACKS_KEY) { {} }
end

# Define the method that will get called for touching belongs_to
Expand Down
12 changes: 12 additions & 0 deletions spec/mongoid/interceptable_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,12 @@ class TestClass
context 'with around callbacks' do
config_override :around_callbacks_for_embeds, true

after do
Mongoid::Threaded.stack('interceptable').clear
end

let(:stack) { Mongoid::Threaded.stack('interceptable') }

let(:expected) do
[
[InterceptableSpec::CbCascadedChild, :before_validation],
Expand Down Expand Up @@ -1824,6 +1830,12 @@ class TestClass
parent.save!
expect(registry.calls).to eq expected
end

it 'shows that cascaded callbacks can access Mongoid state' do
expect(stack).to be_empty
parent.save!
expect(stack).not_to be_empty
end
end

context 'without around callbacks' do
Expand Down
12 changes: 12 additions & 0 deletions spec/mongoid/interceptable_spec_models.rb
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,19 @@ def initialize(callback_registry, options)

attr_accessor :callback_registry

before_save :test_mongoid_state

include CallbackTracking

private

# Helps test that cascading child callbacks have access to the Mongoid
# state objects; if the implementation uses fiber-local (instead of truly
# thread-local) variables, the related tests will fail because the
# cascading child callbacks use fibers to linearize the recursion.
def test_mongoid_state
Mongoid::Threaded.stack('interceptable').push(self)
end
end
end

Expand Down
10 changes: 5 additions & 5 deletions spec/mongoid/threaded_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
context "when the stack has elements" do

before do
Thread.current["[mongoid]:load-stack"] = [ true ]
described_class.stack('load').push(true)
end

after do
Thread.current["[mongoid]:load-stack"] = []
described_class.stack('load').clear
end

it "returns true" do
Expand All @@ -51,7 +51,7 @@
context "when the stack has no elements" do

before do
Thread.current["[mongoid]:load-stack"] = []
described_class.stack('load').clear
end

it "returns false" do
Expand All @@ -76,15 +76,15 @@
context "when a stack has been initialized" do

before do
Thread.current["[mongoid]:load-stack"] = [ true ]
described_class.stack('load').push(true)
end

let(:loading) do
described_class.stack("load")
end

after do
Thread.current["[mongoid]:load-stack"] = []
described_class.stack('load').clear
end

it "returns the stack" do
Expand Down
Loading

0 comments on commit 04a4432

Please sign in to comment.