22 'StreamReader' , 'StreamWriter' , 'StreamReaderProtocol' ,
33 'open_connection' , 'start_server' )
44
5+ import collections
56import socket
67import sys
78import warnings
@@ -129,7 +130,7 @@ def __init__(self, loop=None):
129130 else :
130131 self ._loop = loop
131132 self ._paused = False
132- self ._drain_waiter = None
133+ self ._drain_waiters = collections . deque ()
133134 self ._connection_lost = False
134135
135136 def pause_writing (self ):
@@ -144,38 +145,34 @@ def resume_writing(self):
144145 if self ._loop .get_debug ():
145146 logger .debug ("%r resumes writing" , self )
146147
147- waiter = self ._drain_waiter
148- if waiter is not None :
149- self ._drain_waiter = None
148+ for waiter in self ._drain_waiters :
150149 if not waiter .done ():
151150 waiter .set_result (None )
152151
153152 def connection_lost (self , exc ):
154153 self ._connection_lost = True
155- # Wake up the writer if currently paused.
154+ # Wake up the writer(s) if currently paused.
156155 if not self ._paused :
157156 return
158- waiter = self ._drain_waiter
159- if waiter is None :
160- return
161- self ._drain_waiter = None
162- if waiter .done ():
163- return
164- if exc is None :
165- waiter .set_result (None )
166- else :
167- waiter .set_exception (exc )
157+
158+ for waiter in self ._drain_waiters :
159+ if not waiter .done ():
160+ if exc is None :
161+ waiter .set_result (None )
162+ else :
163+ waiter .set_exception (exc )
168164
169165 async def _drain_helper (self ):
170166 if self ._connection_lost :
171167 raise ConnectionResetError ('Connection lost' )
172168 if not self ._paused :
173169 return
174- waiter = self ._drain_waiter
175- assert waiter is None or waiter .cancelled ()
176170 waiter = self ._loop .create_future ()
177- self ._drain_waiter = waiter
178- await waiter
171+ self ._drain_waiters .append (waiter )
172+ try :
173+ await waiter
174+ finally :
175+ self ._drain_waiters .remove (waiter )
179176
180177 def _get_close_waiter (self , stream ):
181178 raise NotImplementedError
@@ -206,6 +203,7 @@ def __init__(self, stream_reader, client_connected_cb=None, loop=None):
206203 self ._strong_reader = stream_reader
207204 self ._reject_connection = False
208205 self ._stream_writer = None
206+ self ._task = None
209207 self ._transport = None
210208 self ._client_connected_cb = client_connected_cb
211209 self ._over_ssl = False
@@ -241,7 +239,7 @@ def connection_made(self, transport):
241239 res = self ._client_connected_cb (reader ,
242240 self ._stream_writer )
243241 if coroutines .iscoroutine (res ):
244- self ._loop .create_task (res )
242+ self ._task = self . _loop .create_task (res )
245243 self ._strong_reader = None
246244
247245 def connection_lost (self , exc ):
@@ -259,6 +257,7 @@ def connection_lost(self, exc):
259257 super ().connection_lost (exc )
260258 self ._stream_reader_wr = None
261259 self ._stream_writer = None
260+ self ._task = None
262261 self ._transport = None
263262
264263 def data_received (self , data ):
0 commit comments