1-
21import asyncio
32import functools
4- import pycares
53import socket
64import sys
5+ from collections .abc import Iterable , Sequence
6+ from typing import Any , Literal , Optional , TypeVar , Union , overload
77
8+ import pycares
89from typing import (
910 Any ,
1011 Callable ,
2223
2324__all__ = ('DNSResolver' , 'error' )
2425
26+ _T = TypeVar ("_T" )
27+
2528WINDOWS_SELECTOR_ERR_MSG = (
2629 "aiodns needs a SelectorEventLoop on Windows. See more: "
2730 "https://github.com/aio-libs/aiodns#note-for-windows-users"
5558class DNSResolver :
5659 def __init__ (self , nameservers : Optional [Sequence [str ]] = None ,
5760 loop : Optional [asyncio .AbstractEventLoop ] = None ,
58- ** kwargs : Any ) -> None :
61+ ** kwargs : Any ) -> None : # TODO(PY311): Use Unpack for kwargs.
5962 self .loop = loop or asyncio .get_event_loop ()
6063 assert self .loop is not None
6164 kwargs .pop ('sock_state_cb' , None )
@@ -80,31 +83,33 @@ def __init__(self, nameservers: Optional[Sequence[str]] = None,
8083 ** kwargs )
8184 if nameservers :
8285 self .nameservers = nameservers
83- self ._read_fds = set () # type: Set[int]
84- self ._write_fds = set () # type: Set[int]
85- self ._timer = None # type : Optional[asyncio.TimerHandle]
86+ self ._read_fds : set [ int ] = set ()
87+ self ._write_fds : set [ int ] = set ()
88+ self ._timer : Optional [asyncio .TimerHandle ] = None
8689
8790 @property
8891 def nameservers (self ) -> Sequence [str ]:
8992 return self ._channel .servers
9093
9194 @nameservers .setter
92- def nameservers (self , value : Sequence [str ]) -> None :
93- self ._channel .servers = value if isinstance (value , list ) else list (value )
95+ def nameservers (self , value : Iterable [Union [str , bytes ]]) -> None :
96+ # Remove type ignore after mypy 1.16.0
97+ # https://github.com/python/mypy/issues/12892
98+ self ._channel .servers = value # type: ignore[assignment]
9499
95100 @staticmethod
96- def _callback (fut : asyncio .Future , result : Any , errorno : int ) -> None :
101+ def _callback (fut : asyncio .Future [ _T ] , result : _T , errorno : Optional [ int ] ) -> None :
97102 if fut .cancelled ():
98103 return
99104 if errorno is not None :
100105 fut .set_exception (error .DNSError (errorno , pycares .errno .strerror (errorno )))
101106 else :
102107 fut .set_result (result )
103108
104- def _get_future_callback (self ) -> Tuple ["asyncio.Future[Any ]" , Callable [[Any , int ], None ]]:
109+ def _get_future_callback (self ) -> Tuple ["asyncio.Future[_T ]" , Callable [[_T , int ], None ]]:
105110 """Return a future and a callback to set the result of the future."""
106- cb : Callable [[Any , int ], None ]
107- future : "asyncio.Future[Any ]" = self .loop .create_future ()
111+ cb : Callable [[_T , int ], None ]
112+ future : "asyncio.Future[_T ]" = self .loop .create_future ()
108113 if self ._event_thread :
109114 cb = functools .partial ( # type: ignore[assignment]
110115 self .loop .call_soon_threadsafe ,
@@ -115,7 +120,41 @@ def _get_future_callback(self) -> Tuple["asyncio.Future[Any]", Callable[[Any, in
115120 cb = functools .partial (self ._callback , future )
116121 return future , cb
117122
118- def query (self , host : str , qtype : str , qclass : Optional [str ]= None ) -> asyncio .Future :
123+ @overload
124+ def query (self , host : str , qtype : Literal ["A" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_a_result ]]:
125+ ...
126+ @overload
127+ def query (self , host : str , qtype : Literal ["AAAA" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_aaaa_result ]]:
128+ ...
129+ @overload
130+ def query (self , host : str , qtype : Literal ["CAA" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_caa_result ]]:
131+ ...
132+ @overload
133+ def query (self , host : str , qtype : Literal ["CNAME" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_cname_result ]]:
134+ ...
135+ @overload
136+ def query (self , host : str , qtype : Literal ["MX" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_mx_result ]]:
137+ ...
138+ @overload
139+ def query (self , host : str , qtype : Literal ["NAPTR" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_naptr_result ]]:
140+ ...
141+ @overload
142+ def query (self , host : str , qtype : Literal ["NS" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_ns_result ]]:
143+ ...
144+ @overload
145+ def query (self , host : str , qtype : Literal ["PTR" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_ptr_result ]]:
146+ ...
147+ @overload
148+ def query (self , host : str , qtype : Literal ["SOA" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_soa_result ]]:
149+ ...
150+ @overload
151+ def query (self , host : str , qtype : Literal ["SRV" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_srv_result ]]:
152+ ...
153+ @overload
154+ def query (self , host : str , qtype : Literal ["TXT" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_txt_result ]]:
155+ ...
156+
157+ def query (self , host : str , qtype : str , qclass : Optional [str ]= None ) -> asyncio .Future [list [Any ]]:
119158 try :
120159 qtype = query_type_map [qtype ]
121160 except KeyError :
@@ -126,30 +165,35 @@ def query(self, host: str, qtype: str, qclass: Optional[str]=None) -> asyncio.Fu
126165 except KeyError :
127166 raise ValueError ('invalid query class: {}' .format (qclass ))
128167
168+ fut : asyncio .Future [list [Any ]]
129169 fut , cb = self ._get_future_callback ()
130170 self ._channel .query (host , qtype , cb , query_class = qclass )
131171 return fut
132172
133- def gethostbyname (self , host : str , family : socket .AddressFamily ) -> asyncio .Future :
173+ def gethostbyname (self , host : str , family : socket .AddressFamily ) -> asyncio .Future [pycares .ares_host_result ]:
174+ fut : asyncio .Future [pycares .ares_host_result ]
134175 fut , cb = self ._get_future_callback ()
135176 self ._channel .gethostbyname (host , family , cb )
136177 return fut
137178
138- def getaddrinfo (self , host : str , family : socket .AddressFamily = socket .AF_UNSPEC , port : Optional [int ] = None , proto : int = 0 , type : int = 0 , flags : int = 0 ) -> asyncio .Future :
179+ def getaddrinfo (self , host : str , family : socket .AddressFamily = socket .AF_UNSPEC , port : Optional [int ] = None , proto : int = 0 , type : int = 0 , flags : int = 0 ) -> asyncio .Future [pycares .ares_addrinfo_result ]:
180+ fut : asyncio .Future [pycares .ares_addrinfo_result ]
139181 fut , cb = self ._get_future_callback ()
140182 self ._channel .getaddrinfo (host , port , cb , family = family , type = type , proto = proto , flags = flags )
141183 return fut
142184
143- def getnameinfo (self , sockaddr : Union [Tuple [str , int ], Tuple [str , int , int , int ]], flags : int = 0 ) -> asyncio .Future :
185+ def getnameinfo (self , sockaddr : Union [tuple [str , int ], tuple [str , int , int , int ]], flags : int = 0 ) -> asyncio .Future [pycares .ares_nameinfo_result ]:
186+ fut : asyncio .Future [pycares .ares_nameinfo_result ]
144187 fut , cb = self ._get_future_callback ()
145188 self ._channel .getnameinfo (sockaddr , flags , cb )
146189 return fut
147190
148- def gethostbyaddr (self , name : str ) -> asyncio .Future :
191+ def gethostbyaddr (self , name : str ) -> asyncio .Future [pycares .ares_host_result ]:
192+ fut : asyncio .Future [pycares .ares_host_result ]
149193 fut , cb = self ._get_future_callback ()
150194 self ._channel .gethostbyaddr (name , cb )
151195 return fut
152-
196+
153197 def cancel (self ) -> None :
154198 self ._channel .cancel ()
155199
@@ -177,7 +221,7 @@ def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None:
177221 self ._timer .cancel ()
178222 self ._timer = None
179223
180- def _handle_event (self , fd : int , event : Any ) -> None :
224+ def _handle_event (self , fd : int , event : int ) -> None :
181225 read_fd = pycares .ARES_SOCKET_BAD
182226 write_fd = pycares .ARES_SOCKET_BAD
183227 if event == READ :
@@ -193,7 +237,7 @@ def _timer_cb(self) -> None:
193237 else :
194238 self ._timer = None
195239
196- def _start_timer (self ):
240+ def _start_timer (self ) -> None :
197241 timeout = self ._timeout
198242 if timeout is None or timeout < 0 or timeout > 1 :
199243 timeout = 1
0 commit comments