feat: speed up writing name compression for outgoing packets by bdraco · Pull Request #1312 · python-zeroconf/python-zeroconf

Expand Up @@ -61,8 +61,8 @@ class State(enum.Enum): finished = 1

STATE_INIT = State.init STATE_FINISHED = State.finished STATE_INIT = State.init.value STATE_FINISHED = State.finished.value
LOGGING_IS_ENABLED_FOR = log.isEnabledFor LOGGING_DEBUG = logging.DEBUG Expand Down Expand Up @@ -277,30 +277,41 @@ def write_name(self, name: str_) -> None: """
# split name into each label name_length = 0 if name.endswith('.'): name = name[: len(name) - 1] labels = name.split('.') # Write each new label or a pointer to the existing # on in the packet name = name[:-1]
index = self.names.get(name, 0) if index: self._write_link_to_name(index) return
start_size = self.size for count in range(len(labels)): label = name if count == 0 else '.'.join(labels[count:]) index = self.names.get(label, 0) labels = name.split('.') # Write each new label or a pointer to the existing one in the packet self.names[name] = start_size self._write_utf(labels[0])
name_length = 0 for count in range(1, len(labels)): partial_name = '.'.join(labels[count:]) index = self.names.get(partial_name, 0) if index: # If part of the name already exists in the packet, # create a pointer to it self._write_byte((index >> 8) | 0xC0) self._write_byte(index & 0xFF) self._write_link_to_name(index) return if name_length == 0: name_length = len(name.encode('utf-8')) self.names[label] = start_size + name_length - len(label.encode('utf-8')) self.names[partial_name] = start_size + name_length - len(partial_name.encode('utf-8')) self._write_utf(labels[count])
# this is the end of a name self._write_byte(0)
def _write_link_to_name(self, index: int_) -> None: # If part of the name already exists in the packet, # create a pointer to it self._write_byte((index >> 8) | 0xC0) self._write_byte(index & 0xFF)
def _write_question(self, question: DNSQuestion_) -> bool: """Writes a question to the packet""" start_data_length = len(self.data) Expand Down Expand Up @@ -406,9 +417,6 @@ def packets(self) -> List[bytes]: will be written out to a single oversized packet no more than _MAX_MSG_ABSOLUTE in length (and hence will be subject to IP fragmentation potentially).""" return self._packets()
def _packets(self) -> List[bytes]: if self.state == STATE_FINISHED: return self.packets_data
Expand Down Expand Up @@ -445,6 +453,8 @@ def _packets(self) -> List[bytes]: authorities_written = self._write_records_from_offset(self.authorities, authority_offset) additionals_written = self._write_records_from_offset(self.additionals, additional_offset)
made_progress = bool(self.data)
self._insert_short_at_start(additionals_written) self._insert_short_at_start(authorities_written) self._insert_short_at_start(answers_written) Expand Down Expand Up @@ -479,16 +489,16 @@ def _packets(self) -> List[bytes]: self._insert_short_at_start(self.id)
self.packets_data.append(b''.join(self.data)) self._reset_for_next_packet()
if ( not questions_written and not answers_written and not authorities_written and not additionals_written and (self.questions or self.answers or self.authorities or self.additionals) ): if not made_progress: # Generating an empty packet is not a desirable outcome, but currently # too many internals rely on this behavior. So, we'll just return an # empty packet and log a warning until this can be refactored at a later # date. log.warning("packets() made no progress adding records; returning") break
self._reset_for_next_packet()
self.state = STATE_FINISHED return self.packets_data