fix: addresses incorrect after server name change (#1154) · python-zeroconf/python-zeroconf@41ea06a

@@ -23,7 +23,7 @@

2323

import ipaddress

2424

import random

2525

from functools import lru_cache

26-

from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast

26+

from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast

27272828

from .._dns import (

2929

DNSAddress,

@@ -156,8 +156,8 @@ def __init__(

156156

self.port = port

157157

self.weight = weight

158158

self.priority = priority

159-

self.server = server if server else name

160-

self.server_key = self.server.lower()

159+

self.server = server if server else None

160+

self.server_key = server.lower() if server else None

161161

self._properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {}

162162

if isinstance(properties, bytes):

163163

self._set_text(properties)

@@ -205,7 +205,7 @@ def addresses(self, value: List[bytes]) -> None:

205205

"Addresses must either be IPv4 or IPv6 strings, bytes, or integers;"

206206

f" got {address!r}. Hint: convert string addresses with socket.inet_pton"

207207

)

208-

if isinstance(addr, ipaddress.IPv4Address):

208+

if addr.version == 4:

209209

self._ipv4_addresses.append(addr)

210210

else:

211211

self._ipv6_addresses.append(addr)

@@ -339,6 +339,35 @@ def get_name(self) -> str:

339339

"""Name accessor"""

340340

return self.name[: len(self.name) - len(self.type) - 1]

341341342+

def _get_ip_addresses_from_cache_lifo(

343+

self, zc: 'Zeroconf', now: float, type: int

344+

) -> List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]:

345+

"""Set IPv6 addresses from the cache."""

346+

address_list: List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]] = []

347+

for record in self._get_address_records_from_cache_by_type(zc, type):

348+

if record.is_expired(now):

349+

continue

350+

try:

351+

ip_address = _cached_ip_addresses(record.address)

352+

except ValueError:

353+

continue

354+

else:

355+

address_list.append(ip_address)

356+

address_list.reverse() # Reverse to get LIFO order

357+

return address_list

358+359+

def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:

360+

"""Set IPv6 addresses from the cache."""

361+

self._ipv6_addresses = cast(

362+

"List[ipaddress.IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)

363+

)

364+365+

def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:

366+

"""Set IPv4 addresses from the cache."""

367+

self._ipv4_addresses = cast(

368+

"List[ipaddress.IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)

369+

)

370+342371

def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None:

343372

"""Updates service information from a DNS record.

344373

@@ -348,7 +377,7 @@ def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord])

348377

This method will be run in the event loop.

349378

"""

350379

if record is not None:

351-

self._process_records_threadsafe(zc, now, [RecordUpdate(record, None)])

380+

self._process_record_threadsafe(zc, record, now)

352381353382

def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:

354383

"""Updates service information from a DNS record.

@@ -357,55 +386,77 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU

357386

"""

358387

self._process_records_threadsafe(zc, now, records)

359388360-

def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:

361-

"""Thread safe record updating."""

362-

seen_addresses: Set[bytes] = set()

389+

def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:

390+

"""Thread safe record updating.

391+392+

Returns True if new records were added.

393+

"""

394+

updated: bool = False

363395

for record_update in records:

364-

record = record_update.new

365-

if isinstance(record, DNSAddress):

366-

seen_addresses.add(record.address)

367-

self._process_record_threadsafe(record, now)

368-

for record in self._get_address_records_from_cache(zc):

369-

if record.address not in seen_addresses:

370-

self._process_record_threadsafe(record, now)

371-372-

def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None:

373-

"""Thread safe record updating."""

396+

updated |= self._process_record_threadsafe(zc, record_update.new, now)

397+

return updated

398+399+

def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float) -> bool:

400+

"""Thread safe record updating.

401+402+

Returns True if a new record was added.

403+

"""

374404

if record.is_expired(now):

375-

return

405+

return False

376406377-

if isinstance(record, DNSAddress):

378-

if record.key != self.server_key:

379-

return

407+

if record.key == self.server_key and isinstance(record, DNSAddress):

380408

try:

381409

ip_addr = _cached_ip_addresses(record.address)

382410

except ValueError as ex:

383411

log.warning("Encountered invalid address while processing %s: %s", record, ex)

384-

return

385-

if isinstance(ip_addr, ipaddress.IPv4Address):

412+

return False

413+414+

if ip_addr.version == 4:

415+

if not self._ipv4_addresses:

416+

self._set_ipv4_addresses_from_cache(zc, now)

417+386418

if ip_addr not in self._ipv4_addresses:

387419

self._ipv4_addresses.insert(0, ip_addr)

388-

return

420+

return True

421+

elif ip_addr != self._ipv4_addresses[0]:

422+

self._ipv4_addresses.remove(ip_addr)

423+

self._ipv4_addresses.insert(0, ip_addr)

424+425+

return False

426+427+

if not self._ipv6_addresses:

428+

self._set_ipv6_addresses_from_cache(zc, now)

429+389430

if ip_addr not in self._ipv6_addresses:

390431

self._ipv6_addresses.insert(0, ip_addr)

391-

if ip_addr.is_link_local:

392-

self.interface_index = record.scope_id

393-

return

432+

return True

433+

elif ip_addr != self._ipv6_addresses[0]:

434+

self._ipv6_addresses.remove(ip_addr)

435+

self._ipv6_addresses.insert(0, ip_addr)

394436395-

if isinstance(record, DNSText):

396-

if record.key == self.key:

397-

self._set_text(record.text)

398-

return

437+

return False

438+439+

if record.key != self.key:

440+

return False

441+442+

if record.type == _TYPE_TXT and isinstance(record, DNSText):

443+

self._set_text(record.text)

444+

return True

399445400-

if isinstance(record, DNSService):

401-

if record.key != self.key:

402-

return

446+

if record.type == _TYPE_SRV and isinstance(record, DNSService):

447+

old_server_key = self.server_key

403448

self.name = record.name

404449

self.server = record.server

405450

self.server_key = record.server.lower()

406451

self.port = record.port

407452

self.weight = record.weight

408453

self.priority = record.priority

454+

if old_server_key != self.server_key:

455+

self._set_ipv4_addresses_from_cache(zc, now)

456+

self._set_ipv6_addresses_from_cache(zc, now)

457+

return True

458+459+

return False

409460410461

def dns_addresses(

411462

self,

@@ -416,7 +467,7 @@ def dns_addresses(

416467

"""Return matching DNSAddress from ServiceInfo."""

417468

return [

418469

DNSAddress(

419-

self.server,

470+

self.server or self.name,

420471

_TYPE_AAAA if address.version == 6 else _TYPE_A,

421472

_CLASS_IN | _CLASS_UNIQUE,

422473

override_ttl if override_ttl is not None else self.host_ttl,

@@ -447,7 +498,7 @@ def dns_service(self, override_ttl: Optional[int] = None, created: Optional[floa

447498

self.priority,

448499

self.weight,

449500

cast(int, self.port),

450-

self.server,

501+

self.server or self.name,

451502

created,

452503

)

453504

@@ -462,35 +513,43 @@ def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float]

462513

created,

463514

)

464515465-

def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSAddress]:

466-

"""Get the address records from the cache."""

467-

return cast(

468-

"List[DNSAddress]",

469-

[

470-

*zc.cache.get_all_by_details(self.server, _TYPE_A, _CLASS_IN),

471-

*zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN),

472-

],

473-

)

516+

def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]:

517+

"""Get the addresses from the cache."""

518+

if self.server_key is None:

519+

return []

520+

return cast("List[DNSAddress]", zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN))

521+522+

def set_server_if_missing(self) -> None:

523+

"""Set the server if it is missing.

524+525+

This function is for backwards compatibility.

526+

"""

527+

if self.server is None:

528+

self.server = self.name

529+

self.server_key = self.server.lower()

474530475531

def load_from_cache(self, zc: 'Zeroconf') -> bool:

476532

"""Populate the service info from the cache.

477533478534

This method is designed to be threadsafe.

479535

"""

480536

now = current_time_millis()

481-

record_updates: List[RecordUpdate] = []

537+

original_server_key = self.server_key

482538

cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN)

483539

if cached_srv_record:

484-

# If there is a srv record, A and AAAA will already

485-

# be called and we do not want to do it twice

486-

record_updates.append(RecordUpdate(cached_srv_record, None))

487-

else:

488-

for record in self._get_address_records_from_cache(zc):

489-

record_updates.append(RecordUpdate(record, None))

540+

self._process_record_threadsafe(zc, cached_srv_record, now)

490541

cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN)

491542

if cached_txt_record:

492-

record_updates.append(RecordUpdate(cached_txt_record, None))

493-

self._process_records_threadsafe(zc, now, record_updates)

543+

self._process_record_threadsafe(zc, cached_txt_record, now)

544+

if original_server_key == self.server_key:

545+

# If there is a srv which changes the server_key,

546+

# A and AAAA will already be loaded from the cache

547+

# and we do not want to do it twice

548+

for record in [

549+

*self._get_address_records_from_cache_by_type(zc, _TYPE_A),

550+

*self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA),

551+

]:

552+

self._process_record_threadsafe(zc, record, now)

494553

return self._is_complete

495554496555

@property

@@ -560,8 +619,8 @@ def generate_request_query(

560619

out = DNSOutgoing(_FLAGS_QR_QUERY)

561620

out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN)

562621

out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN)

563-

out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN)

564-

out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN)

622+

out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_A, _CLASS_IN)

623+

out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_AAAA, _CLASS_IN)

565624

if question_type == DNSQuestionType.QU:

566625

for question in out.questions:

567626

question.unicast = True