bpo-23749: Implement loop.start_tls() (#5039) · python/cpython@f111b3d

@@ -29,9 +29,15 @@

2929

import warnings

3030

import weakref

313132+

try:

33+

import ssl

34+

except ImportError: # pragma: no cover

35+

ssl = None

36+3237

from . import coroutines

3338

from . import events

3439

from . import futures

40+

from . import sslproto

3541

from . import tasks

3642

from .log import logger

3743

@@ -279,7 +285,8 @@ def _make_ssl_transport(

279285

self, rawsock, protocol, sslcontext, waiter=None,

280286

*, server_side=False, server_hostname=None,

281287

extra=None, server=None,

282-

ssl_handshake_timeout=None):

288+

ssl_handshake_timeout=None,

289+

call_connection_made=True):

283290

"""Create SSL transport."""

284291

raise NotImplementedError

285292

@@ -795,6 +802,42 @@ async def _create_connection_transport(

795802796803

return transport, protocol

797804805+

async def start_tls(self, transport, protocol, sslcontext, *,

806+

server_side=False,

807+

server_hostname=None,

808+

ssl_handshake_timeout=None):

809+

"""Upgrade transport to TLS.

810+811+

Return a new transport that *protocol* should start using

812+

immediately.

813+

"""

814+

if ssl is None:

815+

raise RuntimeError('Python ssl module is not available')

816+817+

if not isinstance(sslcontext, ssl.SSLContext):

818+

raise TypeError(

819+

f'sslcontext is expected to be an instance of ssl.SSLContext, '

820+

f'got {sslcontext!r}')

821+822+

if not getattr(transport, '_start_tls_compatible', False):

823+

raise TypeError(

824+

f'transport {self!r} is not supported by start_tls()')

825+826+

waiter = self.create_future()

827+

ssl_protocol = sslproto.SSLProtocol(

828+

self, protocol, sslcontext, waiter,

829+

server_side, server_hostname,

830+

ssl_handshake_timeout=ssl_handshake_timeout,

831+

call_connection_made=False)

832+833+

transport.set_protocol(ssl_protocol)

834+

self.call_soon(ssl_protocol.connection_made, transport)

835+

if not transport.is_reading():

836+

self.call_soon(transport.resume_reading)

837+838+

await waiter

839+

return ssl_protocol._app_transport

840+798841

async def create_datagram_endpoint(self, protocol_factory,

799842

local_addr=None, remote_addr=None, *,

800843

family=0, proto=0, flags=0,