bpo-33734: asyncio/ssl: a bunch of bugfixes (#7321) · python/cpython@9602643

@@ -214,13 +214,14 @@ def feed_ssldata(self, data, only_handshake=False):

214214

# Drain possible plaintext data after close_notify.

215215

appdata.append(self._incoming.read())

216216

except (ssl.SSLError, ssl.CertificateError) as exc:

217-

if getattr(exc, 'errno', None) not in (

217+

exc_errno = getattr(exc, 'errno', None)

218+

if exc_errno not in (

218219

ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,

219220

ssl.SSL_ERROR_SYSCALL):

220221

if self._state == _DO_HANDSHAKE and self._handshake_cb:

221222

self._handshake_cb(exc)

222223

raise

223-

self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)

224+

self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)

224225225226

# Check for record level data that needs to be sent back.

226227

# Happens for the initial handshake and renegotiations.

@@ -263,13 +264,14 @@ def feed_appdata(self, data, offset=0):

263264

# It is not allowed to call write() after unwrap() until the

264265

# close_notify is acknowledged. We return the condition to the

265266

# caller as a short write.

267+

exc_errno = getattr(exc, 'errno', None)

266268

if exc.reason == 'PROTOCOL_IS_SHUTDOWN':

267-

exc.errno = ssl.SSL_ERROR_WANT_READ

268-

if exc.errno not in (ssl.SSL_ERROR_WANT_READ,

269+

exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ

270+

if exc_errno not in (ssl.SSL_ERROR_WANT_READ,

269271

ssl.SSL_ERROR_WANT_WRITE,

270272

ssl.SSL_ERROR_SYSCALL):

271273

raise

272-

self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)

274+

self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)

273275274276

# See if there's any record level data back for us.

275277

if self._outgoing.pending:

@@ -488,6 +490,12 @@ def connection_lost(self, exc):

488490

if self._session_established:

489491

self._session_established = False

490492

self._loop.call_soon(self._app_protocol.connection_lost, exc)

493+

else:

494+

# Most likely an exception occurred while in SSL handshake.

495+

# Just mark the app transport as closed so that its __del__

496+

# doesn't complain.

497+

if self._app_transport is not None:

498+

self._app_transport._closed = True

491499

self._transport = None

492500

self._app_transport = None

493501

self._wakeup_waiter(exc)

@@ -515,11 +523,8 @@ def data_received(self, data):

515523516524

try:

517525

ssldata, appdata = self._sslpipe.feed_ssldata(data)

518-

except ssl.SSLError as e:

519-

if self._loop.get_debug():

520-

logger.warning('%r: SSL error %s (reason %s)',

521-

self, e.errno, e.reason)

522-

self._abort()

526+

except Exception as e:

527+

self._fatal_error(e, 'SSL error in data received')

523528

return

524529525530

for chunk in ssldata:

@@ -602,8 +607,12 @@ def _start_handshake(self):

602607603608

def _check_handshake_timeout(self):

604609

if self._in_handshake is True:

605-

logger.warning("%r stalled during handshake", self)

606-

self._abort()

610+

msg = (

611+

f"SSL handshake is taking longer than "

612+

f"{self._ssl_handshake_timeout} seconds: "

613+

f"aborting the connection"

614+

)

615+

self._fatal_error(ConnectionAbortedError(msg))

607616608617

def _on_handshake_complete(self, handshake_exc):

609618

self._in_handshake = False

@@ -615,21 +624,13 @@ def _on_handshake_complete(self, handshake_exc):

615624

raise handshake_exc

616625617626

peercert = sslobj.getpeercert()

618-

except BaseException as exc:

619-

if self._loop.get_debug():

620-

if isinstance(exc, ssl.CertificateError):

621-

logger.warning("%r: SSL handshake failed "

622-

"on verifying the certificate",

623-

self, exc_info=True)

624-

else:

625-

logger.warning("%r: SSL handshake failed",

626-

self, exc_info=True)

627-

self._transport.close()

628-

if isinstance(exc, Exception):

629-

self._wakeup_waiter(exc)

630-

return

627+

except Exception as exc:

628+

if isinstance(exc, ssl.CertificateError):

629+

msg = 'SSL handshake failed on verifying the certificate'

631630

else:

632-

raise

631+

msg = 'SSL handshake failed'

632+

self._fatal_error(exc, msg)

633+

return

633634634635

if self._loop.get_debug():

635636

dt = self._loop.time() - self._handshake_start_time

@@ -686,18 +687,14 @@ def _process_write_backlog(self):

686687

# delete it and reduce the outstanding buffer size.

687688

del self._write_backlog[0]

688689

self._write_buffer_size -= len(data)

689-

except BaseException as exc:

690+

except Exception as exc:

690691

if self._in_handshake:

691-

# BaseExceptions will be re-raised in _on_handshake_complete.

692+

# Exceptions will be re-raised in _on_handshake_complete.

692693

self._on_handshake_complete(exc)

693694

else:

694695

self._fatal_error(exc, 'Fatal error on SSL transport')

695-

if not isinstance(exc, Exception):

696-

# BaseException

697-

raise

698696699697

def _fatal_error(self, exc, message='Fatal error on transport'):

700-

# Should be called from exception handler only.

701698

if isinstance(exc, base_events._FATAL_ERROR_IGNORE):

702699

if self._loop.get_debug():

703700

logger.debug("%r: %s", self, message, exc_info=True)