@@ -139,6 +139,8 @@ def _is_ready(fut: Future) -> None:
139
139
while read < length :
140
140
try :
141
141
read += conn .recv_into (mv [read :])
142
+ if read == 0 :
143
+ raise OSError ("connection closed" )
142
144
except BLOCKING_IO_ERRORS as exc :
143
145
fd = conn .fileno ()
144
146
# Check for closed socket.
@@ -195,11 +197,20 @@ def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
195
197
sock .sendall (buf )
196
198
197
199
200
+ async def _poll_cancellation (conn : AsyncConnection ) -> None :
201
+ while True :
202
+ if conn .cancel_context .cancelled :
203
+ return
204
+
205
+ await asyncio .sleep (_POLL_TIMEOUT )
206
+
207
+
198
208
async def async_receive_data (
199
209
conn : AsyncConnection , length : int , deadline : Optional [float ]
200
210
) -> memoryview :
201
211
sock = conn .conn
202
212
sock_timeout = sock .gettimeout ()
213
+ timeout : Optional [Union [float , int ]]
203
214
if deadline :
204
215
# When the timeout has expired perform one final check to
205
216
# see if the socket is readable. This helps avoid spurious
@@ -210,14 +221,22 @@ async def async_receive_data(
210
221
211
222
sock .settimeout (0.0 )
212
223
loop = asyncio .get_event_loop ()
224
+ cancellation_task = asyncio .create_task (_poll_cancellation (conn ))
213
225
try :
214
226
if _HAVE_SSL and isinstance (sock , (SSLSocket , _sslConn )):
215
- return await asyncio .wait_for (_async_receive_ssl (sock , length , loop ), timeout = timeout )
227
+ read_task = asyncio .create_task (_async_receive_ssl (sock , length , loop )) # type: ignore[arg-type]
216
228
else :
217
- return await asyncio .wait_for (_async_receive (sock , length , loop ), timeout = timeout ) # type: ignore[arg-type]
218
- except asyncio .TimeoutError as exc :
219
- # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
220
- raise socket .timeout ("timed out" ) from exc
229
+ read_task = asyncio .create_task (_async_receive (sock , length , loop )) # type: ignore[arg-type]
230
+ tasks = [read_task , cancellation_task ]
231
+ result = await asyncio .wait (tasks , timeout = timeout , return_when = asyncio .FIRST_COMPLETED )
232
+ if len (result [1 ]) == 2 :
233
+ raise socket .timeout ("timed out" )
234
+ finished = next (iter (result [0 ]))
235
+ next (iter (result [1 ])).cancel ()
236
+ if finished == read_task :
237
+ return finished .result () # type: ignore[return-value]
238
+ else :
239
+ raise _OperationCancelled ("operation cancelled" )
221
240
finally :
222
241
sock .settimeout (sock_timeout )
223
242
0 commit comments