fix: addresses incorrect after server name change (#1154) · python-zeroconf/python-zeroconf@41ea06a
@@ -23,7 +23,7 @@
2323import ipaddress
2424import random
2525from functools import lru_cache
26-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast
26+from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
27272828from .._dns import (
2929DNSAddress,
@@ -156,8 +156,8 @@ def __init__(
156156self.port = port
157157self.weight = weight
158158self.priority = priority
159-self.server = server if server else name
160-self.server_key = self.server.lower()
159+self.server = server if server else None
160+self.server_key = server.lower() if server else None
161161self._properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {}
162162if isinstance(properties, bytes):
163163self._set_text(properties)
@@ -205,7 +205,7 @@ def addresses(self, value: List[bytes]) -> None:
205205"Addresses must either be IPv4 or IPv6 strings, bytes, or integers;"
206206f" got {address!r}. Hint: convert string addresses with socket.inet_pton"
207207 )
208-if isinstance(addr, ipaddress.IPv4Address):
208+if addr.version == 4:
209209self._ipv4_addresses.append(addr)
210210else:
211211self._ipv6_addresses.append(addr)
@@ -339,6 +339,35 @@ def get_name(self) -> str:
339339"""Name accessor"""
340340return self.name[: len(self.name) - len(self.type) - 1]
341341342+def _get_ip_addresses_from_cache_lifo(
343+self, zc: 'Zeroconf', now: float, type: int
344+ ) -> List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]:
345+"""Set IPv6 addresses from the cache."""
346+address_list: List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]] = []
347+for record in self._get_address_records_from_cache_by_type(zc, type):
348+if record.is_expired(now):
349+continue
350+try:
351+ip_address = _cached_ip_addresses(record.address)
352+except ValueError:
353+continue
354+else:
355+address_list.append(ip_address)
356+address_list.reverse() # Reverse to get LIFO order
357+return address_list
358+359+def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
360+"""Set IPv6 addresses from the cache."""
361+self._ipv6_addresses = cast(
362+"List[ipaddress.IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)
363+ )
364+365+def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
366+"""Set IPv4 addresses from the cache."""
367+self._ipv4_addresses = cast(
368+"List[ipaddress.IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)
369+ )
370+342371def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None:
343372"""Updates service information from a DNS record.
344373@@ -348,7 +377,7 @@ def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord])
348377 This method will be run in the event loop.
349378 """
350379if record is not None:
351-self._process_records_threadsafe(zc, now, [RecordUpdate(record, None)])
380+self._process_record_threadsafe(zc, record, now)
352381353382def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
354383"""Updates service information from a DNS record.
@@ -357,55 +386,77 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
357386 """
358387self._process_records_threadsafe(zc, now, records)
359388360-def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
361-"""Thread safe record updating."""
362-seen_addresses: Set[bytes] = set()
389+def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:
390+"""Thread safe record updating.
391+392+ Returns True if new records were added.
393+ """
394+updated: bool = False
363395for record_update in records:
364-record = record_update.new
365-if isinstance(record, DNSAddress):
366-seen_addresses.add(record.address)
367-self._process_record_threadsafe(record, now)
368-for record in self._get_address_records_from_cache(zc):
369-if record.address not in seen_addresses:
370-self._process_record_threadsafe(record, now)
371-372-def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None:
373-"""Thread safe record updating."""
396+updated |= self._process_record_threadsafe(zc, record_update.new, now)
397+return updated
398+399+def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float) -> bool:
400+"""Thread safe record updating.
401+402+ Returns True if a new record was added.
403+ """
374404if record.is_expired(now):
375-return
405+return False
376406377-if isinstance(record, DNSAddress):
378-if record.key != self.server_key:
379-return
407+if record.key == self.server_key and isinstance(record, DNSAddress):
380408try:
381409ip_addr = _cached_ip_addresses(record.address)
382410except ValueError as ex:
383411log.warning("Encountered invalid address while processing %s: %s", record, ex)
384-return
385-if isinstance(ip_addr, ipaddress.IPv4Address):
412+return False
413+414+if ip_addr.version == 4:
415+if not self._ipv4_addresses:
416+self._set_ipv4_addresses_from_cache(zc, now)
417+386418if ip_addr not in self._ipv4_addresses:
387419self._ipv4_addresses.insert(0, ip_addr)
388-return
420+return True
421+elif ip_addr != self._ipv4_addresses[0]:
422+self._ipv4_addresses.remove(ip_addr)
423+self._ipv4_addresses.insert(0, ip_addr)
424+425+return False
426+427+if not self._ipv6_addresses:
428+self._set_ipv6_addresses_from_cache(zc, now)
429+389430if ip_addr not in self._ipv6_addresses:
390431self._ipv6_addresses.insert(0, ip_addr)
391-if ip_addr.is_link_local:
392-self.interface_index = record.scope_id
393-return
432+return True
433+elif ip_addr != self._ipv6_addresses[0]:
434+self._ipv6_addresses.remove(ip_addr)
435+self._ipv6_addresses.insert(0, ip_addr)
394436395-if isinstance(record, DNSText):
396-if record.key == self.key:
397-self._set_text(record.text)
398-return
437+return False
438+439+if record.key != self.key:
440+return False
441+442+if record.type == _TYPE_TXT and isinstance(record, DNSText):
443+self._set_text(record.text)
444+return True
399445400-if isinstance(record, DNSService):
401-if record.key != self.key:
402-return
446+if record.type == _TYPE_SRV and isinstance(record, DNSService):
447+old_server_key = self.server_key
403448self.name = record.name
404449self.server = record.server
405450self.server_key = record.server.lower()
406451self.port = record.port
407452self.weight = record.weight
408453self.priority = record.priority
454+if old_server_key != self.server_key:
455+self._set_ipv4_addresses_from_cache(zc, now)
456+self._set_ipv6_addresses_from_cache(zc, now)
457+return True
458+459+return False
409460410461def dns_addresses(
411462self,
@@ -416,7 +467,7 @@ def dns_addresses(
416467"""Return matching DNSAddress from ServiceInfo."""
417468return [
418469DNSAddress(
419-self.server,
470+self.server or self.name,
420471_TYPE_AAAA if address.version == 6 else _TYPE_A,
421472_CLASS_IN | _CLASS_UNIQUE,
422473override_ttl if override_ttl is not None else self.host_ttl,
@@ -447,7 +498,7 @@ def dns_service(self, override_ttl: Optional[int] = None, created: Optional[floa
447498self.priority,
448499self.weight,
449500cast(int, self.port),
450-self.server,
501+self.server or self.name,
451502created,
452503 )
453504@@ -462,35 +513,43 @@ def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float]
462513created,
463514 )
464515465-def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSAddress]:
466-"""Get the address records from the cache."""
467-return cast(
468-"List[DNSAddress]",
469- [
470-*zc.cache.get_all_by_details(self.server, _TYPE_A, _CLASS_IN),
471-*zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN),
472- ],
473- )
516+def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]:
517+"""Get the addresses from the cache."""
518+if self.server_key is None:
519+return []
520+return cast("List[DNSAddress]", zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN))
521+522+def set_server_if_missing(self) -> None:
523+"""Set the server if it is missing.
524+525+ This function is for backwards compatibility.
526+ """
527+if self.server is None:
528+self.server = self.name
529+self.server_key = self.server.lower()
474530475531def load_from_cache(self, zc: 'Zeroconf') -> bool:
476532"""Populate the service info from the cache.
477533478534 This method is designed to be threadsafe.
479535 """
480536now = current_time_millis()
481-record_updates: List[RecordUpdate] = []
537+original_server_key = self.server_key
482538cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN)
483539if cached_srv_record:
484-# If there is a srv record, A and AAAA will already
485-# be called and we do not want to do it twice
486-record_updates.append(RecordUpdate(cached_srv_record, None))
487-else:
488-for record in self._get_address_records_from_cache(zc):
489-record_updates.append(RecordUpdate(record, None))
540+self._process_record_threadsafe(zc, cached_srv_record, now)
490541cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN)
491542if cached_txt_record:
492-record_updates.append(RecordUpdate(cached_txt_record, None))
493-self._process_records_threadsafe(zc, now, record_updates)
543+self._process_record_threadsafe(zc, cached_txt_record, now)
544+if original_server_key == self.server_key:
545+# If there is a srv which changes the server_key,
546+# A and AAAA will already be loaded from the cache
547+# and we do not want to do it twice
548+for record in [
549+*self._get_address_records_from_cache_by_type(zc, _TYPE_A),
550+*self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA),
551+ ]:
552+self._process_record_threadsafe(zc, record, now)
494553return self._is_complete
495554496555@property
@@ -560,8 +619,8 @@ def generate_request_query(
560619out = DNSOutgoing(_FLAGS_QR_QUERY)
561620out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN)
562621out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN)
563-out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN)
564-out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN)
622+out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_A, _CLASS_IN)
623+out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_AAAA, _CLASS_IN)
565624if question_type == DNSQuestionType.QU:
566625for question in out.questions:
567626question.unicast = True