Skip to content

Commit 7f71430

Browse files
committed
Add async support for cancellation
1 parent 3d399da commit 7f71430

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

pymongo/network_layer.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def _is_ready(fut: Future) -> None:
139139
while read < length:
140140
try:
141141
read += conn.recv_into(mv[read:])
142+
if read == 0:
143+
raise OSError("connection closed")
142144
except BLOCKING_IO_ERRORS as exc:
143145
fd = conn.fileno()
144146
# Check for closed socket.
@@ -195,11 +197,20 @@ def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
195197
sock.sendall(buf)
196198

197199

200+
async def _poll_cancellation(conn: AsyncConnection) -> None:
201+
while True:
202+
if conn.cancel_context.cancelled:
203+
return
204+
205+
await asyncio.sleep(_POLL_TIMEOUT)
206+
207+
198208
async def async_receive_data(
199209
conn: AsyncConnection, length: int, deadline: Optional[float]
200210
) -> memoryview:
201211
sock = conn.conn
202212
sock_timeout = sock.gettimeout()
213+
timeout: Optional[Union[float, int]]
203214
if deadline:
204215
# When the timeout has expired perform one final check to
205216
# see if the socket is readable. This helps avoid spurious
@@ -210,14 +221,22 @@ async def async_receive_data(
210221

211222
sock.settimeout(0.0)
212223
loop = asyncio.get_event_loop()
224+
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
213225
try:
214226
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
215-
return await asyncio.wait_for(_async_receive_ssl(sock, length, loop), timeout=timeout)
227+
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
216228
else:
217-
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
218-
except asyncio.TimeoutError as exc:
219-
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
220-
raise socket.timeout("timed out") from exc
229+
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
230+
tasks = [read_task, cancellation_task]
231+
result = await asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
232+
if len(result[1]) == 2:
233+
raise socket.timeout("timed out")
234+
finished = next(iter(result[0]))
235+
next(iter(result[1])).cancel()
236+
if finished == read_task:
237+
return finished.result() # type: ignore[return-value]
238+
else:
239+
raise _OperationCancelled("operation cancelled")
221240
finally:
222241
sock.settimeout(sock_timeout)
223242

0 commit comments

Comments
 (0)