feat: implement heapq for tracking cache expire times (#1465) · python-zeroconf/python-zeroconf@09db184
@@ -20,6 +20,7 @@
2020USA
2121"""
222223+from heapq import heapify, heappop, heappush
2324from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast
24252526from ._dns import (
@@ -43,6 +44,11 @@
4344_float = float
4445_int = int
454647+# The minimum number of scheduled record expirations before we start cleaning up
48+# the expiration heap. This is a performance optimization to avoid cleaning up the
49+# heap too often when there are only a few scheduled expirations.
50+_MIN_SCHEDULED_RECORD_EXPIRATION = 100
51+46524753def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None:
4854"""Remove a key from a DNSRecord cache
@@ -60,6 +66,8 @@ class DNSCache:
60666167def __init__(self) -> None:
6268self.cache: _DNSRecordCacheType = {}
69+self._expire_heap: List[Tuple[float, DNSRecord]] = []
70+self._expirations: Dict[DNSRecord, float] = {}
6371self.service_cache: _DNSRecordCacheType = {}
64726573# Functions prefixed with async_ are NOT threadsafe and must
@@ -81,6 +89,12 @@ def _async_add(self, record: _DNSRecord) -> bool:
8189store = self.cache.setdefault(record.key, {})
8290new = record not in store and not isinstance(record, DNSNsec)
8391store[record] = record
92+when = record.created + (record.ttl * 1000)
93+if self._expirations.get(record) != when:
94+# Avoid adding duplicates to the heap
95+heappush(self._expire_heap, (when, record))
96+self._expirations[record] = when
97+8498if isinstance(record, DNSService):
8599service_record = record
86100self.service_cache.setdefault(record.server_key, {})[service_record] = service_record
@@ -108,6 +122,7 @@ def _async_remove(self, record: _DNSRecord) -> None:
108122service_record = record
109123_remove_key(self.service_cache, service_record.server_key, service_record)
110124_remove_key(self.cache, record.key, record)
125+self._expirations.pop(record, None)
111126112127def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
113128"""Remove multiple records.
@@ -121,8 +136,44 @@ def async_expire(self, now: _float) -> List[DNSRecord]:
121136"""Purge expired entries from the cache.
122137123138 This function must be run in from event loop.
139+140+ :param now: The current time in milliseconds.
124141 """
125-expired = [record for records in self.cache.values() for record in records if record.is_expired(now)]
142+if not (expire_heap_len := len(self._expire_heap)):
143+return []
144+145+expired: List[DNSRecord] = []
146+# Find any expired records and add them to the to-delete list
147+while self._expire_heap:
148+when, record = self._expire_heap[0]
149+if when > now:
150+break
151+heappop(self._expire_heap)
152+# Check if the record hasn't been re-added to the heap
153+# with a different expiration time as it will be removed
154+# later when it reaches the top of the heap and its
155+# expiration time is met.
156+if self._expirations.get(record) == when:
157+expired.append(record)
158+159+# If the expiration heap grows larger than the number expirations
160+# times two, we clean it up to avoid keeping expired entries in
161+# the heap and consuming memory. We guard this with a minimum
162+# threshold to avoid cleaning up the heap too often when there are
163+# only a few scheduled expirations.
164+if (
165+expire_heap_len > _MIN_SCHEDULED_RECORD_EXPIRATION
166+and expire_heap_len > len(self._expirations) * 2
167+ ):
168+# Remove any expired entries from the expiration heap
169+# that do not match the expiration time in the expirations
170+# as it means the record has been re-added to the heap
171+# with a different expiration time.
172+self._expire_heap = [
173+entry for entry in self._expire_heap if self._expirations.get(entry[1]) == entry[0]
174+ ]
175+heapify(self._expire_heap)
176+126177self.async_remove_records(expired)
127178return expired
128179@@ -256,4 +307,11 @@ def async_mark_unique_records_older_than_1s_to_expire(
256307created_double = record.created
257308if (now - created_double > _ONE_SECOND) and record not in answers_rrset:
258309# Expire in 1s
259-record.set_created_ttl(now, 1)
310+self._async_set_created_ttl(record, now, 1)
311+312+def _async_set_created_ttl(self, record: DNSRecord, now: _float, ttl: _float) -> None:
313+"""Set the created time and ttl of a record."""
314+# It would be better if we made a copy instead of mutating the record
315+# in place, but records currently don't have a copy method.
316+record._set_created_ttl(now, ttl)
317+self._async_add(record)