feat: speed up the query handler (#1350) · python-zeroconf/python-zeroconf@9eac0a1

@@ -20,19 +20,19 @@

2020

USA

2121

"""

222223-

from typing import TYPE_CHECKING, List, Optional, Set, cast

23+

from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast

24242525

from .._cache import DNSCache, _UniqueRecordsType

2626

from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet

27-

from .._history import QuestionHistory

2827

from .._protocol.incoming import DNSIncoming

2928

from .._services.info import ServiceInfo

30-

from .._services.registry import ServiceRegistry

29+

from .._transport import _WrappedTransport

3130

from .._utils.net import IPVersion

3231

from ..const import (

3332

_ADDRESS_RECORD_TYPES,

3433

_CLASS_IN,

3534

_DNS_OTHER_TTL,

35+

_MDNS_PORT,

3636

_ONE_SECOND,

3737

_SERVICE_TYPE_ENUMERATION_NAME,

3838

_TYPE_A,

@@ -43,7 +43,12 @@

4343

_TYPE_SRV,

4444

_TYPE_TXT,

4545

)

46-

from .answers import QuestionAnswers, _AnswerWithAdditionalsType

46+

from .answers import (

47+

QuestionAnswers,

48+

_AnswerWithAdditionalsType,

49+

construct_outgoing_multicast_answers,

50+

construct_outgoing_unicast_answers,

51+

)

47524853

_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES}

4954

@@ -53,14 +58,17 @@

5358

_IPVersion_ALL = IPVersion.All

54595560

_int = int

56-61+

_str = str

57625863

_ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION = 0

5964

_ANSWER_STRATEGY_POINTER = 1

6065

_ANSWER_STRATEGY_ADDRESS = 2

6166

_ANSWER_STRATEGY_SERVICE = 3

6267

_ANSWER_STRATEGY_TEXT = 4

636869+

if TYPE_CHECKING:

70+

from .._core import Zeroconf

71+64726573

class _AnswerStrategy:

6674

@@ -183,13 +191,14 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:

183191

class QueryHandler:

184192

"""Query the ServiceRegistry."""

185193186-

__slots__ = ("registry", "cache", "question_history")

194+

__slots__ = ("zc", "registry", "cache", "question_history")

187195188-

def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history: QuestionHistory) -> None:

196+

def __init__(self, zc: 'Zeroconf') -> None:

189197

"""Init the query handler."""

190-

self.registry = registry

191-

self.cache = cache

192-

self.question_history = question_history

198+

self.zc = zc

199+

self.registry = zc.registry

200+

self.cache = zc.cache

201+

self.question_history = zc.question_history

193202194203

def _add_service_type_enumeration_query_answers(

195204

self, types: List[str], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet

@@ -385,3 +394,45 @@ def _get_answer_strategies(

385394

)

386395387396

return strategies

397+398+

def handle_assembled_query(

399+

self,

400+

packets: List[DNSIncoming],

401+

addr: _str,

402+

port: _int,

403+

transport: _WrappedTransport,

404+

v6_flow_scope: Union[Tuple[()], Tuple[int, int]],

405+

) -> None:

406+

"""Respond to a (re)assembled query.

407+408+

If the protocol recieved packets with the TC bit set, it will

409+

wait a bit for the rest of the packets and only call

410+

handle_assembled_query once it has a complete set of packets

411+

or the timer expires. If the TC bit is not set, a single

412+

packet will be in packets.

413+

"""

414+

first_packet = packets[0]

415+

now = first_packet.now

416+

ucast_source = port != _MDNS_PORT

417+

question_answers = self.async_response(packets, ucast_source)

418+

if not question_answers:

419+

return

420+

if question_answers.ucast:

421+

questions = first_packet.questions

422+

id_ = first_packet.id

423+

out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_)

424+

# When sending unicast, only send back the reply

425+

# via the same socket that it was recieved from

426+

# as we know its reachable from that socket

427+

self.zc.async_send(out, addr, port, v6_flow_scope, transport)

428+

if question_answers.mcast_now:

429+

self.zc.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now))

430+

if question_answers.mcast_aggregate:

431+

out_queue = self.zc.out_queue

432+

out_queue.async_add(now, question_answers.mcast_aggregate)

433+

if question_answers.mcast_aggregate_last_second:

434+

# https://datatracker.ietf.org/doc/html/rfc6762#section-14

435+

# If we broadcast it in the last second, we have to delay

436+

# at least a second before we send it again

437+

out_delay_queue = self.zc.out_delay_queue

438+

out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second)