feat: optimize the dns cache (#1119) · python-zeroconf/python-zeroconf@e80fcef

@@ -32,22 +32,23 @@

3232

DNSRecord,

3333

DNSService,

3434

DNSText,

35-

dns_entry_matches,

3635

)

3736

from ._utils.time import current_time_millis

3837

from .const import _TYPE_PTR

39384039

_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService)

4140

_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService]

4241

_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]]

42+

_DNSRecord = DNSRecord

43+

_str = str

4344444545-

def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None:

46+

def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None:

4647

"""Remove a key from a DNSRecord cache

47484849

This function must be run in from event loop.

4950

"""

50-

del cache[key][entry]

51+

del cache[key][record]

5152

if not cache[key]:

5253

del cache[key]

5354

@@ -62,7 +63,7 @@ def __init__(self) -> None:

6263

# Functions prefixed with async_ are NOT threadsafe and must

6364

# be run in the event loop.

646565-

def _async_add(self, entry: DNSRecord) -> bool:

66+

def _async_add(self, record: _DNSRecord) -> bool:

6667

"""Adds an entry.

67686869

Returns true if the entry was not already in the cache.

@@ -75,11 +76,11 @@ def _async_add(self, entry: DNSRecord) -> bool:

7576

# replaces any existing records that are __eq__ to each other which

7677

# removes the risk that accessing the cache from the wrong

7778

# direction would return the old incorrect entry.

78-

store = self.cache.setdefault(entry.key, {})

79-

new = entry not in store and not isinstance(entry, DNSNsec)

80-

store[entry] = entry

81-

if isinstance(entry, DNSService):

82-

self.service_cache.setdefault(entry.server_key, {})[entry] = entry

79+

store = self.cache.setdefault(record.key, {})

80+

new = record not in store and not isinstance(record, DNSNsec)

81+

store[record] = record

82+

if isinstance(record, DNSService):

83+

self.service_cache.setdefault(record.server_key, {})[record] = record

8384

return new

84858586

def async_add_records(self, entries: Iterable[DNSRecord]) -> bool:

@@ -95,14 +96,14 @@ def async_add_records(self, entries: Iterable[DNSRecord]) -> bool:

9596

new = True

9697

return new

979898-

def _async_remove(self, entry: DNSRecord) -> None:

99+

def _async_remove(self, record: _DNSRecord) -> None:

99100

"""Removes an entry.

100101101102

This function must be run in from event loop.

102103

"""

103-

if isinstance(entry, DNSService):

104-

_remove_key(self.service_cache, entry.server_key, entry)

105-

_remove_key(self.cache, entry.key, entry)

104+

if isinstance(record, DNSService):

105+

_remove_key(self.service_cache, record.server_key, record)

106+

_remove_key(self.cache, record.key, record)

106107107108

def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:

108109

"""Remove multiple records.

@@ -128,7 +129,10 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:

128129

This function is not threadsafe and must be called from

129130

the event loop.

130131

"""

131-

return self.cache.get(entry.key, {}).get(entry)

132+

store = self.cache.get(entry.key)

133+

if store is None:

134+

return None

135+

return store.get(entry)

132136133137

def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]:

134138

"""Gets all matching entries by details.

@@ -138,7 +142,7 @@ def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[D

138142

"""

139143

key = name.lower()

140144

for entry in self.cache.get(key, []):

141-

if dns_entry_matches(entry, key, type_, class_):

145+

if _dns_record_matches(entry, key, type_, class_):

142146

yield entry

143147144148

def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]:

@@ -185,15 +189,15 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco

185189

"""

186190

key = name.lower()

187191

for cached_entry in reversed(list(self.cache.get(key, []))):

188-

if dns_entry_matches(cached_entry, key, type_, class_):

192+

if _dns_record_matches(cached_entry, key, type_, class_):

189193

return cached_entry

190194

return None

191195192196

def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]:

193197

"""Gets all matching entries by details."""

194198

key = name.lower()

195199

return [

196-

entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_)

200+

entry for entry in list(self.cache.get(key, [])) if _dns_record_matches(entry, key, type_, class_)

197201

]

198202199203

def entries_with_server(self, server: str) -> List[DNSRecord]:

@@ -218,3 +222,7 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D

218222

def names(self) -> List[str]:

219223

"""Return a copy of the list of current cache names."""

220224

return list(self.cache)

225+226+227+

def _dns_record_matches(record: _DNSRecord, key: _str, type_: int, class_: int) -> bool:

228+

return key == record.key and type_ == record.type and class_ == record.class_