Skip to content

Commit

Permalink
[XLA:Python] Add locking around the pytree registry in free threading…
Browse files Browse the repository at this point in the history
… mode.

Fixes tsan races from JAX test suite under free threading.

PiperOrigin-RevId: 714793284
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Jan 13, 2025
1 parent 35cc516 commit 822c9b6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 5 additions & 0 deletions xla/python/pytree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ void PyTreeRegistry::Register(
registration->to_iterable = std::move(to_iterable);
registration->from_iterable = std::move(from_iterable);
registration->to_iterable_with_keys = std::move(to_iterable_with_keys);
nb::ft_lock_guard lock(mu_);
auto it = registrations_.emplace(type, std::move(registration));
if (!it.second) {
throw std::invalid_argument(
Expand All @@ -112,6 +113,7 @@ void PyTreeRegistry::RegisterDataclass(nb::object type,
registration->type = type;
registration->data_fields = std::move(data_fields);
registration->meta_fields = std::move(meta_fields);
nb::ft_lock_guard lock(mu_);
auto it = registrations_.emplace(type, std::move(registration));
if (!it.second) {
throw std::invalid_argument(absl::StrFormat(
Expand Down Expand Up @@ -222,6 +224,7 @@ PyTreeKind PyTreeRegistry::KindOfObject(

/*static*/ const PyTreeRegistry::Registration* PyTreeRegistry::Lookup(
nb::handle type) const {
nb::ft_lock_guard lock(mu_);
auto it = registrations_.find(type);
return it == registrations_.end() ? nullptr : it->second.get();
}
Expand Down Expand Up @@ -419,6 +422,7 @@ nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x,
void* arg) {
PyTreeRegistry* registry = nb::inst_ptr<PyTreeRegistry>(self);
Py_VISIT(Py_TYPE(self));
nb::ft_lock_guard lock(registry->mu_);
for (const auto& [key, value] : registry->registrations_) {
Py_VISIT(key.ptr());
int rval = value->tp_traverse(visit, arg);
Expand All @@ -431,6 +435,7 @@ nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x,

/* static */ int PyTreeRegistry::tp_clear(PyObject* self) {
PyTreeRegistry* registry = nb::inst_ptr<PyTreeRegistry>(self);
nb::ft_lock_guard lock(registry->mu_);
registry->registrations_.clear();
return 0;
}
Expand Down
3 changes: 2 additions & 1 deletion xla/python/pytree.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ class PyTreeRegistry {
return a.ptr() == b.ptr();
}
};
mutable nanobind::ft_mutex mu_;
absl::flat_hash_map<nanobind::object, std::unique_ptr<Registration>, TypeHash,
TypeEq>
registrations_;
registrations_; // Guarded by mu_
bool enable_namedtuple_;

static int tp_traverse(PyObject* self, visitproc visit, void* arg);
Expand Down

0 comments on commit 822c9b6

Please sign in to comment.