fix(spanner): handle errors during stream restart in snapshot (#1471) · googleapis/python-spanner@c066873
@@ -405,6 +405,56 @@ def test_iteration_w_raw_raising_unavailable_after_token(self):
405405self.assertEqual(request.resume_token, RESUME_TOKEN)
406406self.assertNoSpans()
407407408+def test_iteration_w_raw_raising_unavailable_during_restart(self):
409+from google.api_core.exceptions import ServiceUnavailable
410+411+FIRST = (self._make_item(0), self._make_item(1, resume_token=RESUME_TOKEN))
412+LAST = (self._make_item(2),)
413+before = _MockIterator(
414+*FIRST, fail_after=True, error=ServiceUnavailable("testing")
415+ )
416+after = _MockIterator(*LAST)
417+request = mock.Mock(test="test", spec=["test", "resume_token"])
418+# The second call (the first retry) raises ServiceUnavailable immediately.
419+# The third call (the second retry) succeeds.
420+restart = mock.Mock(
421+spec=[],
422+side_effect=[before, ServiceUnavailable("retry failed"), after],
423+ )
424+database = _Database()
425+database.spanner_api = build_spanner_api()
426+session = _Session(database)
427+derived = _build_snapshot_derived(session)
428+resumable = self._call_fut(derived, restart, request, session=session)
429+self.assertEqual(list(resumable), list(FIRST + LAST))
430+self.assertEqual(len(restart.mock_calls), 3)
431+self.assertEqual(request.resume_token, RESUME_TOKEN)
432+self.assertNoSpans()
433+434+def test_iteration_w_raw_raising_resumable_internal_error_during_restart(self):
435+FIRST = (self._make_item(0), self._make_item(1, resume_token=RESUME_TOKEN))
436+LAST = (self._make_item(2),)
437+before = _MockIterator(
438+*FIRST,
439+fail_after=True,
440+error=INTERNAL_SERVER_ERROR_UNEXPECTED_EOS,
441+ )
442+after = _MockIterator(*LAST)
443+request = mock.Mock(test="test", spec=["test", "resume_token"])
444+restart = mock.Mock(
445+spec=[],
446+side_effect=[before, INTERNAL_SERVER_ERROR_UNEXPECTED_EOS, after],
447+ )
448+database = _Database()
449+database.spanner_api = build_spanner_api()
450+session = _Session(database)
451+derived = _build_snapshot_derived(session)
452+resumable = self._call_fut(derived, restart, request, session=session)
453+self.assertEqual(list(resumable), list(FIRST + LAST))
454+self.assertEqual(len(restart.mock_calls), 3)
455+self.assertEqual(request.resume_token, RESUME_TOKEN)
456+self.assertNoSpans()
457+408458def test_iteration_w_raw_w_multiuse(self):
409459from google.cloud.spanner_v1 import (
410460ReadRequest,