feat: implement heapq for tracking cache expire times (#1465) · python-zeroconf/python-zeroconf@09db184

@@ -20,6 +20,7 @@

2020

USA

2121

"""

222223+

from heapq import heapify, heappop, heappush

2324

from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast

24252526

from ._dns import (

@@ -43,6 +44,11 @@

4344

_float = float

4445

_int = int

454647+

# The minimum number of scheduled record expirations before we start cleaning up

48+

# the expiration heap. This is a performance optimization to avoid cleaning up the

49+

# heap too often when there are only a few scheduled expirations.

50+

_MIN_SCHEDULED_RECORD_EXPIRATION = 100

51+46524753

def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None:

4854

"""Remove a key from a DNSRecord cache

@@ -60,6 +66,8 @@ class DNSCache:

60666167

def __init__(self) -> None:

6268

self.cache: _DNSRecordCacheType = {}

69+

self._expire_heap: List[Tuple[float, DNSRecord]] = []

70+

self._expirations: Dict[DNSRecord, float] = {}

6371

self.service_cache: _DNSRecordCacheType = {}

64726573

# Functions prefixed with async_ are NOT threadsafe and must

@@ -81,6 +89,12 @@ def _async_add(self, record: _DNSRecord) -> bool:

8189

store = self.cache.setdefault(record.key, {})

8290

new = record not in store and not isinstance(record, DNSNsec)

8391

store[record] = record

92+

when = record.created + (record.ttl * 1000)

93+

if self._expirations.get(record) != when:

94+

# Avoid adding duplicates to the heap

95+

heappush(self._expire_heap, (when, record))

96+

self._expirations[record] = when

97+8498

if isinstance(record, DNSService):

8599

service_record = record

86100

self.service_cache.setdefault(record.server_key, {})[service_record] = service_record

@@ -108,6 +122,7 @@ def _async_remove(self, record: _DNSRecord) -> None:

108122

service_record = record

109123

_remove_key(self.service_cache, service_record.server_key, service_record)

110124

_remove_key(self.cache, record.key, record)

125+

self._expirations.pop(record, None)

111126112127

def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:

113128

"""Remove multiple records.

@@ -121,8 +136,44 @@ def async_expire(self, now: _float) -> List[DNSRecord]:

121136

"""Purge expired entries from the cache.

122137123138

This function must be run in from event loop.

139+140+

:param now: The current time in milliseconds.

124141

"""

125-

expired = [record for records in self.cache.values() for record in records if record.is_expired(now)]

142+

if not (expire_heap_len := len(self._expire_heap)):

143+

return []

144+145+

expired: List[DNSRecord] = []

146+

# Find any expired records and add them to the to-delete list

147+

while self._expire_heap:

148+

when, record = self._expire_heap[0]

149+

if when > now:

150+

break

151+

heappop(self._expire_heap)

152+

# Check if the record hasn't been re-added to the heap

153+

# with a different expiration time as it will be removed

154+

# later when it reaches the top of the heap and its

155+

# expiration time is met.

156+

if self._expirations.get(record) == when:

157+

expired.append(record)

158+159+

# If the expiration heap grows larger than the number expirations

160+

# times two, we clean it up to avoid keeping expired entries in

161+

# the heap and consuming memory. We guard this with a minimum

162+

# threshold to avoid cleaning up the heap too often when there are

163+

# only a few scheduled expirations.

164+

if (

165+

expire_heap_len > _MIN_SCHEDULED_RECORD_EXPIRATION

166+

and expire_heap_len > len(self._expirations) * 2

167+

):

168+

# Remove any expired entries from the expiration heap

169+

# that do not match the expiration time in the expirations

170+

# as it means the record has been re-added to the heap

171+

# with a different expiration time.

172+

self._expire_heap = [

173+

entry for entry in self._expire_heap if self._expirations.get(entry[1]) == entry[0]

174+

]

175+

heapify(self._expire_heap)

176+126177

self.async_remove_records(expired)

127178

return expired

128179

@@ -256,4 +307,11 @@ def async_mark_unique_records_older_than_1s_to_expire(

256307

created_double = record.created

257308

if (now - created_double > _ONE_SECOND) and record not in answers_rrset:

258309

# Expire in 1s

259-

record.set_created_ttl(now, 1)

310+

self._async_set_created_ttl(record, now, 1)

311+312+

def _async_set_created_ttl(self, record: DNSRecord, now: _float, ttl: _float) -> None:

313+

"""Set the created time and ttl of a record."""

314+

# It would be better if we made a copy instead of mutating the record

315+

# in place, but records currently don't have a copy method.

316+

record._set_created_ttl(now, ttl)

317+

self._async_add(record)