feat: speed up processing incoming records (#1179) · python-zeroconf/python-zeroconf@d919316
@@ -21,7 +21,7 @@
2121"""
22222323import itertools
24-from typing import Dict, Iterable, Iterator, List, Optional, Union, cast
24+from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast
25252626from ._dns import (
2727DNSAddress,
@@ -34,13 +34,15 @@
3434DNSText,
3535)
3636from ._utils.time import current_time_millis
37-from .const import _TYPE_PTR
37+from .const import _ONE_SECOND, _TYPE_PTR
38383939_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService)
4040_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService]
4141_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]]
4242_DNSRecord = DNSRecord
4343_str = str
44+_float = float
45+_int = int
444645474648def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None:
@@ -134,19 +136,29 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
134136return None
135137return store.get(entry)
136138137-def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterator[DNSRecord]:
139+def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterable[DNSRecord]:
138140"""Gets all matching entries by details.
139141140- This function is not threadsafe and must be called from
142+ This function is not thread-safe and must be called from
143+ the event loop.
144+ """
145+return self._async_all_by_details(name, type_, class_)
146+147+def _async_all_by_details(self, name: _str, type_: int, class_: int) -> List[DNSRecord]:
148+"""Gets all matching entries by details.
149+150+ This function is not thread-safe and must be called from
141151 the event loop.
142152 """
143153key = name.lower()
144154records = self.cache.get(key)
155+matches: List[DNSRecord] = []
145156if records is None:
146-return
147-for entry in records:
148-if _dns_record_matches(entry, key, type_, class_):
149-yield entry
157+return matches
158+for record in records:
159+if _dns_record_matches(record, key, type_, class_):
160+matches.append(record)
161+return matches
150162151163def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]:
152164"""Returns a dict of entries whose key matches the name.
@@ -226,6 +238,25 @@ def names(self) -> List[str]:
226238"""Return a copy of the list of current cache names."""
227239return list(self.cache)
228240241+def async_mark_unique_records_older_than_1s_to_expire(
242+self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float
243+ ) -> None:
244+self._async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now)
245+246+def _async_mark_unique_records_older_than_1s_to_expire(
247+self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float
248+ ) -> None:
249+# rfc6762#section-10.2 para 2
250+# Since unique is set, all old records with that name, rrtype,
251+# and rrclass that were received more than one second ago are declared
252+# invalid, and marked to expire from the cache in one second.
253+answers_rrset = set(answers)
254+for name, type_, class_ in unique_types:
255+for record in self._async_all_by_details(name, type_, class_):
256+if (now - record.created > _ONE_SECOND) and record not in answers_rrset:
257+# Expire in 1s
258+record.set_created_ttl(now, 1)
259+229260230261def _dns_record_matches(record: _DNSRecord, key: _str, type_: int, class_: int) -> bool:
231262return key == record.key and type_ == record.type and class_ == record.class_