Fix handling of IPv6 addresses with scope_id in ServiceInfo · python-zeroconf/python-zeroconf@ae9b7c2

@@ -22,6 +22,7 @@

22222323

import asyncio

2424

import random

25+

import sys

2526

from functools import lru_cache

2627

from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address

2728

from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast

7879

# the A/AAAA/SRV records for a host.

7980

_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)

808182+

bytes_ = bytes

8183

float_ = float

8284

int_ = int

83858486

DNS_QUESTION_TYPE_QU = DNSQuestionType.QU

8587

DNS_QUESTION_TYPE_QM = DNSQuestionType.QM

868889+

IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 8, 0)

90+8791

if TYPE_CHECKING:

8892

from .._core import Zeroconf

8993

@@ -110,6 +114,29 @@ def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4

110114

_cached_ip_addresses_wrapper = _cached_ip_addresses

111115112116117+

def _get_ip_address_object_from_record(record: DNSAddress) -> Optional[Union[IPv4Address, IPv6Address]]:

118+

"""Get the IP address object from the record."""

119+

if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id is not None:

120+

return _ip_bytes_and_scope_to_address(record.address, record.scope_id)

121+

return _cached_ip_addresses_wrapper(record.address)

122+123+124+

def _ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Union[IPv4Address, IPv6Address]]:

125+

"""Convert the bytes and scope to an IP address object."""

126+

base_address = _cached_ip_addresses_wrapper(address)

127+

if base_address is not None and base_address.is_link_local:

128+

return _cached_ip_addresses_wrapper(f"{base_address}%{scope}")

129+

return base_address

130+131+132+

def _str_without_scope_id(addr: Union[IPv4Address, IPv6Address]) -> str:

133+

"""Return the string representation of the address without the scope id."""

134+

if IPADDRESS_SUPPORTS_SCOPE_ID and addr.version == 6:

135+

address_str = str(addr)

136+

return address_str.partition('%')[0]

137+

return str(addr)

138+139+113140

class ServiceInfo(RecordUpdateListener):

114141

"""Service information.

115142

@@ -177,6 +204,7 @@ def __init__(

177204

raise TypeError("addresses and parsed_addresses cannot be provided together")

178205

if not type_.endswith(service_type_name(name, strict=False)):

179206

raise BadTypeInNameException

207+

self.interface_index = interface_index

180208

self.text = b''

181209

self.type = type_

182210

self._name = name

@@ -199,7 +227,6 @@ def __init__(

199227

self._set_properties(properties)

200228

self.host_ttl = host_ttl

201229

self.other_ttl = other_ttl

202-

self.interface_index = interface_index

203230

self._new_records_futures: Optional[Set[asyncio.Future]] = None

204231

self._dns_address_cache: Optional[List[DNSAddress]] = None

205232

self._dns_pointer_cache: Optional[DNSPointer] = None

@@ -243,7 +270,10 @@ def addresses(self, value: List[bytes]) -> None:

243270

self._get_address_and_nsec_records_cache = None

244271245272

for address in value:

246-

addr = _cached_ip_addresses_wrapper(address)

273+

if IPADDRESS_SUPPORTS_SCOPE_ID and len(address) == 16 and self.interface_index is not None:

274+

addr = _ip_bytes_and_scope_to_address(address, self.interface_index)

275+

else:

276+

addr = _cached_ip_addresses_wrapper(address)

247277

if addr is None:

248278

raise TypeError(

249279

"Addresses must either be IPv4 or IPv6 strings, bytes, or integers;"

@@ -322,10 +352,10 @@ def ip_addresses_by_version(

322352323353

def _ip_addresses_by_version_value(

324354

self, version_value: int_

325-

) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]:

355+

) -> Union[List[IPv4Address], List[IPv6Address]]:

326356

"""Backend for addresses_by_version that uses the raw value."""

327357

if version_value == _IPVersion_All_value:

328-

return [*self._ipv4_addresses, *self._ipv6_addresses]

358+

return [*self._ipv4_addresses, *self._ipv6_addresses] # type: ignore[return-value]

329359

if version_value == _IPVersion_V4Only_value:

330360

return self._ipv4_addresses

331361

return self._ipv6_addresses

@@ -339,7 +369,7 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:

339369

This means the first address will always be the most recently added

340370

address of the given IP version.

341371

"""

342-

return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)]

372+

return [_str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)]

343373344374

def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:

345375

"""Equivalent to parsed_addresses, with the exception that IPv6 Link-Local

@@ -351,12 +381,7 @@ def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[st

351381

This means the first address will always be the most recently added

352382

address of the given IP version.

353383

"""

354-

if self.interface_index is None:

355-

return self.parsed_addresses(version)

356-

return [

357-

f"{addr}%{self.interface_index}" if addr.version == 6 and addr.is_link_local else str(addr)

358-

for addr in self._ip_addresses_by_version_value(version.value)

359-

]

384+

return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)]

360385361386

def _set_properties(self, properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]]) -> None:

362387

"""Sets properties and text of this info from a dictionary"""

@@ -421,8 +446,8 @@ def _get_ip_addresses_from_cache_lifo(

421446

for record in self._get_address_records_from_cache_by_type(zc, type):

422447

if record.is_expired(now):

423448

continue

424-

ip_addr = _cached_ip_addresses_wrapper(record.address)

425-

if ip_addr is not None:

449+

ip_addr = _get_ip_address_object_from_record(record)

450+

if ip_addr is not None and ip_addr not in address_list:

426451

address_list.append(ip_addr)

427452

address_list.reverse() # Reverse to get LIFO order

428453

return address_list

@@ -471,7 +496,7 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo

471496

dns_address_record = record

472497

if TYPE_CHECKING:

473498

assert isinstance(dns_address_record, DNSAddress)

474-

ip_addr = _cached_ip_addresses_wrapper(dns_address_record.address)

499+

ip_addr = _get_ip_address_object_from_record(dns_address_record)

475500

if ip_addr is None:

476501

log.warning(

477502

"Encountered invalid address while processing %s: %s",