feat: speed up decoding incoming packets by bdraco · Pull Request #1256 · python-zeroconf/python-zeroconf
Expand Up
@@ -89,6 +89,7 @@ class DNSIncoming:
'num_additionals',
'valid',
'now',
'_now_float',
'scope_id',
'source',
)
Expand Down
Expand Up
@@ -116,6 +117,7 @@ def __init__(
self.valid = False
self._did_read_others = False
self.now = now or current_time_millis()
self._now_float = self.now
self.source = source
self.scope_id = scope_id
try:
Expand Down
Expand Up
@@ -226,11 +228,13 @@ def _read_questions(self) -> None:
question = DNSQuestion(name, type_, class_)
self.questions.append(question)
def _read_character_string(self) -> bytes: def _read_character_string(self) -> str: """Reads a character string from the packet""" length = self.data[self.offset] self.offset += 1 return self._read_string(length) info = self.data[self.offset : self.offset + length].decode('utf-8', 'replace') self.offset += length return info
def _read_string(self, length: _int) -> bytes: """Reads a string of a given length from the packet""" Expand Down Expand Up @@ -273,7 +277,7 @@ def _read_record( """Read known records types and skip unknown ones.""" if type_ == _TYPE_A: dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(4)) dns_address.created = self.now dns_address.created = self._now_float return dns_address if type_ in (_TYPE_CNAME, _TYPE_PTR): return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now) Expand All @@ -299,13 +303,13 @@ def _read_record( type_, class_, ttl, self._read_character_string().decode('utf-8', 'replace'), self._read_character_string().decode('utf-8', 'replace'), self._read_character_string(), self._read_character_string(), self.now, ) if type_ == _TYPE_AAAA: dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(16)) dns_address.created = self.now dns_address.created = self._now_float dns_address.scope_id = self.scope_id return dns_address if type_ == _TYPE_NSEC: Expand Down Expand Up @@ -377,7 +381,7 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: # We have a DNS compression pointer link_data = self.data[off + 1] link = (length & 0x3F) * 256 + link_data lint_int = int(link) link_py_int = link if link > self._data_len: raise IncomingDecodeError( f"DNS compression pointer at {off} points to {link} beyond packet from {self.source}" Expand All @@ -386,16 +390,16 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: raise IncomingDecodeError( f"DNS compression pointer at {off} points to itself from {self.source}" ) if lint_int in seen_pointers: if link_py_int in seen_pointers: raise IncomingDecodeError( f"DNS compression pointer at {off} was seen again from {self.source}" ) linked_labels = self.name_cache.get(lint_int) linked_labels = self.name_cache.get(link_py_int) if not linked_labels: linked_labels = [] seen_pointers.add(lint_int) seen_pointers.add(link_py_int) self._decode_labels_at_offset(link, linked_labels, seen_pointers) self.name_cache[lint_int] = linked_labels self.name_cache[link_py_int] = linked_labels labels.extend(linked_labels) if len(labels) > MAX_DNS_LABELS: raise IncomingDecodeError( Expand Down
def _read_character_string(self) -> bytes: def _read_character_string(self) -> str: """Reads a character string from the packet""" length = self.data[self.offset] self.offset += 1 return self._read_string(length) info = self.data[self.offset : self.offset + length].decode('utf-8', 'replace') self.offset += length return info
def _read_string(self, length: _int) -> bytes: """Reads a string of a given length from the packet""" Expand Down Expand Up @@ -273,7 +277,7 @@ def _read_record( """Read known records types and skip unknown ones.""" if type_ == _TYPE_A: dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(4)) dns_address.created = self.now dns_address.created = self._now_float return dns_address if type_ in (_TYPE_CNAME, _TYPE_PTR): return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now) Expand All @@ -299,13 +303,13 @@ def _read_record( type_, class_, ttl, self._read_character_string().decode('utf-8', 'replace'), self._read_character_string().decode('utf-8', 'replace'), self._read_character_string(), self._read_character_string(), self.now, ) if type_ == _TYPE_AAAA: dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(16)) dns_address.created = self.now dns_address.created = self._now_float dns_address.scope_id = self.scope_id return dns_address if type_ == _TYPE_NSEC: Expand Down Expand Up @@ -377,7 +381,7 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: # We have a DNS compression pointer link_data = self.data[off + 1] link = (length & 0x3F) * 256 + link_data lint_int = int(link) link_py_int = link if link > self._data_len: raise IncomingDecodeError( f"DNS compression pointer at {off} points to {link} beyond packet from {self.source}" Expand All @@ -386,16 +390,16 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: raise IncomingDecodeError( f"DNS compression pointer at {off} points to itself from {self.source}" ) if lint_int in seen_pointers: if link_py_int in seen_pointers: raise IncomingDecodeError( f"DNS compression pointer at {off} was seen again from {self.source}" ) linked_labels = self.name_cache.get(lint_int) linked_labels = self.name_cache.get(link_py_int) if not linked_labels: linked_labels = [] seen_pointers.add(lint_int) seen_pointers.add(link_py_int) self._decode_labels_at_offset(link, linked_labels, seen_pointers) self.name_cache[lint_int] = linked_labels self.name_cache[link_py_int] = linked_labels labels.extend(linked_labels) if len(labels) > MAX_DNS_LABELS: raise IncomingDecodeError( Expand Down