Skip to content

Commit 6a7fae1

Browse files
ShaneHarveyblink1073
authored andcommitted
PYTHON-4782 Fix deadlock and blocking behavior in _ACondition.wait (#1875)
(cherry picked from commit 821811e)
1 parent d712bc1 commit 6a7fae1

14 files changed

+693
-48
lines changed

pymongo/asynchronous/pool.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,8 @@ def __init__(
992992
# from the right side.
993993
self.conns: collections.deque = collections.deque()
994994
self.active_contexts: set[_CancellationContext] = set()
995-
self.lock = _ALock(_create_lock())
995+
_lock = _create_lock()
996+
self.lock = _ALock(_lock)
996997
self.active_sockets = 0
997998
# Monotonically increasing connection ID required for CMAP Events.
998999
self.next_connection_id = 1
@@ -1018,15 +1019,15 @@ def __init__(
10181019
# The first portion of the wait queue.
10191020
# Enforces: maxPoolSize
10201021
# Also used for: clearing the wait queue
1021-
self.size_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
1022+
self.size_cond = _ACondition(threading.Condition(_lock))
10221023
self.requests = 0
10231024
self.max_pool_size = self.opts.max_pool_size
10241025
if not self.max_pool_size:
10251026
self.max_pool_size = float("inf")
10261027
# The second portion of the wait queue.
10271028
# Enforces: maxConnecting
10281029
# Also used for: clearing the wait queue
1029-
self._max_connecting_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
1030+
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
10301031
self._max_connecting = self.opts.max_connecting
10311032
self._pending = 0
10321033
self._client_id = client_id

pymongo/asynchronous/topology.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings):
170170
self._seed_addresses = list(topology_description.server_descriptions())
171171
self._opened = False
172172
self._closed = False
173-
self._lock = _ALock(_create_lock())
174-
self._condition = _ACondition(self._settings.condition_class(self._lock)) # type: ignore[arg-type]
173+
_lock = _create_lock()
174+
self._lock = _ALock(_lock)
175+
self._condition = _ACondition(self._settings.condition_class(_lock))
175176
self._servers: dict[_Address, Server] = {}
176177
self._pid: Optional[int] = None
177178
self._max_cluster_time: Optional[ClusterTime] = None

pymongo/lock.py

+126-21
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,20 @@
1414
from __future__ import annotations
1515

1616
import asyncio
17+
import collections
1718
import os
1819
import threading
1920
import time
2021
import weakref
21-
from typing import Any, Callable, Optional
22+
from typing import Any, Callable, Optional, TypeVar
2223

2324
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
2425

2526
# References to instances of _create_lock
2627
_forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
2728

29+
_T = TypeVar("_T")
30+
2831

2932
def _create_lock() -> threading.Lock:
3033
"""Represents a lock that is tracked upon instantiation using a WeakSet and
@@ -43,7 +46,14 @@ def _release_locks() -> None:
4346
lock.release()
4447

4548

49+
# Needed only for synchro.py compat.
50+
def _Lock(lock: threading.Lock) -> threading.Lock:
51+
return lock
52+
53+
4654
class _ALock:
55+
__slots__ = ("_lock",)
56+
4757
def __init__(self, lock: threading.Lock) -> None:
4858
self._lock = lock
4959

@@ -81,9 +91,18 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
8191
self.release()
8292

8393

94+
def _safe_set_result(fut: asyncio.Future) -> None:
95+
# Ensure the future hasn't been cancelled before calling set_result.
96+
if not fut.done():
97+
fut.set_result(False)
98+
99+
84100
class _ACondition:
101+
__slots__ = ("_condition", "_waiters")
102+
85103
def __init__(self, condition: threading.Condition) -> None:
86104
self._condition = condition
105+
self._waiters: collections.deque = collections.deque()
87106

88107
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
89108
if timeout > 0:
@@ -99,30 +118,116 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
99118
await asyncio.sleep(0)
100119

101120
async def wait(self, timeout: Optional[float] = None) -> bool:
102-
if timeout is not None:
103-
tstart = time.monotonic()
104-
while True:
105-
notified = self._condition.wait(0.001)
106-
if notified:
107-
return True
108-
if timeout is not None and (time.monotonic() - tstart) > timeout:
109-
return False
110-
111-
async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool:
112-
if timeout is not None:
113-
tstart = time.monotonic()
114-
while True:
115-
notified = self._condition.wait_for(predicate, 0.001)
116-
if notified:
117-
return True
118-
if timeout is not None and (time.monotonic() - tstart) > timeout:
119-
return False
121+
"""Wait until notified.
122+
123+
If the calling task has not acquired the lock when this
124+
method is called, a RuntimeError is raised.
125+
126+
This method releases the underlying lock, and then blocks
127+
until it is awakened by a notify() or notify_all() call for
128+
the same condition variable in another task. Once
129+
awakened, it re-acquires the lock and returns True.
130+
131+
This method may return spuriously,
132+
which is why the caller should always
133+
re-check the state and be prepared to wait() again.
134+
"""
135+
loop = asyncio.get_running_loop()
136+
fut = loop.create_future()
137+
self._waiters.append((loop, fut))
138+
self.release()
139+
try:
140+
try:
141+
try:
142+
await asyncio.wait_for(fut, timeout)
143+
return True
144+
except asyncio.TimeoutError:
145+
return False # Return false on timeout for sync pool compat.
146+
finally:
147+
# Must re-acquire lock even if wait is cancelled.
148+
# We only catch CancelledError here, since we don't want any
149+
# other (fatal) errors with the future to cause us to spin.
150+
err = None
151+
while True:
152+
try:
153+
await self.acquire()
154+
break
155+
except asyncio.exceptions.CancelledError as e:
156+
err = e
157+
158+
self._waiters.remove((loop, fut))
159+
if err is not None:
160+
try:
161+
raise err # Re-raise most recent exception instance.
162+
finally:
163+
err = None # Break reference cycles.
164+
except BaseException:
165+
# Any error raised out of here _may_ have occurred after this Task
166+
# believed to have been successfully notified.
167+
# Make sure to notify another Task instead. This may result
168+
# in a "spurious wakeup", which is allowed as part of the
169+
# Condition Variable protocol.
170+
self.notify(1)
171+
raise
172+
173+
async def wait_for(self, predicate: Callable[[], _T]) -> _T:
174+
"""Wait until a predicate becomes true.
175+
176+
The predicate should be a callable whose result will be
177+
interpreted as a boolean value. The method will repeatedly
178+
wait() until it evaluates to true. The final predicate value is
179+
the return value.
180+
"""
181+
result = predicate()
182+
while not result:
183+
await self.wait()
184+
result = predicate()
185+
return result
120186

121187
def notify(self, n: int = 1) -> None:
122-
self._condition.notify(n)
188+
"""By default, wake up one coroutine waiting on this condition, if any.
189+
If the calling coroutine has not acquired the lock when this method
190+
is called, a RuntimeError is raised.
191+
192+
This method wakes up at most n of the coroutines waiting for the
193+
condition variable; it is a no-op if no coroutines are waiting.
194+
195+
Note: an awakened coroutine does not actually return from its
196+
wait() call until it can reacquire the lock. Since notify() does
197+
not release the lock, its caller should.
198+
"""
199+
idx = 0
200+
to_remove = []
201+
for loop, fut in self._waiters:
202+
if idx >= n:
203+
break
204+
205+
if fut.done():
206+
continue
207+
208+
try:
209+
loop.call_soon_threadsafe(_safe_set_result, fut)
210+
except RuntimeError:
211+
# Loop was closed, ignore.
212+
to_remove.append((loop, fut))
213+
continue
214+
215+
idx += 1
216+
217+
for waiter in to_remove:
218+
self._waiters.remove(waiter)
123219

124220
def notify_all(self) -> None:
125-
self._condition.notify_all()
221+
"""Wake up all threads waiting on this condition. This method acts
222+
like notify(), but wakes up all waiting threads instead of one. If the
223+
calling thread has not acquired the lock when this method is called,
224+
a RuntimeError is raised.
225+
"""
226+
self.notify(len(self._waiters))
227+
228+
def locked(self) -> bool:
229+
"""Only needed for tests in test_locks."""
230+
return self._condition._lock.locked() # type: ignore[attr-defined]
126231

127232
def release(self) -> None:
128233
self._condition.release()

pymongo/synchronous/pool.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
_CertificateError,
6363
)
6464
from pymongo.hello import Hello, HelloCompat
65-
from pymongo.lock import _create_lock
65+
from pymongo.lock import _create_lock, _Lock
6666
from pymongo.logger import (
6767
_CONNECTION_LOGGER,
6868
_ConnectionStatusMessage,
@@ -988,7 +988,8 @@ def __init__(
988988
# from the right side.
989989
self.conns: collections.deque = collections.deque()
990990
self.active_contexts: set[_CancellationContext] = set()
991-
self.lock = _create_lock()
991+
_lock = _create_lock()
992+
self.lock = _Lock(_lock)
992993
self.active_sockets = 0
993994
# Monotonically increasing connection ID required for CMAP Events.
994995
self.next_connection_id = 1
@@ -1014,15 +1015,15 @@ def __init__(
10141015
# The first portion of the wait queue.
10151016
# Enforces: maxPoolSize
10161017
# Also used for: clearing the wait queue
1017-
self.size_cond = threading.Condition(self.lock) # type: ignore[arg-type]
1018+
self.size_cond = threading.Condition(_lock)
10181019
self.requests = 0
10191020
self.max_pool_size = self.opts.max_pool_size
10201021
if not self.max_pool_size:
10211022
self.max_pool_size = float("inf")
10221023
# The second portion of the wait queue.
10231024
# Enforces: maxConnecting
10241025
# Also used for: clearing the wait queue
1025-
self._max_connecting_cond = threading.Condition(self.lock) # type: ignore[arg-type]
1026+
self._max_connecting_cond = threading.Condition(_lock)
10261027
self._max_connecting = self.opts.max_connecting
10271028
self._pending = 0
10281029
self._client_id = client_id

pymongo/synchronous/topology.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
WriteError,
4040
)
4141
from pymongo.hello import Hello
42-
from pymongo.lock import _create_lock
42+
from pymongo.lock import _create_lock, _Lock
4343
from pymongo.logger import (
4444
_SDAM_LOGGER,
4545
_SERVER_SELECTION_LOGGER,
@@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings):
170170
self._seed_addresses = list(topology_description.server_descriptions())
171171
self._opened = False
172172
self._closed = False
173-
self._lock = _create_lock()
174-
self._condition = self._settings.condition_class(self._lock) # type: ignore[arg-type]
173+
_lock = _create_lock()
174+
self._lock = _Lock(_lock)
175+
self._condition = self._settings.condition_class(_lock)
175176
self._servers: dict[_Address, Server] = {}
176177
self._pid: Optional[int] = None
177178
self._max_cluster_time: Optional[ClusterTime] = None

test/asynchronous/test_client.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2351,7 +2351,9 @@ async def test_reconnect(self):
23512351

23522352
# But it can reconnect.
23532353
c.revive_host("a:1")
2354-
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
2354+
await (await c._get_topology()).select_servers(
2355+
writable_server_selector, _Op.TEST, server_selection_timeout=10
2356+
)
23552357
self.assertEqual(await c.address, ("a", 1))
23562358

23572359
async def _test_network_error(self, operation_callback):

test/asynchronous/test_client_bulk_write.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from unittest.mock import patch
3232

33+
import pymongo
3334
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
3435
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
3536
from pymongo.errors import (
@@ -597,7 +598,9 @@ async def test_timeout_in_multi_batch_bulk_write(self):
597598
timeoutMS=2000,
598599
w="majority",
599600
)
600-
await client.admin.command("ping") # Init the client first.
601+
# Initialize the client with a larger timeout to help make test less flakey
602+
with pymongo.timeout(10):
603+
await client.admin.command("ping")
601604
with self.assertRaises(ClientBulkWriteException) as context:
602605
await client.bulk_write(models=models)
603606
self.assertIsInstance(context.exception.error, NetworkTimeout)

test/asynchronous/test_cursor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1414,7 +1414,7 @@ async def test_to_list_length(self):
14141414
async def test_to_list_csot_applied(self):
14151415
client = await self.async_single_client(timeoutMS=500)
14161416
# Initialize the client with a larger timeout to help make test less flakey
1417-
with pymongo.timeout(2):
1417+
with pymongo.timeout(10):
14181418
await client.admin.command("ping")
14191419
coll = client.pymongo.test
14201420
await coll.insert_many([{} for _ in range(5)])
@@ -1456,7 +1456,7 @@ async def test_command_cursor_to_list_length(self):
14561456
async def test_command_cursor_to_list_csot_applied(self):
14571457
client = await self.async_single_client(timeoutMS=500)
14581458
# Initialize the client with a larger timeout to help make test less flakey
1459-
with pymongo.timeout(2):
1459+
with pymongo.timeout(10):
14601460
await client.admin.command("ping")
14611461
coll = client.pymongo.test
14621462
await coll.insert_many([{} for _ in range(5)])

0 commit comments

Comments
 (0)