feat: speed up ServiceInfo with a cython pxd by bdraco · Pull Request #1264 · python-zeroconf/python-zeroconf

Expand Up @@ -78,6 +78,12 @@ # the A/AAAA/SRV records for a host. _AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)
float_ = float int_ = int
DNS_QUESTION_TYPE_QU = DNSQuestionType.QU DNS_QUESTION_TYPE_QM = DNSQuestionType.QM
if TYPE_CHECKING: from .._core import Zeroconf
Expand Down Expand Up @@ -281,10 +287,9 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]: """ version_value = version.value if version_value == _IPVersion_All_value: return [ *(addr.packed for addr in self._ipv4_addresses), *(addr.packed for addr in self._ipv6_addresses), ] ip_v4_packed = [addr.packed for addr in self._ipv4_addresses] ip_v6_packed = [addr.packed for addr in self._ipv6_addresses] return [*ip_v4_packed, *ip_v6_packed] if version_value == _IPVersion_V4Only_value: return [addr.packed for addr in self._ipv4_addresses] return [addr.packed for addr in self._ipv6_addresses] Expand All @@ -303,7 +308,7 @@ def ip_addresses_by_version( return self._ip_addresses_by_version_value(version.value)
def _ip_addresses_by_version_value( self, version_value: int self, version_value: int_ ) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]: """Backend for addresses_by_version that uses the raw value.""" if version_value == _IPVersion_All_value: Expand Down Expand Up @@ -397,7 +402,7 @@ def get_name(self) -> str: return self._name[: len(self._name) - len(self.type) - 1]
def _get_ip_addresses_from_cache_lifo( self, zc: 'Zeroconf', now: float, type: int self, zc: 'Zeroconf', now: float_, type: int_ ) -> List[Union[IPv4Address, IPv6Address]]: """Set IPv6 addresses from the cache.""" address_list: List[Union[IPv4Address, IPv6Address]] = [] Expand All @@ -410,7 +415,7 @@ def _get_ip_addresses_from_cache_lifo( address_list.reverse() # Reverse to get LIFO order return address_list
def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None: """Set IPv6 addresses from the cache.""" if TYPE_CHECKING: self._ipv6_addresses = cast( Expand All @@ -419,7 +424,7 @@ def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: else: self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)
def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None: """Set IPv4 addresses from the cache.""" if TYPE_CHECKING: self._ipv4_addresses = cast( Expand All @@ -428,7 +433,7 @@ def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: else: self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)
def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None: def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[RecordUpdate]) -> None: """Updates service information from a DNS record.
This method will be run in the event loop. Expand All @@ -440,7 +445,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU if updated and new_records_futures: _resolve_all_futures_to_none(new_records_futures)
def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float) -> bool: def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float_) -> bool: """Thread safe record updating.
Returns True if a new record was added. Expand Down Expand Up @@ -624,14 +629,15 @@ def get_address_and_nsec_records(self, override_ttl: Optional[int] = None) -> Se self._get_address_and_nsec_records_cache = records return records
def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]: def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int_) -> List[DNSAddress]: """Get the addresses from the cache.""" if self.server_key is None: return [] cache = zc.cache if TYPE_CHECKING: records = cast("List[DNSAddress]", zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN)) records = cast("List[DNSAddress]", cache.get_all_by_details(self.server_key, _type, _CLASS_IN)) else: records = zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN) records = cache.get_all_by_details(self.server_key, _type, _CLASS_IN) return records
def set_server_if_missing(self) -> None: Expand All @@ -643,28 +649,33 @@ def set_server_if_missing(self) -> None: self.server = self._name self.server_key = self.key
def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool: def load_from_cache(self, zc: 'Zeroconf', now: Optional[float_] = None) -> bool: """Populate the service info from the cache.
This method is designed to be threadsafe. """ return self._load_from_cache(zc, now or current_time_millis())
def _load_from_cache(self, zc: 'Zeroconf', now: float_) -> bool: """Populate the service info from the cache.
This method is designed to be threadsafe. """ if not now: now = current_time_millis() cache = zc.cache original_server_key = self.server_key cached_srv_record = zc.cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN) cached_srv_record = cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN) if cached_srv_record: self._process_record_threadsafe(zc, cached_srv_record, now) cached_txt_record = zc.cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN) cached_txt_record = cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN) if cached_txt_record: self._process_record_threadsafe(zc, cached_txt_record, now) if original_server_key == self.server_key: # If there is a srv which changes the server_key, # A and AAAA will already be loaded from the cache # and we do not want to do it twice for record in [ *self._get_address_records_from_cache_by_type(zc, _TYPE_A), *self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA), ]: for record in self._get_address_records_from_cache_by_type(zc, _TYPE_A): self._process_record_threadsafe(zc, record, now) for record in self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA): self._process_record_threadsafe(zc, record, now) return self._is_complete
Expand Down Expand Up @@ -720,7 +731,7 @@ async def async_request(
now = current_time_millis()
if self.load_from_cache(zc, now): if self._load_from_cache(zc, now): return True
if TYPE_CHECKING: Expand All @@ -737,11 +748,13 @@ async def async_request( return False if next_ <= now: out = self.generate_request_query( zc, now, question_type or DNSQuestionType.QU if first_request else DNSQuestionType.QM zc, now, question_type or DNS_QUESTION_TYPE_QU if first_request else DNS_QUESTION_TYPE_QM, ) first_request = False if not out.questions: return self.load_from_cache(zc, now) return self._load_from_cache(zc, now) zc.async_send(out, addr, port) next_ = now + delay delay *= 2 Expand All @@ -755,7 +768,7 @@ async def async_request( return True
def generate_request_query( self, zc: 'Zeroconf', now: float, question_type: Optional[DNSQuestionType] = None self, zc: 'Zeroconf', now: float_, question_type: Optional[DNSQuestionType] = None ) -> DNSOutgoing: """Generate the request query.""" out = DNSOutgoing(_FLAGS_QR_QUERY) Expand All @@ -766,7 +779,7 @@ def generate_request_query( out.add_question_or_one_cache(cache, now, name, _TYPE_TXT, _CLASS_IN) out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_A, _CLASS_IN) out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_AAAA, _CLASS_IN) if question_type == DNSQuestionType.QU: if question_type == DNS_QUESTION_TYPE_QU: for question in out.questions: question.unicast = True return out Expand Down