feat: improve performance of constructing outgoing queries by bdraco · Pull Request #1267 · python-zeroconf/python-zeroconf

Expand Up @@ -53,12 +53,21 @@ PACK_SHORT = Struct('>H').pack PACK_LONG = Struct('>L').pack
BYTE_TABLE = tuple(PACK_BYTE(i) for i in range(256))

class State(enum.Enum): init = 0 finished = 1

STATE_INIT = State.init STATE_FINISHED = State.finished
LOGGING_IS_ENABLED_FOR = log.isEnabledFor LOGGING_DEBUG = logging.DEBUG

class DNSOutgoing:
"""Object representation of an outgoing packet""" Expand Down Expand Up @@ -93,7 +102,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: self.size: int = _DNS_PACKET_HEADER_LEN self.allow_long: bool = True
self.state = State.init self.state = STATE_INIT
self.questions: List[DNSQuestion] = [] self.answers: List[Tuple[DNSRecord, float]] = [] Expand Down Expand Up @@ -137,7 +146,8 @@ def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None:
def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None: """Adds an answer if it does not expire by a certain time""" if record is not None and (now == 0 or not record.is_expired(now)): now_float = now if record is not None and (now_float == 0 or not record.is_expired(now_float)): self.answers.append((record, now))
def add_authorative_answer(self, record: DNSPointer) -> None: Expand Down Expand Up @@ -207,7 +217,7 @@ def add_question_or_all_cache(
def _write_byte(self, value: int_) -> None: """Writes a single byte to the packet""" self.data.append(PACK_BYTE(value)) self.data.append(BYTE_TABLE[value]) self.size += 1
def _insert_short_at_start(self, value: int_) -> None: Expand Down Expand Up @@ -267,7 +277,7 @@ def write_name(self, name: str_) -> None: """
# split name into each label name_length = None name_length = 0 if name.endswith('.'): name = name[: len(name) - 1] labels = name.split('.') Expand All @@ -276,14 +286,14 @@ def write_name(self, name: str_) -> None: start_size = self.size for count in range(len(labels)): label = name if count == 0 else '.'.join(labels[count:]) index = self.names.get(label) index = self.names.get(label, 0) if index: # If part of the name already exists in the packet, # create a pointer to it self._write_byte((index >> 8) | 0xC0) self._write_byte(index & 0xFF) return if name_length is None: if name_length == 0: name_length = len(name.encode('utf-8')) self.names[label] = start_size + name_length - len(label.encode('utf-8')) self._write_utf(labels[count]) Expand All @@ -293,7 +303,8 @@ def write_name(self, name: str_) -> None:
def _write_question(self, question: DNSQuestion_) -> bool: """Writes a question to the packet""" start_data_length, start_size = len(self.data), self.size start_data_length = len(self.data) start_size = self.size self.write_name(question.name) self.write_short(question.type) self._write_record_class(question) Expand All @@ -314,7 +325,8 @@ def _write_record(self, record: DNSRecord_, now: float_) -> bool: """Writes a record (answer, authoritative answer, additional) to the packet. Returns True on success, or False if we did not because the packet because the record does not fit.""" start_data_length, start_size = len(self.data), self.size start_data_length = len(self.data) start_size = self.size self.write_name(record.name) self.write_short(record.type) self._write_record_class(record) Expand All @@ -339,11 +351,13 @@ def _check_data_limit_or_rollback(self, start_data_length: int_, start_size: int if self.size <= len_limit: return True
log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit) if LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG): # pragma: no branch log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit) del self.data[start_data_length:] self.size = start_size
rollback_names = [name for name, idx in self.names.items() if idx >= start_size] start_size_int = start_size rollback_names = [name for name, idx in self.names.items() if idx >= start_size_int] for name in rollback_names: del self.names[name] return False Expand Down Expand Up @@ -395,7 +409,7 @@ def packets(self) -> List[bytes]: return self._packets()
def _packets(self) -> List[bytes]: if self.state == State.finished: if self.state == STATE_FINISHED: return self.packets_data
questions_offset = 0 Expand All @@ -404,7 +418,7 @@ def _packets(self) -> List[bytes]: additional_offset = 0 # we have to at least write out the question first_time = True debug_enable = log.isEnabledFor(logging.DEBUG) debug_enable = LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG)
while first_time or self._has_more_to_add( questions_offset, answer_offset, authority_offset, additional_offset Expand Down Expand Up @@ -476,5 +490,5 @@ def _packets(self) -> List[bytes]: ): log.warning("packets() made no progress adding records; returning") break self.state = State.finished self.state = STATE_FINISHED return self.packets_data