feat: speed up writing name compression for outgoing packets (#1312) · python-zeroconf/python-zeroconf@9caeabb
@@ -61,8 +61,8 @@ class State(enum.Enum):
6161finished = 1
6262636364-STATE_INIT = State.init
65-STATE_FINISHED = State.finished
64+STATE_INIT = State.init.value
65+STATE_FINISHED = State.finished.value
66666767LOGGING_IS_ENABLED_FOR = log.isEnabledFor
6868LOGGING_DEBUG = logging.DEBUG
@@ -277,30 +277,41 @@ def write_name(self, name: str_) -> None:
277277 """
278278279279# split name into each label
280-name_length = 0
281280if name.endswith('.'):
282-name = name[: len(name) - 1]
283-labels = name.split('.')
284-# Write each new label or a pointer to the existing
285-# on in the packet
281+name = name[:-1]
282+283+index = self.names.get(name, 0)
284+if index:
285+self._write_link_to_name(index)
286+return
287+286288start_size = self.size
287-for count in range(len(labels)):
288-label = name if count == 0 else '.'.join(labels[count:])
289-index = self.names.get(label, 0)
289+labels = name.split('.')
290+# Write each new label or a pointer to the existing one in the packet
291+self.names[name] = start_size
292+self._write_utf(labels[0])
293+294+name_length = 0
295+for count in range(1, len(labels)):
296+partial_name = '.'.join(labels[count:])
297+index = self.names.get(partial_name, 0)
290298if index:
291-# If part of the name already exists in the packet,
292-# create a pointer to it
293-self._write_byte((index >> 8) | 0xC0)
294-self._write_byte(index & 0xFF)
299+self._write_link_to_name(index)
295300return
296301if name_length == 0:
297302name_length = len(name.encode('utf-8'))
298-self.names[label] = start_size + name_length - len(label.encode('utf-8'))
303+self.names[partial_name] = start_size + name_length - len(partial_name.encode('utf-8'))
299304self._write_utf(labels[count])
300305301306# this is the end of a name
302307self._write_byte(0)
303308309+def _write_link_to_name(self, index: int_) -> None:
310+# If part of the name already exists in the packet,
311+# create a pointer to it
312+self._write_byte((index >> 8) | 0xC0)
313+self._write_byte(index & 0xFF)
314+304315def _write_question(self, question: DNSQuestion_) -> bool:
305316"""Writes a question to the packet"""
306317start_data_length = len(self.data)
@@ -406,9 +417,6 @@ def packets(self) -> List[bytes]:
406417 will be written out to a single oversized packet no more than
407418 _MAX_MSG_ABSOLUTE in length (and hence will be subject to IP
408419 fragmentation potentially)."""
409-return self._packets()
410-411-def _packets(self) -> List[bytes]:
412420if self.state == STATE_FINISHED:
413421return self.packets_data
414422@@ -445,6 +453,8 @@ def _packets(self) -> List[bytes]:
445453authorities_written = self._write_records_from_offset(self.authorities, authority_offset)
446454additionals_written = self._write_records_from_offset(self.additionals, additional_offset)
447455456+made_progress = bool(self.data)
457+448458self._insert_short_at_start(additionals_written)
449459self._insert_short_at_start(authorities_written)
450460self._insert_short_at_start(answers_written)
@@ -479,16 +489,16 @@ def _packets(self) -> List[bytes]:
479489self._insert_short_at_start(self.id)
480490481491self.packets_data.append(b''.join(self.data))
482-self._reset_for_next_packet()
483492484-if (
485-not questions_written
486-and not answers_written
487-and not authorities_written
488-and not additionals_written
489-and (self.questions or self.answers or self.authorities or self.additionals)
490- ):
493+if not made_progress:
494+# Generating an empty packet is not a desirable outcome, but currently
495+# too many internals rely on this behavior. So, we'll just return an
496+# empty packet and log a warning until this can be refactored at a later
497+# date.
491498log.warning("packets() made no progress adding records; returning")
492499break
500+501+self._reset_for_next_packet()
502+493503self.state = STATE_FINISHED
494504return self.packets_data