Skip to content

Commit 3d399da

Browse files
committed
Implement _async_receive_ssl
1 parent adb1670 commit 3d399da

File tree

1 file changed

+52
-4
lines changed

1 file changed

+52
-4
lines changed

pymongo/network_layer.py

+52-4
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,44 @@ def _is_ready(fut: Future) -> None:
121121
loop.add_reader(fd, _is_ready, fut)
122122
loop.add_writer(fd, _is_ready, fut)
123123
await fut
124+
125+
async def _async_receive_ssl(
126+
conn: _sslConn, length: int, loop: AbstractEventLoop
127+
) -> memoryview:
128+
mv = memoryview(bytearray(length))
129+
fd = conn.fileno()
130+
read = 0
131+
132+
def _is_ready(fut: Future) -> None:
133+
loop.remove_writer(fd)
134+
loop.remove_reader(fd)
135+
if fut.done():
136+
return
137+
fut.set_result(None)
138+
139+
while read < length:
140+
try:
141+
read += conn.recv_into(mv[read:])
142+
except BLOCKING_IO_ERRORS as exc:
143+
fd = conn.fileno()
144+
# Check for closed socket.
145+
if fd == -1:
146+
raise SSLError("Underlying socket has been closed") from None
147+
if isinstance(exc, BLOCKING_IO_READ_ERROR):
148+
fut = loop.create_future()
149+
loop.add_reader(fd, _is_ready, fut)
150+
await fut
151+
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
152+
fut = loop.create_future()
153+
loop.add_writer(fd, _is_ready, fut)
154+
await fut
155+
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
156+
fut = loop.create_future()
157+
loop.add_reader(fd, _is_ready, fut)
158+
loop.add_writer(fd, _is_ready, fut)
159+
await fut
160+
return mv
161+
124162
else:
125163
# The default Windows asyncio event loop does not support loop.add_reader/add_writer:
126164
# https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
@@ -138,6 +176,20 @@ async def _async_sendall_ssl(
138176
sent = 0
139177
total_sent += sent
140178

179+
async def _async_receive_ssl(
180+
conn: _sslConn, length: int, dummy: AbstractEventLoop
181+
) -> memoryview:
182+
mv = memoryview(bytearray(length))
183+
total_read = 0
184+
while total_read < length:
185+
try:
186+
read = conn.recv_into(mv[total_read:])
187+
except BLOCKING_IO_ERRORS:
188+
await asyncio.sleep(0.5)
189+
read = 0
190+
total_read += read
191+
return mv
192+
141193

142194
def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
143195
sock.sendall(buf)
@@ -181,10 +233,6 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo
181233
return mv
182234

183235

184-
async def _async_receive_ssl(conn: _sslConn, length: int, loop: AbstractEventLoop) -> memoryview: # noqa: ARG001
185-
return memoryview(b"")
186-
187-
188236
# Sync version:
189237
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
190238
"""Block until at least one byte is read, or a timeout, or a cancel."""

0 commit comments

Comments
 (0)