feat: speed up processing incoming records (#1179) · python-zeroconf/python-zeroconf@d919316

@@ -21,7 +21,7 @@

2121

"""

22222323

import itertools

24-

from typing import Dict, Iterable, Iterator, List, Optional, Union, cast

24+

from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast

25252626

from ._dns import (

2727

DNSAddress,

@@ -34,13 +34,15 @@

3434

DNSText,

3535

)

3636

from ._utils.time import current_time_millis

37-

from .const import _TYPE_PTR

37+

from .const import _ONE_SECOND, _TYPE_PTR

38383939

_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService)

4040

_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService]

4141

_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]]

4242

_DNSRecord = DNSRecord

4343

_str = str

44+

_float = float

45+

_int = int

444645474648

def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None:

@@ -134,19 +136,29 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:

134136

return None

135137

return store.get(entry)

136138137-

def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterator[DNSRecord]:

139+

def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterable[DNSRecord]:

138140

"""Gets all matching entries by details.

139141140-

This function is not threadsafe and must be called from

142+

This function is not thread-safe and must be called from

143+

the event loop.

144+

"""

145+

return self._async_all_by_details(name, type_, class_)

146+147+

def _async_all_by_details(self, name: _str, type_: int, class_: int) -> List[DNSRecord]:

148+

"""Gets all matching entries by details.

149+150+

This function is not thread-safe and must be called from

141151

the event loop.

142152

"""

143153

key = name.lower()

144154

records = self.cache.get(key)

155+

matches: List[DNSRecord] = []

145156

if records is None:

146-

return

147-

for entry in records:

148-

if _dns_record_matches(entry, key, type_, class_):

149-

yield entry

157+

return matches

158+

for record in records:

159+

if _dns_record_matches(record, key, type_, class_):

160+

matches.append(record)

161+

return matches

150162151163

def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]:

152164

"""Returns a dict of entries whose key matches the name.

@@ -226,6 +238,25 @@ def names(self) -> List[str]:

226238

"""Return a copy of the list of current cache names."""

227239

return list(self.cache)

228240241+

def async_mark_unique_records_older_than_1s_to_expire(

242+

self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float

243+

) -> None:

244+

self._async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now)

245+246+

def _async_mark_unique_records_older_than_1s_to_expire(

247+

self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float

248+

) -> None:

249+

# rfc6762#section-10.2 para 2

250+

# Since unique is set, all old records with that name, rrtype,

251+

# and rrclass that were received more than one second ago are declared

252+

# invalid, and marked to expire from the cache in one second.

253+

answers_rrset = set(answers)

254+

for name, type_, class_ in unique_types:

255+

for record in self._async_all_by_details(name, type_, class_):

256+

if (now - record.created > _ONE_SECOND) and record not in answers_rrset:

257+

# Expire in 1s

258+

record.set_created_ttl(now, 1)

259+229260230261

def _dns_record_matches(record: _DNSRecord, key: _str, type_: int, class_: int) -> bool:

231262

return key == record.key and type_ == record.type and class_ == record.class_