feat: optimize the dns cache (#1119) · python-zeroconf/python-zeroconf@e80fcef
@@ -32,22 +32,23 @@
3232DNSRecord,
3333DNSService,
3434DNSText,
35-dns_entry_matches,
3635)
3736from ._utils.time import current_time_millis
3837from .const import _TYPE_PTR
39384039_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService)
4140_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService]
4241_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]]
42+_DNSRecord = DNSRecord
43+_str = str
4344444545-def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None:
46+def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None:
4647"""Remove a key from a DNSRecord cache
47484849 This function must be run in from event loop.
4950 """
50-del cache[key][entry]
51+del cache[key][record]
5152if not cache[key]:
5253del cache[key]
5354@@ -62,7 +63,7 @@ def __init__(self) -> None:
6263# Functions prefixed with async_ are NOT threadsafe and must
6364# be run in the event loop.
646565-def _async_add(self, entry: DNSRecord) -> bool:
66+def _async_add(self, record: _DNSRecord) -> bool:
6667"""Adds an entry.
67686869 Returns true if the entry was not already in the cache.
@@ -75,11 +76,11 @@ def _async_add(self, entry: DNSRecord) -> bool:
7576# replaces any existing records that are __eq__ to each other which
7677# removes the risk that accessing the cache from the wrong
7778# direction would return the old incorrect entry.
78-store = self.cache.setdefault(entry.key, {})
79-new = entry not in store and not isinstance(entry, DNSNsec)
80-store[entry] = entry
81-if isinstance(entry, DNSService):
82-self.service_cache.setdefault(entry.server_key, {})[entry] = entry
79+store = self.cache.setdefault(record.key, {})
80+new = record not in store and not isinstance(record, DNSNsec)
81+store[record] = record
82+if isinstance(record, DNSService):
83+self.service_cache.setdefault(record.server_key, {})[record] = record
8384return new
84858586def async_add_records(self, entries: Iterable[DNSRecord]) -> bool:
@@ -95,14 +96,14 @@ def async_add_records(self, entries: Iterable[DNSRecord]) -> bool:
9596new = True
9697return new
979898-def _async_remove(self, entry: DNSRecord) -> None:
99+def _async_remove(self, record: _DNSRecord) -> None:
99100"""Removes an entry.
100101101102 This function must be run in from event loop.
102103 """
103-if isinstance(entry, DNSService):
104-_remove_key(self.service_cache, entry.server_key, entry)
105-_remove_key(self.cache, entry.key, entry)
104+if isinstance(record, DNSService):
105+_remove_key(self.service_cache, record.server_key, record)
106+_remove_key(self.cache, record.key, record)
106107107108def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
108109"""Remove multiple records.
@@ -128,7 +129,10 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
128129 This function is not threadsafe and must be called from
129130 the event loop.
130131 """
131-return self.cache.get(entry.key, {}).get(entry)
132+store = self.cache.get(entry.key)
133+if store is None:
134+return None
135+return store.get(entry)
132136133137def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]:
134138"""Gets all matching entries by details.
@@ -138,7 +142,7 @@ def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[D
138142 """
139143key = name.lower()
140144for entry in self.cache.get(key, []):
141-if dns_entry_matches(entry, key, type_, class_):
145+if _dns_record_matches(entry, key, type_, class_):
142146yield entry
143147144148def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]:
@@ -185,15 +189,15 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco
185189 """
186190key = name.lower()
187191for cached_entry in reversed(list(self.cache.get(key, []))):
188-if dns_entry_matches(cached_entry, key, type_, class_):
192+if _dns_record_matches(cached_entry, key, type_, class_):
189193return cached_entry
190194return None
191195192196def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]:
193197"""Gets all matching entries by details."""
194198key = name.lower()
195199return [
196-entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_)
200+entry for entry in list(self.cache.get(key, [])) if _dns_record_matches(entry, key, type_, class_)
197201 ]
198202199203def entries_with_server(self, server: str) -> List[DNSRecord]:
@@ -218,3 +222,7 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D
218222def names(self) -> List[str]:
219223"""Return a copy of the list of current cache names."""
220224return list(self.cache)
225+226+227+def _dns_record_matches(record: _DNSRecord, key: _str, type_: int, class_: int) -> bool:
228+return key == record.key and type_ == record.type and class_ == record.class_