bpo-31386: Custom wrap_bio and wrap_socket type (#3426) · python/cpython@4df60f1

@@ -383,10 +383,11 @@ class Purpose(_ASN1Object, _Enum):

383383

class SSLContext(_SSLContext):

384384

"""An SSLContext holds various SSL-related configuration options and

385385

data, such as certificates and possibly a private key."""

386-387-

__slots__ = ('protocol', '__weakref__')

388386

_windows_cert_stores = ("CA", "ROOT")

389387388+

sslsocket_class = None # SSLSocket is assigned later.

389+

sslobject_class = None # SSLObject is assigned later.

390+390391

def __new__(cls, protocol=PROTOCOL_TLS, *args, **kwargs):

391392

self = _SSLContext.__new__(cls, protocol)

392393

if protocol != _SSLv2_IF_EXISTS:

@@ -400,17 +401,21 @@ def wrap_socket(self, sock, server_side=False,

400401

do_handshake_on_connect=True,

401402

suppress_ragged_eofs=True,

402403

server_hostname=None, session=None):

403-

return SSLSocket(sock=sock, server_side=server_side,

404-

do_handshake_on_connect=do_handshake_on_connect,

405-

suppress_ragged_eofs=suppress_ragged_eofs,

406-

server_hostname=server_hostname,

407-

_context=self, _session=session)

404+

return self.sslsocket_class(

405+

sock=sock,

406+

server_side=server_side,

407+

do_handshake_on_connect=do_handshake_on_connect,

408+

suppress_ragged_eofs=suppress_ragged_eofs,

409+

server_hostname=server_hostname,

410+

_context=self,

411+

_session=session

412+

)

408413409414

def wrap_bio(self, incoming, outgoing, server_side=False,

410415

server_hostname=None, session=None):

411416

sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,

412417

server_hostname=server_hostname)

413-

return SSLObject(sslobj, session=session)

418+

return self.sslobject_class(sslobj, session=session)

414419415420

def set_npn_protocols(self, npn_protocols):

416421

protos = bytearray()

@@ -1135,6 +1140,11 @@ def version(self):

11351140

return self._sslobj.version()

11361141113711421143+

# Python does not support forward declaration of types.

1144+

SSLContext.sslsocket_class = SSLSocket

1145+

SSLContext.sslobject_class = SSLObject

1146+1147+11381148

def wrap_socket(sock, keyfile=None, certfile=None,

11391149

server_side=False, cert_reqs=CERT_NONE,

11401150

ssl_version=PROTOCOL_TLS, ca_certs=None,