Fix handling of IPv6 addresses with scope_id in ServiceInfo · python-zeroconf/python-zeroconf@ae9b7c2
@@ -22,6 +22,7 @@
22222323import asyncio
2424import random
25+import sys
2526from functools import lru_cache
2627from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address
2728from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast
7879# the A/AAAA/SRV records for a host.
7980_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)
808182+bytes_ = bytes
8183float_ = float
8284int_ = int
83858486DNS_QUESTION_TYPE_QU = DNSQuestionType.QU
8587DNS_QUESTION_TYPE_QM = DNSQuestionType.QM
868889+IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 8, 0)
90+8791if TYPE_CHECKING:
8892from .._core import Zeroconf
8993@@ -110,6 +114,29 @@ def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4
110114_cached_ip_addresses_wrapper = _cached_ip_addresses
111115112116117+def _get_ip_address_object_from_record(record: DNSAddress) -> Optional[Union[IPv4Address, IPv6Address]]:
118+"""Get the IP address object from the record."""
119+if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id is not None:
120+return _ip_bytes_and_scope_to_address(record.address, record.scope_id)
121+return _cached_ip_addresses_wrapper(record.address)
122+123+124+def _ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Union[IPv4Address, IPv6Address]]:
125+"""Convert the bytes and scope to an IP address object."""
126+base_address = _cached_ip_addresses_wrapper(address)
127+if base_address is not None and base_address.is_link_local:
128+return _cached_ip_addresses_wrapper(f"{base_address}%{scope}")
129+return base_address
130+131+132+def _str_without_scope_id(addr: Union[IPv4Address, IPv6Address]) -> str:
133+"""Return the string representation of the address without the scope id."""
134+if IPADDRESS_SUPPORTS_SCOPE_ID and addr.version == 6:
135+address_str = str(addr)
136+return address_str.partition('%')[0]
137+return str(addr)
138+139+113140class ServiceInfo(RecordUpdateListener):
114141"""Service information.
115142@@ -177,6 +204,7 @@ def __init__(
177204raise TypeError("addresses and parsed_addresses cannot be provided together")
178205if not type_.endswith(service_type_name(name, strict=False)):
179206raise BadTypeInNameException
207+self.interface_index = interface_index
180208self.text = b''
181209self.type = type_
182210self._name = name
@@ -199,7 +227,6 @@ def __init__(
199227self._set_properties(properties)
200228self.host_ttl = host_ttl
201229self.other_ttl = other_ttl
202-self.interface_index = interface_index
203230self._new_records_futures: Optional[Set[asyncio.Future]] = None
204231self._dns_address_cache: Optional[List[DNSAddress]] = None
205232self._dns_pointer_cache: Optional[DNSPointer] = None
@@ -243,7 +270,10 @@ def addresses(self, value: List[bytes]) -> None:
243270self._get_address_and_nsec_records_cache = None
244271245272for address in value:
246-addr = _cached_ip_addresses_wrapper(address)
273+if IPADDRESS_SUPPORTS_SCOPE_ID and len(address) == 16 and self.interface_index is not None:
274+addr = _ip_bytes_and_scope_to_address(address, self.interface_index)
275+else:
276+addr = _cached_ip_addresses_wrapper(address)
247277if addr is None:
248278raise TypeError(
249279"Addresses must either be IPv4 or IPv6 strings, bytes, or integers;"
@@ -322,10 +352,10 @@ def ip_addresses_by_version(
322352323353def _ip_addresses_by_version_value(
324354self, version_value: int_
325- ) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]:
355+ ) -> Union[List[IPv4Address], List[IPv6Address]]:
326356"""Backend for addresses_by_version that uses the raw value."""
327357if version_value == _IPVersion_All_value:
328-return [*self._ipv4_addresses, *self._ipv6_addresses]
358+return [*self._ipv4_addresses, *self._ipv6_addresses] # type: ignore[return-value]
329359if version_value == _IPVersion_V4Only_value:
330360return self._ipv4_addresses
331361return self._ipv6_addresses
@@ -339,7 +369,7 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
339369 This means the first address will always be the most recently added
340370 address of the given IP version.
341371 """
342-return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)]
372+return [_str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)]
343373344374def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
345375"""Equivalent to parsed_addresses, with the exception that IPv6 Link-Local
@@ -351,12 +381,7 @@ def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[st
351381 This means the first address will always be the most recently added
352382 address of the given IP version.
353383 """
354-if self.interface_index is None:
355-return self.parsed_addresses(version)
356-return [
357-f"{addr}%{self.interface_index}" if addr.version == 6 and addr.is_link_local else str(addr)
358-for addr in self._ip_addresses_by_version_value(version.value)
359- ]
384+return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)]
360385361386def _set_properties(self, properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]]) -> None:
362387"""Sets properties and text of this info from a dictionary"""
@@ -421,8 +446,8 @@ def _get_ip_addresses_from_cache_lifo(
421446for record in self._get_address_records_from_cache_by_type(zc, type):
422447if record.is_expired(now):
423448continue
424-ip_addr = _cached_ip_addresses_wrapper(record.address)
425-if ip_addr is not None:
449+ip_addr = _get_ip_address_object_from_record(record)
450+if ip_addr is not None and ip_addr not in address_list:
426451address_list.append(ip_addr)
427452address_list.reverse() # Reverse to get LIFO order
428453return address_list
@@ -471,7 +496,7 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo
471496dns_address_record = record
472497if TYPE_CHECKING:
473498assert isinstance(dns_address_record, DNSAddress)
474-ip_addr = _cached_ip_addresses_wrapper(dns_address_record.address)
499+ip_addr = _get_ip_address_object_from_record(dns_address_record)
475500if ip_addr is None:
476501log.warning(
477502"Encountered invalid address while processing %s: %s",