Skip to content

[MLIR:Python] Fix race on PyOperations. #139721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

hawkinsp
Copy link
Contributor

Joint work with @vfdev-5

We found the following TSAN race report in JAX's CI: jax-ml/jax#28551

WARNING: ThreadSanitizer: data race (pid=35893)
  Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0):
    #0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54)
    #1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d)
    #2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d)
    #3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d)
...

  Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0):
    #0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54)
    #1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012)
    #2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54)
    #3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54)
    #4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
    #5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
    #7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54)
...

At the simplest level, the valid field of a PyOperation must be protected by a lock, because it may be concurrently accessed from multiple threads. Much more interesting, however is how we get into the situation described by the two stack traces above in the first place.

The scenario that triggers this is the following:

  • thread T56 holds the last Python reference on a PyOperation, and decides to release it.
  • After T56 starts to release its reference, but before T56 removes the PyOperation from the liveOperations map a second thread T57 comes along and looks up the same MlirOperation in the liveOperations map.
  • Finding the operation to be present, thread T57 increments the reference count of that PyOperation and returns it to the caller. This is illegal! Python is in the process of calling the destructor of that object, and once an object is in that state it cannot be safely revived.

To fix this, whenever we increment the reference count of a PyOperation that we found via the liveOperations map and to which we only hold a non-owning reference, we must use the Python 3.14+ API PyUnstable_TryIncRef, which exists precisely for this scenario (python/cpython#128844). That API does not exist under Python 3.13, so we need a backport of it in that case, for which we the backport that both nanobind and pybind11 also use.

Fixes jax-ml/jax#28551

@llvmbot
Copy link
Member

llvmbot commented May 13, 2025

@llvm/pr-subscribers-mlir

Author: Peter Hawkins (hawkinsp)

Changes

Joint work with @vfdev-5

We found the following TSAN race report in JAX's CI: jax-ml/jax#28551

WARNING: ThreadSanitizer: data race (pid=35893)
  Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0):
    #<!-- -->0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54)
    #<!-- -->1 mlir::python::populateIRCore(nanobind::module_&amp;)::$_57::operator()(mlir::python::PyOperationBase&amp;) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d)
    #<!-- -->2 _object* nanobind::detail::func_create&lt;true, true, mlir::python::populateIRCore(nanobind::module_&amp;)::$_57&amp;, MlirStringRef, mlir::python::PyOperationBase&amp;, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy&gt;(mlir::python::populateIRCore(nanobind::module_&amp;)::$_57&amp;, MlirStringRef (*)(mlir::python::PyOperationBase&amp;), std::integer_sequence&lt;unsigned long, 0ul&gt;, nanobind::is_method const&amp;, nanobind::is_getter const&amp;, nanobind::rv_policy const&amp;)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d)
    #<!-- -->3 _object* nanobind::detail::func_create&lt;true, true, mlir::python::populateIRCore(nanobind::module_&amp;)::$_57&amp;, MlirStringRef, mlir::python::PyOperationBase&amp;, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy&gt;(mlir::python::populateIRCore(nanobind::module_&amp;)::$_57&amp;, MlirStringRef (*)(mlir::python::PyOperationBase&amp;), std::integer_sequence&lt;unsigned long, 0ul&gt;, nanobind::is_method const&amp;, nanobind::is_getter const&amp;, nanobind::rv_policy const&amp;)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d)
...

  Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0):
    #<!-- -->0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54)
    #<!-- -->1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012)
    #<!-- -->2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54)
    #<!-- -->3 void nanobind::detail::wrap_destruct&lt;mlir::python::PyOperation&gt;(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54)
    #<!-- -->4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
    #<!-- -->5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #<!-- -->6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
    #<!-- -->7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54)
...

At the simplest level, the valid field of a PyOperation must be protected by a lock, because it may be concurrently accessed from multiple threads. Much more interesting, however is how we get into the situation described by the two stack traces above in the first place.

The scenario that triggers this is the following:

  • thread T56 holds the last Python reference on a PyOperation, and decides to release it.
  • After T56 starts to release its reference, but before T56 removes the PyOperation from the liveOperations map a second thread T57 comes along and looks up the same MlirOperation in the liveOperations map.
  • Finding the operation to be present, thread T57 increments the reference count of that PyOperation and returns it to the caller. This is illegal! Python is in the process of calling the destructor of that object, and once an object is in that state it cannot be safely revived.

To fix this, whenever we increment the reference count of a PyOperation that we found via the liveOperations map and to which we only hold a non-owning reference, we must use the Python 3.14+ API PyUnstable_TryIncRef, which exists precisely for this scenario (python/cpython#128844). That API does not exist under Python 3.13, so we need a backport of it in that case, for which we the backport that both nanobind and pybind11 also use.

Fixes jax-ml/jax#28551


Full diff: https://github.com/llvm/llvm-project/pull/139721.diff

3 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+128-33)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+28-7)
  • (modified) mlir/test/python/multithreaded_tests.py (+43)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b5720b7ad8b21..cc5a8bdb9a187 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -635,6 +635,75 @@ class PyOpOperandIterator {
   MlirOpOperand opOperand;
 };
 
+
+
+#if !defined(Py_GIL_DISABLED)
+inline void enableTryIncRef(nb::handle obj) noexcept { }
+inline bool tryIncRef(nb::handle obj) noexcept {
+    if (Py_REFCNT(obj.ptr()) > 0) {
+        Py_INCREF(obj.ptr());
+        return true;
+    }
+    return false;
+}
+
+#elif PY_VERSION_HEX >= 0x030E00A5
+
+// CPython 3.14 provides an unstable API for these.
+inline void enableTryIncRef(nb::handle obj) noexcept {
+  PyUnstable_EnableTryIncRef(obj.ptr());
+}
+inline bool tryIncRef(nb::handle obj) noexcept {
+  return PyUnstable_TryIncRef(obj.ptr());
+}
+
+#else
+
+// For CPython 3.13 there is no API for this, and so we must implement our own.
+// This code originates from https://github.com/wjakob/nanobind/pull/865/files.
+void enableTryIncRef(nb::handle h) noexcept {
+  // Since this is called during object construction, we know that we have
+  // the only reference to the object and can use a non-atomic write.
+  PyObject* obj = h.ptr();
+  assert(h->ob_ref_shared == 0);
+  h->ob_ref_shared = _Py_REF_MAYBE_WEAKREF;
+}
+
+bool tryIncRef(nb::handle h) noexcept {
+  PyObject *obj = h.ptr();
+  // See https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
+  uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local);
+  local += 1;
+  if (local == 0) {
+      // immortal
+      return true;
+  }
+  if (_Py_IsOwnedByCurrentThread(obj)) {
+      _Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local);
+#ifdef Py_REF_DEBUG
+      _Py_INCREF_IncRefTotal();
+#endif
+      return true;
+  }
+  Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared);
+  for (;;) {
+      // If the shared refcount is zero and the object is either merged
+      // or may not have weak references, then we cannot incref it.
+      if (shared == 0 || shared == _Py_REF_MERGED) {
+          return false;
+      }
+
+      if (_Py_atomic_compare_exchange_ssize(
+              &obj->ob_ref_shared, &shared, shared + (1 << _Py_REF_SHARED_SHIFT))) {
+#ifdef Py_REF_DEBUG
+          _Py_INCREF_IncRefTotal();
+#endif
+          return true;
+      }
+  }
+}
+#endif
+
 } // namespace
 
 //------------------------------------------------------------------------------
@@ -706,11 +775,17 @@ size_t PyMlirContext::getLiveOperationCount() {
   return liveOperations.size();
 }
 
-std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
-  std::vector<PyOperation *> liveObjects;
+std::vector<nb::object> PyMlirContext::getLiveOperationObjects() {
+  std::vector<nb::object> liveObjects;
   nb::ft_lock_guard lock(liveOperationsMutex);
-  for (auto &entry : liveOperations)
-    liveObjects.push_back(entry.second.second);
+  for (auto &entry : liveOperations) {
+    // It is not safe to unconditionally increment the reference count here
+    // because an operation that is in the process of being deleted by another
+    // thread may still be present in the map.
+    if (tryIncRef(entry.second.first)) {
+      liveObjects.push_back(nb::steal(entry.second.first));
+    }
+  }
   return liveObjects;
 }
 
@@ -720,25 +795,26 @@ size_t PyMlirContext::clearLiveOperations() {
   {
     nb::ft_lock_guard lock(liveOperationsMutex);
     std::swap(operations, liveOperations);
+    for (auto &op : operations)
+      op.second.second->setInvalidLocked();
   }
-  for (auto &op : operations)
-    op.second.second->setInvalid();
   size_t numInvalidated = operations.size();
   return numInvalidated;
 }
 
-void PyMlirContext::clearOperation(MlirOperation op) {
-  PyOperation *py_op;
-  {
-    nb::ft_lock_guard lock(liveOperationsMutex);
-    auto it = liveOperations.find(op.ptr);
-    if (it == liveOperations.end()) {
-      return;
-    }
-    py_op = it->second.second;
-    liveOperations.erase(it);
+void PyMlirContext::clearOperationLocked(MlirOperation op) {
+  auto it = liveOperations.find(op.ptr);
+  if (it == liveOperations.end()) {
+    return;
   }
-  py_op->setInvalid();
+  PyOperation *py_op = it->second.second;
+  py_op->setInvalidLocked();
+  liveOperations.erase(it);
+}
+
+void PyMlirContext::clearOperation(MlirOperation op) {
+  nb::ft_lock_guard lock(liveOperationsMutex);
+  clearOperationLocked(op);
 }
 
 void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
@@ -770,7 +846,7 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
   MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
                                                       void *userData) {
     PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
-    contextRef->clearOperation(op);
+    contextRef->clearOperationLocked(op);
     return MlirWalkResult::MlirWalkResultAdvance;
   };
   mlirOperationWalk(op.getOperation(), invalidatingCallback,
@@ -1203,6 +1279,8 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
     : BaseContextObject(std::move(contextRef)), operation(operation) {}
 
 PyOperation::~PyOperation() {
+  PyMlirContextRef context = getContext();
+  nb::ft_lock_guard lock(context->liveOperationsMutex);
   // If the operation has already been invalidated there is nothing to do.
   if (!valid)
     return;
@@ -1210,12 +1288,14 @@ PyOperation::~PyOperation() {
   // Otherwise, invalidate the operation and remove it from live map when it is
   // attached.
   if (isAttached()) {
-    getContext()->clearOperation(*this);
+    // Since the operation was valid, we know that it is this object present
+    // in the map, not some other object.
+    context->liveOperations.erase(operation.ptr);
   } else {
     // And destroy it when it is detached, i.e. owned by Python, in which case
     // all nested operations must be invalidated at removed from the live map as
     // well.
-    erase();
+    eraseLocked();
   }
 }
 
@@ -1241,6 +1321,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
   // Create.
   PyOperationRef unownedOperation =
       makeObjectRef<PyOperation>(std::move(contextRef), operation);
+  enableTryIncRef(unownedOperation.getObject());
   unownedOperation->handle = unownedOperation.getObject();
   if (parentKeepAlive) {
     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
@@ -1254,18 +1335,26 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
   nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
   auto &liveOperations = contextRef->liveOperations;
   auto it = liveOperations.find(operation.ptr);
-  if (it == liveOperations.end()) {
-    // Create.
-    PyOperationRef result = createInstance(std::move(contextRef), operation,
-                                           std::move(parentKeepAlive));
-    liveOperations[operation.ptr] =
-        std::make_pair(result.getObject(), result.get());
-    return result;
+  if (it != liveOperations.end()) {
+    PyOperation *existing = it->second.second;
+    nb::handle pyRef = it->second.first;
+
+    // Try to increment the reference count of the existing entry. This can fail
+    // if the object is in the process of being destroyed by another thread.
+    if (tryIncRef(pyRef)) {
+      return PyOperationRef(existing, nb::steal<nb::object>(pyRef));
+    }
+
+    // Mark the existing entry as invalid, since we are about to replace it.
+    existing->valid = false;
   }
-  // Use existing.
-  PyOperation *existing = it->second.second;
-  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
-  return PyOperationRef(existing, std::move(pyRef));
+
+  // Create a new wrapper object.
+  PyOperationRef result = createInstance(std::move(contextRef), operation,
+                                         std::move(parentKeepAlive));
+  liveOperations[operation.ptr] =
+      std::make_pair(result.getObject(), result.get());
+  return result;
 }
 
 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
@@ -1297,6 +1386,7 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
 }
 
 void PyOperation::checkValid() const {
+  nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
   if (!valid) {
     throw std::runtime_error("the operation has been invalidated");
   }
@@ -1638,12 +1728,17 @@ nb::object PyOperation::createOpView() {
   return nb::cast(PyOpView(getRef().getObject()));
 }
 
-void PyOperation::erase() {
+void PyOperation::eraseLocked() {
   checkValid();
   getContext()->clearOperationAndInside(*this);
   mlirOperationDestroy(operation);
 }
 
+void PyOperation::erase() {
+  nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
+  eraseLocked();
+}
+
 namespace {
 /// CRTP base class for Python MLIR values that subclass Value and should be
 /// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2324,7 +2419,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
   // The operation is also erased, so we must invalidate it. There may be Python
   // references to this operation so we don't want to delete it from the list of
   // live operations here.
-  symbol.getOperation().valid = false;
+  symbol.getOperation().setInvalid();
 }
 
 void PySymbolTable::dunderDel(const std::string &name) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9befcce725bb7..ba3ec423196df 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -83,7 +83,7 @@ class PyObjectRef {
   }
 
   T *get() { return referrent; }
-  T *operator->() {
+  T *operator->() const {
     assert(referrent && object);
     return referrent;
   }
@@ -229,7 +229,7 @@ class PyMlirContext {
   static size_t getLiveCount();
 
   /// Get a list of Python objects which are still in the live context map.
-  std::vector<PyOperation *> getLiveOperationObjects();
+  std::vector<nanobind::object> getLiveOperationObjects();
 
   /// Gets the count of live operations associated with this context.
   /// Used for testing.
@@ -254,8 +254,9 @@ class PyMlirContext {
   void clearOperationsInside(PyOperationBase &op);
   void clearOperationsInside(MlirOperation op);
 
-  /// Clears the operaiton _and_ all operations inside using
-  /// `clearOperation(MlirOperation)`.
+  /// Clears the operation _and_ all operations inside using
+  /// `clearOperation(MlirOperation)`. Requires that liveOperations mutex is
+  /// held.
   void clearOperationAndInside(PyOperationBase &op);
 
   /// Gets the count of live modules associated with this context.
@@ -278,6 +279,9 @@ class PyMlirContext {
   struct ErrorCapture;
 
 private:
+  // Similar to clearOperation, but requires the liveOperations mutex to be held
+  void clearOperationLocked(MlirOperation op);
+
   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
   // preserving the relationship that an MlirContext maps to a single
   // PyMlirContext wrapper. This could be replaced in the future with an
@@ -302,6 +306,9 @@ class PyMlirContext {
   // attempt to access it will raise an error.
   using LiveOperationMap =
       llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
+
+  // liveOperationsMutex guards both liveOperations and the valid field of
+  // PyOperation objects in free-threading mode.
   nanobind::ft_mutex liveOperationsMutex;
 
   // Guarded by liveOperationsMutex in free-threading mode.
@@ -336,6 +343,7 @@ class BaseContextObject {
   }
 
   /// Accesses the context reference.
+  const PyMlirContextRef &getContext() const { return contextRef; }
   PyMlirContextRef &getContext() { return contextRef; }
 
 private:
@@ -725,12 +733,19 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   /// parent context's live operations map, and sets the valid bit false.
   void erase();
 
-  /// Invalidate the operation.
-  void setInvalid() { valid = false; }
-
   /// Clones this operation.
   nanobind::object clone(const nanobind::object &ip);
 
+  /// Invalidate the operation.
+  void setInvalid() {
+    nanobind::ft_lock_guard lock(getContext()->liveOperationsMutex);
+    setInvalidLocked();
+  }
+  /// Like setInvalid(), but requires the liveOperations mutex to be held.
+  void setInvalidLocked() {
+    valid = false;
+  }
+
   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
 
 private:
@@ -738,6 +753,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
                                        MlirOperation operation,
                                        nanobind::object parentKeepAlive);
 
+  // Like erase(), but requires the caller to hold the liveOperationsMutex.
+  void eraseLocked();
+
   MlirOperation operation;
   nanobind::handle handle;
   // Keeps the parent alive, regardless of whether it is an Operation or
@@ -748,6 +766,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   // ir_operation.py regarding testing corresponding lifetime guarantees.
   nanobind::object parentKeepAlive;
   bool attached = true;
+
+  // Guarded by 'context->liveOperationsMutex'. Valid objects must be present
+  // in context->liveOperations.
   bool valid = true;
 
   friend class PyOperationBase;
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
index 6e1a668346872..51a09a6a496e1 100644
--- a/mlir/test/python/multithreaded_tests.py
+++ b/mlir/test/python/multithreaded_tests.py
@@ -512,6 +512,49 @@ def _original_test_create_module_with_consts(self):
                 arith.constant(dtype, py_values[2])
 
 
+    def test_check_pyoperation_race(self):
+        # Regression test for a race where:
+        # * one thread is in the process of destroying a PyOperation,
+        # * while simultaneously another thread looks up the PyOperation is
+        #   the liveOperations map and attempts to increase its reference count.
+        # It is illegal to attempt to revive an object that is in the process of
+        # being deleted, and this was producing races and heap use-after-frees.
+        num_workers = 40
+        num_runs = 20
+
+        barrier = threading.Barrier(num_workers)
+
+        def walk_operations(op):
+            _ = op.operation.name
+            for region in op.operation.regions:
+                for block in region:
+                    for op in block:
+                        walk_operations(op)
+
+        with Context():
+            mlir_module = Module.parse(
+                """
+    module @m {
+    func.func public @main(%arg0: tensor<f32>) -> (tensor<f32>) {
+        return %arg0 : tensor<f32>
+    }
+    }
+                """
+            )
+
+        def closure():
+            barrier.wait()
+
+            for _ in range(num_runs):
+                walk_operations(mlir_module)
+
+        with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+            futures = []
+            for i in range(num_workers):
+                futures.append(executor.submit(closure))
+            assert len(list(f.result() for f in futures)) == num_workers
+
+
 if __name__ == "__main__":
     # Do not run the tests on CPython with GIL
     if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled():

Copy link

github-actions bot commented May 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@hawkinsp
Copy link
Contributor Author

I proposed making nanobind's copy of tryIncRef public in wjakob/nanobind#1043.

Another possibility is that it could be added to the pythoncapi-compat project.

Joint work with @vfdev-5

We found the following TSAN race report in JAX's CI:
jax-ml/jax#28551

```
WARNING: ThreadSanitizer: data race (pid=35893)
  Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0):
    #0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54)
    #1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d)
    llvm#2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d)
    llvm#3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d)
...

  Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0):
    #0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54)
    #1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012)
    llvm#2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54)
    llvm#3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54)
    llvm#4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
    llvm#5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    llvm#6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
    llvm#7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54)
...
```

At the simplest level, the `valid` field of a PyOperation must be
protected by a lock, because it may be concurrently accessed from
multiple threads. Much more interesting, however is how we get into the
situation described by the two stack traces above in the first place.

The scenario that triggers this is the following:
* thread T56 holds the last Python reference on a PyOperation, and
  decides to release it.
* After T56 starts to release its reference, but before T56 removes the
  PyOperation from the liveOperations map a second thread T57 comes
  along and looks up the same MlirOperation in the liveOperations map.
* Finding the operation to be present, thread T57 increments the
  reference count of that PyOperation and returns it to the caller.
  This is illegal! Python is in the process of calling the destructor of
  that object, and once an object is in that state it cannot be safely
  revived.

To fix this, whenever we increment the reference count of a PyOperation
that we found via the liveOperations map and to which we only hold a
non-owning reference, we must use the Python 3.14+ API
`PyUnstable_TryIncRef`, which exists precisely for this
scenario (python/cpython#128844). That API
does not exist under Python 3.13, so we need a backport of it in that
case, for which we the backport that both nanobind and pybind11 also
use.

Fixes jax-ml/jax#28551
@makslevental
Copy link
Contributor

liveOperations strikes again!

Regarding

and it's probably good for the world that we don't end up with many copies of the fallback implementation

Is it "n00b" to propose that rathe than copy-pasting the impl from nbbind we just rely on the detail/impl via the same extern? I.e., couldn't we just copy-paste (roughly)

#if PY_VERSION_HEX >= 0x030E00A5
/// Sufficiently recent CPython versions provide an API for the following operations
inline void nb_enable_try_inc_ref(PyObject *obj) noexcept {
    PyUnstable_EnableTryIncRef(obj);
}
inline bool nb_try_inc_ref(PyObject *obj) noexcept {
    return PyUnstable_TryIncRef(obj);
}
#else
/// Otherwise, nanabind ships with a low-level implementation
extern void nb_enable_try_inc_ref(PyObject *) noexcept;
extern bool nb_try_inc_ref(PyObject *obj) noexcept;
#endif

@hawkinsp
Copy link
Contributor Author

liveOperations strikes again!

Regarding

and it's probably good for the world that we don't end up with many copies of the fallback implementation

Is it "n00b" to propose that rathe than copy-pasting the impl from nbbind we just rely on the detail/impl via the same extern? I.e., couldn't we just copy-paste (roughly)

#if PY_VERSION_HEX >= 0x030E00A5
/// Sufficiently recent CPython versions provide an API for the following operations
inline void nb_enable_try_inc_ref(PyObject *obj) noexcept {
    PyUnstable_EnableTryIncRef(obj);
}
inline bool nb_try_inc_ref(PyObject *obj) noexcept {
    return PyUnstable_TryIncRef(obj);
}
#else
/// Otherwise, nanabind ships with a low-level implementation
extern void nb_enable_try_inc_ref(PyObject *) noexcept;
extern bool nb_try_inc_ref(PyObject *obj) noexcept;
#endif

I'd worry about symbol visibility if you did that, given that nanobind didn't intend to make these public.

@makslevental
Copy link
Contributor

I'd worry about symbol visibility if you did that, given that nanobind didn't intend to make these public.

is there a compile/build scenario under which

extern void nb_enable_try_inc_ref(PyObject *) noexcept;
extern bool nb_try_inc_ref(PyObject *obj) noexcept;

are "visible" within the nbbind archive but not outside? Since (presumably) they "weakly" linked in the nbbind archive itself (since they're indeed used through the extern). Especially since in our current use we're statically linking nbbind?

Granted this is now "impl^2" territory (since if we were to switch to shared linking you would be right) but maybe we can guard against breakage (the corresponding shower thought was "is there a way to provide a custom error message in the case that those symbols do fail to resolve").

Ultimately I'm not opposed to copy-pasta - this is just ideation/brainstorming (I will closely review the changes to our code today).

@hawkinsp
Copy link
Contributor Author

Granted this is now "impl^2" territory (since if we were to switch to shared linking you would be right) but maybe we can guard against breakage (the corresponding shower thought was "is there a way to provide a custom error message in the case that those symbols do fail to resolve").

Ultimately I'm not opposed to copy-pasta - this is just ideation/brainstorming (I will closely review the changes to our code today).

I think that there are three possibilities here:

a) The nanobind authors agrees to add it as a nanobind API. We upgrade to that version of nanobind and life is good. This may force us to wait until the next release, even if approved.

Failing that, it's something explicitly private. We shouldn't use other people's private APIs since they can and will change them. This leaves:

b) copy-paste, as this PR does now
c) send a PR adding PyUnstable_TryIncRef to pythoncapi-compat, which is a project that exists for this purpose, add it as an MLIR:Python build dependency, and then just write PyUnstable_TryIncRef in this code.

What do you think?

@makslevental
Copy link
Contributor

a) The nanobind authors agrees to add it as a nanobind API. We upgrade to that version of nanobind and life is good. This may force us to wait until the next release, even if approved.

I've had poor luck prior getting something across the private->public line in nanobind before. No grudge against that guy (his baby, his rules) but I wouldn't be surprised if he just says no 🤷 so let's strike this option.

c) send a PR adding PyUnstable_TryIncRef to pythoncapi-compat, which is a project that exists for this purpose, add it as an MLIR:Python build dependency, and then just write PyUnstable_TryIncRef in this code.

pythoncapi-compat is a good project but I think it's overkill for just this usecase (though maybe we have other lurking? don't remember).

b) copy-paste, as this PR does now

Door #3 it is then!

setInvalidLocked();
}
/// Like setInvalid(), but requires the liveOperations mutex to be held.
void setInvalidLocked() { valid = false; }
Copy link
Contributor

@makslevental makslevental May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb question: is there a way to put a runtime assert/check here that the mutex is actually held? so that there's some way for people that don't read the doc strings (...like me...) to save themselves via compiling with asserts. e.g. i'm wondering if nanobind::ft_mutex::lock() is a no-op if the mutex is already held/locked by the thread?

if that's too tedious/onerous a change than i propose we rename the method to something like setInvalidWhileLocked (although that implies the method will be a no-op "when unlocked", which is not true)

Comment on lines +811 to +812
py_op->setInvalidLocked();
liveOperations.erase(it);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't you rather reorder these? the py_op isn't actually "invalid" until it's erased from liveOperations right?

@@ -770,7 +846,7 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
void *userData) {
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
contextRef->clearOperation(op);
contextRef->clearOperationLocked(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe i'm missing something but i don't see an explicit nb::ft_lock_guard lock(liveOperationsMutex); anywhere around here?

Copy link
Contributor

@makslevental makslevental May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it's because the call chain is ~PyOperation (which takes the lock) -> eraseLocked -> clearOperationAndInside -> clearOperationLocked.

yea these are sharp edges in my opinion - there being no assertions that a lock is actually held. at minimum we should propagate the Locked convention throughout the call chain, i.e. this method should be renamed clearOperationAndInsideLocked (this is of course unfortunate/tedious but i think there's now out of "coloring" the methods)

Comment on lines +1737 to +1741
void PyOperation::erase() {
nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
eraseLocked();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this called anywhere? don't think so

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess it's here for completeness...

}

// Mark the existing entry as invalid, since we are about to replace it.
existing->valid = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: setInvalidLocked instead of direct access

Comment on lines +535 to +543
mlir_module = Module.parse(
"""
module @m {
func.func public @main(%arg0: tensor<f32>) -> (tensor<f32>) {
return %arg0 : tensor<f32>
}
}
"""
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ultra-nit

Suggested change
mlir_module = Module.parse(
"""
module @m {
func.func public @main(%arg0: tensor<f32>) -> (tensor<f32>) {
return %arg0 : tensor<f32>
}
}
"""
)
mlir_module = Module.parse(
dedent(
"""
module @m {
func.func public @main(%arg0: tensor<f32>) -> (tensor<f32>) {
return %arg0 : tensor<f32>
}
}
"""
)
)

will need from textwrap import dedent at the top

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low-level this looks good (modulo nits). high-level it's a little hard to keep the threads in mind (no pun intended) because the mutex is now pulling double-duty in guarding both valid and liveOperations (and also the relationship to incref). presumably you guys tested this fix in your CI and it worked (resolved the race) but a procedural question/reminder for myself: do we test FT mode here (ie in llvm-project CI)? also I believe we test ASAN (I think we do - there were definitely PRs a while back that enabled that for Python by @rkayaith) but do we test TSAN?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Possible data race in PyOperation and ~PyOperation on cached Module
3 participants