diff --git a/xla/python/pytree.cc b/xla/python/pytree.cc index a374c2df6bff9..d5799b8695cb7 100644 --- a/xla/python/pytree.cc +++ b/xla/python/pytree.cc @@ -96,6 +96,7 @@ void PyTreeRegistry::Register( registration->to_iterable = std::move(to_iterable); registration->from_iterable = std::move(from_iterable); registration->to_iterable_with_keys = std::move(to_iterable_with_keys); + nb::ft_lock_guard lock(mu_); auto it = registrations_.emplace(type, std::move(registration)); if (!it.second) { throw std::invalid_argument( @@ -112,6 +113,7 @@ void PyTreeRegistry::RegisterDataclass(nb::object type, registration->type = type; registration->data_fields = std::move(data_fields); registration->meta_fields = std::move(meta_fields); + nb::ft_lock_guard lock(mu_); auto it = registrations_.emplace(type, std::move(registration)); if (!it.second) { throw std::invalid_argument(absl::StrFormat( @@ -222,6 +224,7 @@ PyTreeKind PyTreeRegistry::KindOfObject( /*static*/ const PyTreeRegistry::Registration* PyTreeRegistry::Lookup( nb::handle type) const { + nb::ft_lock_guard lock(mu_); auto it = registrations_.find(type); return it == registrations_.end() ? nullptr : it->second.get(); } @@ -419,6 +422,7 @@ nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, void* arg) { PyTreeRegistry* registry = nb::inst_ptr(self); Py_VISIT(Py_TYPE(self)); + nb::ft_lock_guard lock(registry->mu_); for (const auto& [key, value] : registry->registrations_) { Py_VISIT(key.ptr()); int rval = value->tp_traverse(visit, arg); @@ -431,6 +435,7 @@ nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, /* static */ int PyTreeRegistry::tp_clear(PyObject* self) { PyTreeRegistry* registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); registry->registrations_.clear(); return 0; } diff --git a/xla/python/pytree.h b/xla/python/pytree.h index fc16fdd40136c..f526893d8dc81 100644 --- a/xla/python/pytree.h +++ b/xla/python/pytree.h @@ -143,9 +143,10 @@ class PyTreeRegistry { return a.ptr() == b.ptr(); } }; + mutable nanobind::ft_mutex mu_; absl::flat_hash_map, TypeHash, TypeEq> - registrations_; + registrations_; // Guarded by mu_ bool enable_namedtuple_; static int tp_traverse(PyObject* self, visitproc visit, void* arg);