fix: correct handling of IPv6 addresses with scope_id in ServiceInfo by bdraco · Pull Request #1322 · python-zeroconf/python-zeroconf
Expand Up
@@ -22,6 +22,7 @@
import asyncio import random import sys from functools import lru_cache from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast Expand Down Expand Up @@ -78,12 +79,15 @@ # the A/AAAA/SRV records for a host. _AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)
bytes_ = bytes float_ = float int_ = int
DNS_QUESTION_TYPE_QU = DNSQuestionType.QU DNS_QUESTION_TYPE_QM = DNSQuestionType.QM
IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0)
if TYPE_CHECKING: from .._core import Zeroconf
Expand All @@ -110,6 +114,29 @@ def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4 _cached_ip_addresses_wrapper = _cached_ip_addresses
def _get_ip_address_object_from_record(record: DNSAddress) -> Optional[Union[IPv4Address, IPv6Address]]: """Get the IP address object from the record.""" if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id is not None: return _ip_bytes_and_scope_to_address(record.address, record.scope_id) return _cached_ip_addresses_wrapper(record.address)
def _ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Union[IPv4Address, IPv6Address]]: """Convert the bytes and scope to an IP address object.""" base_address = _cached_ip_addresses_wrapper(address) if base_address is not None and base_address.is_link_local: return _cached_ip_addresses_wrapper(f"{base_address}%{scope}") return base_address
def _str_without_scope_id(addr: Union[IPv4Address, IPv6Address]) -> str: """Return the string representation of the address without the scope id.""" if IPADDRESS_SUPPORTS_SCOPE_ID and addr.version == 6: address_str = str(addr) return address_str.partition('%')[0] return str(addr)
class ServiceInfo(RecordUpdateListener): """Service information.
Expand Down Expand Up @@ -177,6 +204,7 @@ def __init__( raise TypeError("addresses and parsed_addresses cannot be provided together") if not type_.endswith(service_type_name(name, strict=False)): raise BadTypeInNameException self.interface_index = interface_index self.text = b'' self.type = type_ self._name = name Expand All @@ -199,7 +227,6 @@ def __init__( self._set_properties(properties) self.host_ttl = host_ttl self.other_ttl = other_ttl self.interface_index = interface_index self._new_records_futures: Optional[Set[asyncio.Future]] = None self._dns_address_cache: Optional[List[DNSAddress]] = None self._dns_pointer_cache: Optional[DNSPointer] = None Expand Down Expand Up @@ -243,7 +270,10 @@ def addresses(self, value: List[bytes]) -> None: self._get_address_and_nsec_records_cache = None
for address in value: addr = _cached_ip_addresses_wrapper(address) if IPADDRESS_SUPPORTS_SCOPE_ID and len(address) == 16 and self.interface_index is not None: addr = _ip_bytes_and_scope_to_address(address, self.interface_index) else: addr = _cached_ip_addresses_wrapper(address) if addr is None: raise TypeError( "Addresses must either be IPv4 or IPv6 strings, bytes, or integers;" Expand Down Expand Up @@ -322,10 +352,10 @@ def ip_addresses_by_version(
def _ip_addresses_by_version_value( self, version_value: int_ ) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]: ) -> Union[List[IPv4Address], List[IPv6Address]]: """Backend for addresses_by_version that uses the raw value.""" if version_value == _IPVersion_All_value: return [*self._ipv4_addresses, *self._ipv6_addresses] return [*self._ipv4_addresses, *self._ipv6_addresses] # type: ignore[return-value] if version_value == _IPVersion_V4Only_value: return self._ipv4_addresses return self._ipv6_addresses Expand All @@ -339,7 +369,7 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: This means the first address will always be the most recently added address of the given IP version. """ return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)] return [_str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)]
def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: """Equivalent to parsed_addresses, with the exception that IPv6 Link-Local Expand All @@ -351,12 +381,7 @@ def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[st This means the first address will always be the most recently added address of the given IP version. """ if self.interface_index is None: return self.parsed_addresses(version) return [ f"{addr}%{self.interface_index}" if addr.version == 6 and addr.is_link_local else str(addr) for addr in self._ip_addresses_by_version_value(version.value) ] return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)]
def _set_properties(self, properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]]) -> None: """Sets properties and text of this info from a dictionary""" Expand Down Expand Up @@ -421,8 +446,8 @@ def _get_ip_addresses_from_cache_lifo( for record in self._get_address_records_from_cache_by_type(zc, type): if record.is_expired(now): continue ip_addr = _cached_ip_addresses_wrapper(record.address) if ip_addr is not None: ip_addr = _get_ip_address_object_from_record(record) if ip_addr is not None and ip_addr not in address_list: address_list.append(ip_addr) address_list.reverse() # Reverse to get LIFO order return address_list Expand Down Expand Up @@ -471,7 +496,7 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo dns_address_record = record if TYPE_CHECKING: assert isinstance(dns_address_record, DNSAddress) ip_addr = _cached_ip_addresses_wrapper(dns_address_record.address) ip_addr = _get_ip_address_object_from_record(dns_address_record) if ip_addr is None: log.warning( "Encountered invalid address while processing %s: %s", Expand Down
import asyncio import random import sys from functools import lru_cache from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast Expand Down Expand Up @@ -78,12 +79,15 @@ # the A/AAAA/SRV records for a host. _AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)
bytes_ = bytes float_ = float int_ = int
DNS_QUESTION_TYPE_QU = DNSQuestionType.QU DNS_QUESTION_TYPE_QM = DNSQuestionType.QM
IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0)
if TYPE_CHECKING: from .._core import Zeroconf
Expand All @@ -110,6 +114,29 @@ def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4 _cached_ip_addresses_wrapper = _cached_ip_addresses
def _get_ip_address_object_from_record(record: DNSAddress) -> Optional[Union[IPv4Address, IPv6Address]]: """Get the IP address object from the record.""" if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id is not None: return _ip_bytes_and_scope_to_address(record.address, record.scope_id) return _cached_ip_addresses_wrapper(record.address)
def _ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Union[IPv4Address, IPv6Address]]: """Convert the bytes and scope to an IP address object.""" base_address = _cached_ip_addresses_wrapper(address) if base_address is not None and base_address.is_link_local: return _cached_ip_addresses_wrapper(f"{base_address}%{scope}") return base_address
def _str_without_scope_id(addr: Union[IPv4Address, IPv6Address]) -> str: """Return the string representation of the address without the scope id.""" if IPADDRESS_SUPPORTS_SCOPE_ID and addr.version == 6: address_str = str(addr) return address_str.partition('%')[0] return str(addr)
class ServiceInfo(RecordUpdateListener): """Service information.
Expand Down Expand Up @@ -177,6 +204,7 @@ def __init__( raise TypeError("addresses and parsed_addresses cannot be provided together") if not type_.endswith(service_type_name(name, strict=False)): raise BadTypeInNameException self.interface_index = interface_index self.text = b'' self.type = type_ self._name = name Expand All @@ -199,7 +227,6 @@ def __init__( self._set_properties(properties) self.host_ttl = host_ttl self.other_ttl = other_ttl self.interface_index = interface_index self._new_records_futures: Optional[Set[asyncio.Future]] = None self._dns_address_cache: Optional[List[DNSAddress]] = None self._dns_pointer_cache: Optional[DNSPointer] = None Expand Down Expand Up @@ -243,7 +270,10 @@ def addresses(self, value: List[bytes]) -> None: self._get_address_and_nsec_records_cache = None
for address in value: addr = _cached_ip_addresses_wrapper(address) if IPADDRESS_SUPPORTS_SCOPE_ID and len(address) == 16 and self.interface_index is not None: addr = _ip_bytes_and_scope_to_address(address, self.interface_index) else: addr = _cached_ip_addresses_wrapper(address) if addr is None: raise TypeError( "Addresses must either be IPv4 or IPv6 strings, bytes, or integers;" Expand Down Expand Up @@ -322,10 +352,10 @@ def ip_addresses_by_version(
def _ip_addresses_by_version_value( self, version_value: int_ ) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]: ) -> Union[List[IPv4Address], List[IPv6Address]]: """Backend for addresses_by_version that uses the raw value.""" if version_value == _IPVersion_All_value: return [*self._ipv4_addresses, *self._ipv6_addresses] return [*self._ipv4_addresses, *self._ipv6_addresses] # type: ignore[return-value] if version_value == _IPVersion_V4Only_value: return self._ipv4_addresses return self._ipv6_addresses Expand All @@ -339,7 +369,7 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: This means the first address will always be the most recently added address of the given IP version. """ return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)] return [_str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)]
def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: """Equivalent to parsed_addresses, with the exception that IPv6 Link-Local Expand All @@ -351,12 +381,7 @@ def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[st This means the first address will always be the most recently added address of the given IP version. """ if self.interface_index is None: return self.parsed_addresses(version) return [ f"{addr}%{self.interface_index}" if addr.version == 6 and addr.is_link_local else str(addr) for addr in self._ip_addresses_by_version_value(version.value) ] return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)]
def _set_properties(self, properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]]) -> None: """Sets properties and text of this info from a dictionary""" Expand Down Expand Up @@ -421,8 +446,8 @@ def _get_ip_addresses_from_cache_lifo( for record in self._get_address_records_from_cache_by_type(zc, type): if record.is_expired(now): continue ip_addr = _cached_ip_addresses_wrapper(record.address) if ip_addr is not None: ip_addr = _get_ip_address_object_from_record(record) if ip_addr is not None and ip_addr not in address_list: address_list.append(ip_addr) address_list.reverse() # Reverse to get LIFO order return address_list Expand Down Expand Up @@ -471,7 +496,7 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo dns_address_record = record if TYPE_CHECKING: assert isinstance(dns_address_record, DNSAddress) ip_addr = _cached_ip_addresses_wrapper(dns_address_record.address) ip_addr = _get_ip_address_object_from_record(dns_address_record) if ip_addr is None: log.warning( "Encountered invalid address while processing %s: %s", Expand Down