feat: improve performance of constructing outgoing queries (#1267) · python-zeroconf/python-zeroconf@00c439a
@@ -53,12 +53,21 @@
5353PACK_SHORT = Struct('>H').pack
5454PACK_LONG = Struct('>L').pack
555556+BYTE_TABLE = tuple(PACK_BYTE(i) for i in range(256))
57+56585759class State(enum.Enum):
5860init = 0
5961finished = 1
6062616364+STATE_INIT = State.init
65+STATE_FINISHED = State.finished
66+67+LOGGING_IS_ENABLED_FOR = log.isEnabledFor
68+LOGGING_DEBUG = logging.DEBUG
69+70+6271class DNSOutgoing:
63726473"""Object representation of an outgoing packet"""
@@ -93,7 +102,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None:
93102self.size: int = _DNS_PACKET_HEADER_LEN
94103self.allow_long: bool = True
9510496-self.state = State.init
105+self.state = STATE_INIT
9710698107self.questions: List[DNSQuestion] = []
99108self.answers: List[Tuple[DNSRecord, float]] = []
@@ -137,7 +146,8 @@ def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None:
137146138147def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None:
139148"""Adds an answer if it does not expire by a certain time"""
140-if record is not None and (now == 0 or not record.is_expired(now)):
149+now_float = now
150+if record is not None and (now_float == 0 or not record.is_expired(now_float)):
141151self.answers.append((record, now))
142152143153def add_authorative_answer(self, record: DNSPointer) -> None:
@@ -207,7 +217,7 @@ def add_question_or_all_cache(
207217208218def _write_byte(self, value: int_) -> None:
209219"""Writes a single byte to the packet"""
210-self.data.append(PACK_BYTE(value))
220+self.data.append(BYTE_TABLE[value])
211221self.size += 1
212222213223def _insert_short_at_start(self, value: int_) -> None:
@@ -267,7 +277,7 @@ def write_name(self, name: str_) -> None:
267277 """
268278269279# split name into each label
270-name_length = None
280+name_length = 0
271281if name.endswith('.'):
272282name = name[: len(name) - 1]
273283labels = name.split('.')
@@ -276,14 +286,14 @@ def write_name(self, name: str_) -> None:
276286start_size = self.size
277287for count in range(len(labels)):
278288label = name if count == 0 else '.'.join(labels[count:])
279-index = self.names.get(label)
289+index = self.names.get(label, 0)
280290if index:
281291# If part of the name already exists in the packet,
282292# create a pointer to it
283293self._write_byte((index >> 8) | 0xC0)
284294self._write_byte(index & 0xFF)
285295return
286-if name_length is None:
296+if name_length == 0:
287297name_length = len(name.encode('utf-8'))
288298self.names[label] = start_size + name_length - len(label.encode('utf-8'))
289299self._write_utf(labels[count])
@@ -293,7 +303,8 @@ def write_name(self, name: str_) -> None:
293303294304def _write_question(self, question: DNSQuestion_) -> bool:
295305"""Writes a question to the packet"""
296-start_data_length, start_size = len(self.data), self.size
306+start_data_length = len(self.data)
307+start_size = self.size
297308self.write_name(question.name)
298309self.write_short(question.type)
299310self._write_record_class(question)
@@ -314,7 +325,8 @@ def _write_record(self, record: DNSRecord_, now: float_) -> bool:
314325"""Writes a record (answer, authoritative answer, additional) to
315326 the packet. Returns True on success, or False if we did not
316327 because the packet because the record does not fit."""
317-start_data_length, start_size = len(self.data), self.size
328+start_data_length = len(self.data)
329+start_size = self.size
318330self.write_name(record.name)
319331self.write_short(record.type)
320332self._write_record_class(record)
@@ -339,11 +351,13 @@ def _check_data_limit_or_rollback(self, start_data_length: int_, start_size: int
339351if self.size <= len_limit:
340352return True
341353342-log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit)
354+if LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG): # pragma: no branch
355+log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit)
343356del self.data[start_data_length:]
344357self.size = start_size
345358346-rollback_names = [name for name, idx in self.names.items() if idx >= start_size]
359+start_size_int = start_size
360+rollback_names = [name for name, idx in self.names.items() if idx >= start_size_int]
347361for name in rollback_names:
348362del self.names[name]
349363return False
@@ -395,7 +409,7 @@ def packets(self) -> List[bytes]:
395409return self._packets()
396410397411def _packets(self) -> List[bytes]:
398-if self.state == State.finished:
412+if self.state == STATE_FINISHED:
399413return self.packets_data
400414401415questions_offset = 0
@@ -404,7 +418,7 @@ def _packets(self) -> List[bytes]:
404418additional_offset = 0
405419# we have to at least write out the question
406420first_time = True
407-debug_enable = log.isEnabledFor(logging.DEBUG)
421+debug_enable = LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG)
408422409423while first_time or self._has_more_to_add(
410424questions_offset, answer_offset, authority_offset, additional_offset
@@ -476,5 +490,5 @@ def _packets(self) -> List[bytes]:
476490 ):
477491log.warning("packets() made no progress adding records; returning")
478492break
479-self.state = State.finished
493+self.state = STATE_FINISHED
480494return self.packets_data