14
14
from __future__ import annotations
15
15
16
16
import asyncio
17
+ import collections
17
18
import os
18
19
import threading
19
20
import time
20
21
import weakref
21
- from typing import Any , Callable , Optional
22
+ from typing import Any , Callable , Optional , TypeVar
22
23
23
24
_HAS_REGISTER_AT_FORK = hasattr (os , "register_at_fork" )
24
25
25
26
# References to instances of _create_lock
26
27
_forkable_locks : weakref .WeakSet [threading .Lock ] = weakref .WeakSet ()
27
28
29
+ _T = TypeVar ("_T" )
30
+
28
31
29
32
def _create_lock () -> threading .Lock :
30
33
"""Represents a lock that is tracked upon instantiation using a WeakSet and
@@ -43,7 +46,14 @@ def _release_locks() -> None:
43
46
lock .release ()
44
47
45
48
49
+ # Needed only for synchro.py compat.
50
+ def _Lock (lock : threading .Lock ) -> threading .Lock :
51
+ return lock
52
+
53
+
46
54
class _ALock :
55
+ __slots__ = ("_lock" ,)
56
+
47
57
def __init__ (self , lock : threading .Lock ) -> None :
48
58
self ._lock = lock
49
59
@@ -81,9 +91,18 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
81
91
self .release ()
82
92
83
93
94
+ def _safe_set_result (fut : asyncio .Future ) -> None :
95
+ # Ensure the future hasn't been cancelled before calling set_result.
96
+ if not fut .done ():
97
+ fut .set_result (False )
98
+
99
+
84
100
class _ACondition :
101
+ __slots__ = ("_condition" , "_waiters" )
102
+
85
103
def __init__ (self , condition : threading .Condition ) -> None :
86
104
self ._condition = condition
105
+ self ._waiters : collections .deque = collections .deque ()
87
106
88
107
async def acquire (self , blocking : bool = True , timeout : float = - 1 ) -> bool :
89
108
if timeout > 0 :
@@ -99,30 +118,116 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
99
118
await asyncio .sleep (0 )
100
119
101
120
async def wait (self , timeout : Optional [float ] = None ) -> bool :
102
- if timeout is not None :
103
- tstart = time .monotonic ()
104
- while True :
105
- notified = self ._condition .wait (0.001 )
106
- if notified :
107
- return True
108
- if timeout is not None and (time .monotonic () - tstart ) > timeout :
109
- return False
110
-
111
- async def wait_for (self , predicate : Callable , timeout : Optional [float ] = None ) -> bool :
112
- if timeout is not None :
113
- tstart = time .monotonic ()
114
- while True :
115
- notified = self ._condition .wait_for (predicate , 0.001 )
116
- if notified :
117
- return True
118
- if timeout is not None and (time .monotonic () - tstart ) > timeout :
119
- return False
121
+ """Wait until notified.
122
+
123
+ If the calling task has not acquired the lock when this
124
+ method is called, a RuntimeError is raised.
125
+
126
+ This method releases the underlying lock, and then blocks
127
+ until it is awakened by a notify() or notify_all() call for
128
+ the same condition variable in another task. Once
129
+ awakened, it re-acquires the lock and returns True.
130
+
131
+ This method may return spuriously,
132
+ which is why the caller should always
133
+ re-check the state and be prepared to wait() again.
134
+ """
135
+ loop = asyncio .get_running_loop ()
136
+ fut = loop .create_future ()
137
+ self ._waiters .append ((loop , fut ))
138
+ self .release ()
139
+ try :
140
+ try :
141
+ try :
142
+ await asyncio .wait_for (fut , timeout )
143
+ return True
144
+ except asyncio .TimeoutError :
145
+ return False # Return false on timeout for sync pool compat.
146
+ finally :
147
+ # Must re-acquire lock even if wait is cancelled.
148
+ # We only catch CancelledError here, since we don't want any
149
+ # other (fatal) errors with the future to cause us to spin.
150
+ err = None
151
+ while True :
152
+ try :
153
+ await self .acquire ()
154
+ break
155
+ except asyncio .exceptions .CancelledError as e :
156
+ err = e
157
+
158
+ self ._waiters .remove ((loop , fut ))
159
+ if err is not None :
160
+ try :
161
+ raise err # Re-raise most recent exception instance.
162
+ finally :
163
+ err = None # Break reference cycles.
164
+ except BaseException :
165
+ # Any error raised out of here _may_ have occurred after this Task
166
+ # believed to have been successfully notified.
167
+ # Make sure to notify another Task instead. This may result
168
+ # in a "spurious wakeup", which is allowed as part of the
169
+ # Condition Variable protocol.
170
+ self .notify (1 )
171
+ raise
172
+
173
+ async def wait_for (self , predicate : Callable [[], _T ]) -> _T :
174
+ """Wait until a predicate becomes true.
175
+
176
+ The predicate should be a callable whose result will be
177
+ interpreted as a boolean value. The method will repeatedly
178
+ wait() until it evaluates to true. The final predicate value is
179
+ the return value.
180
+ """
181
+ result = predicate ()
182
+ while not result :
183
+ await self .wait ()
184
+ result = predicate ()
185
+ return result
120
186
121
187
def notify (self , n : int = 1 ) -> None :
122
- self ._condition .notify (n )
188
+ """By default, wake up one coroutine waiting on this condition, if any.
189
+ If the calling coroutine has not acquired the lock when this method
190
+ is called, a RuntimeError is raised.
191
+
192
+ This method wakes up at most n of the coroutines waiting for the
193
+ condition variable; it is a no-op if no coroutines are waiting.
194
+
195
+ Note: an awakened coroutine does not actually return from its
196
+ wait() call until it can reacquire the lock. Since notify() does
197
+ not release the lock, its caller should.
198
+ """
199
+ idx = 0
200
+ to_remove = []
201
+ for loop , fut in self ._waiters :
202
+ if idx >= n :
203
+ break
204
+
205
+ if fut .done ():
206
+ continue
207
+
208
+ try :
209
+ loop .call_soon_threadsafe (_safe_set_result , fut )
210
+ except RuntimeError :
211
+ # Loop was closed, ignore.
212
+ to_remove .append ((loop , fut ))
213
+ continue
214
+
215
+ idx += 1
216
+
217
+ for waiter in to_remove :
218
+ self ._waiters .remove (waiter )
123
219
124
220
def notify_all (self ) -> None :
125
- self ._condition .notify_all ()
221
+ """Wake up all threads waiting on this condition. This method acts
222
+ like notify(), but wakes up all waiting threads instead of one. If the
223
+ calling thread has not acquired the lock when this method is called,
224
+ a RuntimeError is raised.
225
+ """
226
+ self .notify (len (self ._waiters ))
227
+
228
+ def locked (self ) -> bool :
229
+ """Only needed for tests in test_locks."""
230
+ return self ._condition ._lock .locked () # type: ignore[attr-defined]
126
231
127
232
def release (self ) -> None :
128
233
self ._condition .release ()
0 commit comments