feat: speed up ServiceInfo with a cython pxd (#1264) · python-zeroconf/python-zeroconf@7ca690a

@@ -78,6 +78,12 @@

7878

# the A/AAAA/SRV records for a host.

7979

_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)

808081+

float_ = float

82+

int_ = int

83+84+

DNS_QUESTION_TYPE_QU = DNSQuestionType.QU

85+

DNS_QUESTION_TYPE_QM = DNSQuestionType.QM

86+8187

if TYPE_CHECKING:

8288

from .._core import Zeroconf

8389

@@ -281,10 +287,9 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]:

281287

"""

282288

version_value = version.value

283289

if version_value == _IPVersion_All_value:

284-

return [

285-

*(addr.packed for addr in self._ipv4_addresses),

286-

*(addr.packed for addr in self._ipv6_addresses),

287-

]

290+

ip_v4_packed = [addr.packed for addr in self._ipv4_addresses]

291+

ip_v6_packed = [addr.packed for addr in self._ipv6_addresses]

292+

return [*ip_v4_packed, *ip_v6_packed]

288293

if version_value == _IPVersion_V4Only_value:

289294

return [addr.packed for addr in self._ipv4_addresses]

290295

return [addr.packed for addr in self._ipv6_addresses]

@@ -303,7 +308,7 @@ def ip_addresses_by_version(

303308

return self._ip_addresses_by_version_value(version.value)

304309305310

def _ip_addresses_by_version_value(

306-

self, version_value: int

311+

self, version_value: int_

307312

) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]:

308313

"""Backend for addresses_by_version that uses the raw value."""

309314

if version_value == _IPVersion_All_value:

@@ -397,7 +402,7 @@ def get_name(self) -> str:

397402

return self._name[: len(self._name) - len(self.type) - 1]

398403399404

def _get_ip_addresses_from_cache_lifo(

400-

self, zc: 'Zeroconf', now: float, type: int

405+

self, zc: 'Zeroconf', now: float_, type: int_

401406

) -> List[Union[IPv4Address, IPv6Address]]:

402407

"""Set IPv6 addresses from the cache."""

403408

address_list: List[Union[IPv4Address, IPv6Address]] = []

@@ -410,7 +415,7 @@ def _get_ip_addresses_from_cache_lifo(

410415

address_list.reverse() # Reverse to get LIFO order

411416

return address_list

412417413-

def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:

418+

def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None:

414419

"""Set IPv6 addresses from the cache."""

415420

if TYPE_CHECKING:

416421

self._ipv6_addresses = cast(

@@ -419,7 +424,7 @@ def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:

419424

else:

420425

self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)

421426422-

def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:

427+

def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None:

423428

"""Set IPv4 addresses from the cache."""

424429

if TYPE_CHECKING:

425430

self._ipv4_addresses = cast(

@@ -428,7 +433,7 @@ def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:

428433

else:

429434

self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)

430435431-

def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:

436+

def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[RecordUpdate]) -> None:

432437

"""Updates service information from a DNS record.

433438434439

This method will be run in the event loop.

@@ -440,7 +445,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU

440445

if updated and new_records_futures:

441446

_resolve_all_futures_to_none(new_records_futures)

442447443-

def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float) -> bool:

448+

def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float_) -> bool:

444449

"""Thread safe record updating.

445450446451

Returns True if a new record was added.

@@ -624,14 +629,15 @@ def get_address_and_nsec_records(self, override_ttl: Optional[int] = None) -> Se

624629

self._get_address_and_nsec_records_cache = records

625630

return records

626631627-

def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]:

632+

def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int_) -> List[DNSAddress]:

628633

"""Get the addresses from the cache."""

629634

if self.server_key is None:

630635

return []

636+

cache = zc.cache

631637

if TYPE_CHECKING:

632-

records = cast("List[DNSAddress]", zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN))

638+

records = cast("List[DNSAddress]", cache.get_all_by_details(self.server_key, _type, _CLASS_IN))

633639

else:

634-

records = zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN)

640+

records = cache.get_all_by_details(self.server_key, _type, _CLASS_IN)

635641

return records

636642637643

def set_server_if_missing(self) -> None:

@@ -643,28 +649,33 @@ def set_server_if_missing(self) -> None:

643649

self.server = self._name

644650

self.server_key = self.key

645651646-

def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool:

652+

def load_from_cache(self, zc: 'Zeroconf', now: Optional[float_] = None) -> bool:

653+

"""Populate the service info from the cache.

654+655+

This method is designed to be threadsafe.

656+

"""

657+

return self._load_from_cache(zc, now or current_time_millis())

658+659+

def _load_from_cache(self, zc: 'Zeroconf', now: float_) -> bool:

647660

"""Populate the service info from the cache.

648661649662

This method is designed to be threadsafe.

650663

"""

651-

if not now:

652-

now = current_time_millis()

664+

cache = zc.cache

653665

original_server_key = self.server_key

654-

cached_srv_record = zc.cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN)

666+

cached_srv_record = cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN)

655667

if cached_srv_record:

656668

self._process_record_threadsafe(zc, cached_srv_record, now)

657-

cached_txt_record = zc.cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN)

669+

cached_txt_record = cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN)

658670

if cached_txt_record:

659671

self._process_record_threadsafe(zc, cached_txt_record, now)

660672

if original_server_key == self.server_key:

661673

# If there is a srv which changes the server_key,

662674

# A and AAAA will already be loaded from the cache

663675

# and we do not want to do it twice

664-

for record in [

665-

*self._get_address_records_from_cache_by_type(zc, _TYPE_A),

666-

*self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA),

667-

]:

676+

for record in self._get_address_records_from_cache_by_type(zc, _TYPE_A):

677+

self._process_record_threadsafe(zc, record, now)

678+

for record in self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA):

668679

self._process_record_threadsafe(zc, record, now)

669680

return self._is_complete

670681

@@ -720,7 +731,7 @@ async def async_request(

720731721732

now = current_time_millis()

722733723-

if self.load_from_cache(zc, now):

734+

if self._load_from_cache(zc, now):

724735

return True

725736726737

if TYPE_CHECKING:

@@ -737,11 +748,13 @@ async def async_request(

737748

return False

738749

if next_ <= now:

739750

out = self.generate_request_query(

740-

zc, now, question_type or DNSQuestionType.QU if first_request else DNSQuestionType.QM

751+

zc,

752+

now,

753+

question_type or DNS_QUESTION_TYPE_QU if first_request else DNS_QUESTION_TYPE_QM,

741754

)

742755

first_request = False

743756

if not out.questions:

744-

return self.load_from_cache(zc, now)

757+

return self._load_from_cache(zc, now)

745758

zc.async_send(out, addr, port)

746759

next_ = now + delay

747760

delay *= 2

@@ -755,7 +768,7 @@ async def async_request(

755768

return True

756769757770

def generate_request_query(

758-

self, zc: 'Zeroconf', now: float, question_type: Optional[DNSQuestionType] = None

771+

self, zc: 'Zeroconf', now: float_, question_type: Optional[DNSQuestionType] = None

759772

) -> DNSOutgoing:

760773

"""Generate the request query."""

761774

out = DNSOutgoing(_FLAGS_QR_QUERY)

@@ -766,7 +779,7 @@ def generate_request_query(

766779

out.add_question_or_one_cache(cache, now, name, _TYPE_TXT, _CLASS_IN)

767780

out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_A, _CLASS_IN)

768781

out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_AAAA, _CLASS_IN)

769-

if question_type == DNSQuestionType.QU:

782+

if question_type == DNS_QUESTION_TYPE_QU:

770783

for question in out.questions:

771784

question.unicast = True

772785

return out