[3.7] bpo-35998: Avoid TimeoutError in test_asyncio: test_start_tls_server_1() (GH-14080) by vstinner · Pull Request #14086 · python/cpython
Expand Up
@@ -491,17 +491,14 @@ async def client(addr):
def test_start_tls_server_1(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE ANSWER = b'answer'
server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() if sys.platform.startswith('freebsd') or sys.platform.startswith('win'): # bpo-35031: Some FreeBSD and Windows buildbots fail to run this test # as the eof was not being received by the server if the payload # size is not big enough. This behaviour only appears if the # client is using TLS1.3. client_context.options |= ssl.OP_NO_TLSv1_3 answer = None
def client(sock, addr): nonlocal answer sock.settimeout(self.TIMEOUT)
sock.connect(addr) Expand All @@ -510,33 +507,36 @@ def client(sock, addr):
sock.start_tls(client_context) sock.sendall(HELLO_MSG)
sock.shutdown(socket.SHUT_RDWR) answer = sock.recv_all(len(ANSWER)) sock.close()
class ServerProto(asyncio.Protocol): def __init__(self, on_con, on_eof, on_con_lost): def __init__(self, on_con, on_con_lost): self.on_con = on_con self.on_eof = on_eof self.on_con_lost = on_con_lost self.data = b'' self.transport = None
def connection_made(self, tr): self.transport = tr self.on_con.set_result(tr)
def replace_transport(self, tr): self.transport = tr
def data_received(self, data): self.data += data
def eof_received(self): self.on_eof.set_result(1) if len(self.data) >= len(HELLO_MSG): self.transport.write(ANSWER)
def connection_lost(self, exc): self.transport = None if exc is None: self.on_con_lost.set_result(None) else: self.on_con_lost.set_exception(exc)
async def main(proto, on_con, on_eof, on_con_lost): async def main(proto, on_con, on_con_lost): tr = await on_con tr.write(HELLO_MSG)
Expand All @@ -547,16 +547,16 @@ async def main(proto, on_con, on_eof, on_con_lost): server_side=True, ssl_handshake_timeout=self.TIMEOUT)
await on_eof proto.replace_transport(new_tr)
await on_con_lost self.assertEqual(proto.data, HELLO_MSG) new_tr.close()
async def run_main(): on_con = self.loop.create_future() on_eof = self.loop.create_future() on_con_lost = self.loop.create_future() proto = ServerProto(on_con, on_eof, on_con_lost) proto = ServerProto(on_con, on_con_lost)
server = await self.loop.create_server( lambda: proto, '127.0.0.1', 0) Expand All @@ -565,11 +565,12 @@ async def run_main(): with self.tcp_client(lambda sock: client(sock, addr), timeout=self.TIMEOUT): await asyncio.wait_for( main(proto, on_con, on_eof, on_con_lost), main(proto, on_con, on_con_lost), loop=self.loop, timeout=self.TIMEOUT)
server.close() await server.wait_closed() self.assertEqual(answer, ANSWER)
self.loop.run_until_complete(run_main())
Expand Down
def test_start_tls_server_1(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE ANSWER = b'answer'
server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() if sys.platform.startswith('freebsd') or sys.platform.startswith('win'): # bpo-35031: Some FreeBSD and Windows buildbots fail to run this test # as the eof was not being received by the server if the payload # size is not big enough. This behaviour only appears if the # client is using TLS1.3. client_context.options |= ssl.OP_NO_TLSv1_3 answer = None
def client(sock, addr): nonlocal answer sock.settimeout(self.TIMEOUT)
sock.connect(addr) Expand All @@ -510,33 +507,36 @@ def client(sock, addr):
sock.start_tls(client_context) sock.sendall(HELLO_MSG)
sock.shutdown(socket.SHUT_RDWR) answer = sock.recv_all(len(ANSWER)) sock.close()
class ServerProto(asyncio.Protocol): def __init__(self, on_con, on_eof, on_con_lost): def __init__(self, on_con, on_con_lost): self.on_con = on_con self.on_eof = on_eof self.on_con_lost = on_con_lost self.data = b'' self.transport = None
def connection_made(self, tr): self.transport = tr self.on_con.set_result(tr)
def replace_transport(self, tr): self.transport = tr
def data_received(self, data): self.data += data
def eof_received(self): self.on_eof.set_result(1) if len(self.data) >= len(HELLO_MSG): self.transport.write(ANSWER)
def connection_lost(self, exc): self.transport = None if exc is None: self.on_con_lost.set_result(None) else: self.on_con_lost.set_exception(exc)
async def main(proto, on_con, on_eof, on_con_lost): async def main(proto, on_con, on_con_lost): tr = await on_con tr.write(HELLO_MSG)
Expand All @@ -547,16 +547,16 @@ async def main(proto, on_con, on_eof, on_con_lost): server_side=True, ssl_handshake_timeout=self.TIMEOUT)
await on_eof proto.replace_transport(new_tr)
await on_con_lost self.assertEqual(proto.data, HELLO_MSG) new_tr.close()
async def run_main(): on_con = self.loop.create_future() on_eof = self.loop.create_future() on_con_lost = self.loop.create_future() proto = ServerProto(on_con, on_eof, on_con_lost) proto = ServerProto(on_con, on_con_lost)
server = await self.loop.create_server( lambda: proto, '127.0.0.1', 0) Expand All @@ -565,11 +565,12 @@ async def run_main(): with self.tcp_client(lambda sock: client(sock, addr), timeout=self.TIMEOUT): await asyncio.wait_for( main(proto, on_con, on_eof, on_con_lost), main(proto, on_con, on_con_lost), loop=self.loop, timeout=self.TIMEOUT)
server.close() await server.wait_closed() self.assertEqual(answer, ANSWER)
self.loop.run_until_complete(run_main())
Expand Down