Strict typing for dhcp (#67361)

This commit is contained in:
J. Nick Koston 2022-02-28 18:49:44 -10:00 committed by GitHub
parent 21ce441a97
commit 076fe97110
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 73 deletions

View File

@ -68,6 +68,7 @@ homeassistant.components.device_automation.*
homeassistant.components.device_tracker.*
homeassistant.components.devolo_home_control.*
homeassistant.components.devolo_home_network.*
homeassistant.components.dhcp.*
homeassistant.components.dlna_dmr.*
homeassistant.components.dnsip.*
homeassistant.components.dsmr.*

View File

@ -2,6 +2,9 @@
from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import Callable, Iterable
import contextlib
from dataclasses import dataclass
from datetime import timedelta
import fnmatch
@ -9,7 +12,7 @@ from ipaddress import ip_address as make_ip_address
import logging
import os
import threading
from typing import Any, Final
from typing import TYPE_CHECKING, Any, Final, cast
from aiodiscover import DiscoverHosts
from aiodiscover.discovery import (
@ -51,12 +54,16 @@ from homeassistant.helpers.event import (
)
from homeassistant.helpers.frame import report
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_dhcp
from homeassistant.loader import DHCPMatcher, async_get_dhcp
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.network import is_invalid, is_link_local, is_loopback
from .const import DOMAIN
if TYPE_CHECKING:
from scapy.packet import Packet
from scapy.sendrecv import AsyncSniffer
FILTER = "udp and (port 67 or 68)"
REQUESTED_ADDR = "requested_addr"
MESSAGE_TYPE = "message-type"
@ -115,7 +122,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
watchers: list[WatcherBase] = []
address_data: dict[str, dict[str, str]] = {}
integration_matchers = await async_get_dhcp(hass)
# For the passive classes we need to start listening
# for state changes and connect the dispatchers before
# everything else starts up or we will miss events
@ -124,13 +130,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
await passive_watcher.async_start()
watchers.append(passive_watcher)
async def _initialize(_):
async def _initialize(event: Event) -> None:
for active_cls in (DHCPWatcher, NetworkWatcher):
active_watcher = active_cls(hass, address_data, integration_matchers)
await active_watcher.async_start()
watchers.append(active_watcher)
async def _async_stop(*_):
async def _async_stop(event: Event) -> None:
for watcher in watchers:
await watcher.async_stop()
@ -143,7 +149,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
class WatcherBase:
"""Base class for dhcp and device tracker watching."""
def __init__(self, hass, address_data, integration_matchers):
def __init__(
self,
hass: HomeAssistant,
address_data: dict[str, dict[str, str]],
integration_matchers: list[DHCPMatcher],
) -> None:
"""Initialize class."""
super().__init__()
@ -152,11 +163,11 @@ class WatcherBase:
self._address_data = address_data
@abstractmethod
async def async_stop(self):
async def async_stop(self) -> None:
"""Stop the watcher."""
@abstractmethod
async def async_start(self):
async def async_start(self) -> None:
"""Start the watcher."""
def process_client(self, ip_address: str, hostname: str, mac_address: str) -> None:
@ -197,8 +208,8 @@ class WatcherBase:
data = {MAC_ADDRESS: mac_address, HOSTNAME: hostname}
self._address_data[ip_address] = data
lowercase_hostname = data[HOSTNAME].lower()
uppercase_mac = data[MAC_ADDRESS].upper()
lowercase_hostname = hostname.lower()
uppercase_mac = mac_address.upper()
_LOGGER.debug(
"Processing updated address data for %s: mac=%s hostname=%s",
@ -218,22 +229,24 @@ class WatcherBase:
if entry := self.hass.config_entries.async_get_entry(entry_id):
device_domains.add(entry.domain)
for entry in self._integration_matchers:
if entry.get(REGISTERED_DEVICES) and not entry["domain"] in device_domains:
for matcher in self._integration_matchers:
domain = matcher["domain"]
if matcher.get(REGISTERED_DEVICES) and domain not in device_domains:
continue
if MAC_ADDRESS in entry and not fnmatch.fnmatch(
uppercase_mac, entry[MAC_ADDRESS]
):
if (
matcher_mac := matcher.get(MAC_ADDRESS)
) is not None and not fnmatch.fnmatch(uppercase_mac, matcher_mac):
continue
if HOSTNAME in entry and not fnmatch.fnmatch(
lowercase_hostname, entry[HOSTNAME]
):
if (
matcher_hostname := matcher.get(HOSTNAME)
) is not None and not fnmatch.fnmatch(lowercase_hostname, matcher_hostname):
continue
_LOGGER.debug("Matched %s against %s", data, entry)
matched_domains.add(entry["domain"])
_LOGGER.debug("Matched %s against %s", data, matcher)
matched_domains.add(domain)
for domain in matched_domains:
discovery_flow.async_create_flow(
@ -243,7 +256,7 @@ class WatcherBase:
DhcpServiceInfo(
ip=ip_address,
hostname=lowercase_hostname,
macaddress=data[MAC_ADDRESS],
macaddress=mac_address,
),
)
@ -251,14 +264,19 @@ class WatcherBase:
class NetworkWatcher(WatcherBase):
"""Class to query ptr records routers."""
def __init__(self, hass, address_data, integration_matchers):
def __init__(
self,
hass: HomeAssistant,
address_data: dict[str, dict[str, str]],
integration_matchers: list[DHCPMatcher],
) -> None:
"""Initialize class."""
super().__init__(hass, address_data, integration_matchers)
self._unsub = None
self._discover_hosts = None
self._discover_task = None
self._unsub: Callable[[], None] | None = None
self._discover_hosts: DiscoverHosts | None = None
self._discover_task: asyncio.Task | None = None
async def async_stop(self):
async def async_stop(self) -> None:
"""Stop scanning for new devices on the network."""
if self._unsub:
self._unsub()
@ -267,7 +285,7 @@ class NetworkWatcher(WatcherBase):
self._discover_task.cancel()
self._discover_task = None
async def async_start(self):
async def async_start(self) -> None:
"""Start scanning for new devices on the network."""
self._discover_hosts = DiscoverHosts()
self._unsub = async_track_time_interval(
@ -276,14 +294,15 @@ class NetworkWatcher(WatcherBase):
self.async_start_discover()
@callback
def async_start_discover(self, *_):
def async_start_discover(self, *_: Any) -> None:
"""Start a new discovery task if one is not running."""
if self._discover_task and not self._discover_task.done():
return
self._discover_task = self.hass.async_create_task(self.async_discover())
async def async_discover(self):
async def async_discover(self) -> None:
"""Process discovery."""
assert self._discover_hosts is not None
for host in await self._discover_hosts.async_discover():
self.async_process_client(
host[DISCOVERY_IP_ADDRESS],
@ -295,18 +314,23 @@ class NetworkWatcher(WatcherBase):
class DeviceTrackerWatcher(WatcherBase):
"""Class to watch dhcp data from routers."""
def __init__(self, hass, address_data, integration_matchers):
def __init__(
self,
hass: HomeAssistant,
address_data: dict[str, dict[str, str]],
integration_matchers: list[DHCPMatcher],
) -> None:
"""Initialize class."""
super().__init__(hass, address_data, integration_matchers)
self._unsub = None
self._unsub: Callable[[], None] | None = None
async def async_stop(self):
async def async_stop(self) -> None:
"""Stop watching for new device trackers."""
if self._unsub:
self._unsub()
self._unsub = None
async def async_start(self):
async def async_start(self) -> None:
"""Stop watching for new device trackers."""
self._unsub = async_track_state_added_domain(
self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event
@ -315,12 +339,12 @@ class DeviceTrackerWatcher(WatcherBase):
self._async_process_device_state(state)
@callback
def _async_process_device_event(self, event: Event):
def _async_process_device_event(self, event: Event) -> None:
"""Process a device tracker state change event."""
self._async_process_device_state(event.data["new_state"])
@callback
def _async_process_device_state(self, state: State):
def _async_process_device_state(self, state: State) -> None:
"""Process a device tracker state."""
if state.state != STATE_HOME:
return
@ -343,18 +367,23 @@ class DeviceTrackerWatcher(WatcherBase):
class DeviceTrackerRegisteredWatcher(WatcherBase):
"""Class to watch data from device tracker registrations."""
def __init__(self, hass, address_data, integration_matchers):
def __init__(
self,
hass: HomeAssistant,
address_data: dict[str, dict[str, str]],
integration_matchers: list[DHCPMatcher],
) -> None:
"""Initialize class."""
super().__init__(hass, address_data, integration_matchers)
self._unsub = None
self._unsub: Callable[[], None] | None = None
async def async_stop(self):
async def async_stop(self) -> None:
"""Stop watching for device tracker registrations."""
if self._unsub:
self._unsub()
self._unsub = None
async def async_start(self):
async def async_start(self) -> None:
"""Stop watching for device tracker registrations."""
self._unsub = async_dispatcher_connect(
self.hass, CONNECTED_DEVICE_REGISTERED, self._async_process_device_data
@ -376,26 +405,32 @@ class DeviceTrackerRegisteredWatcher(WatcherBase):
class DHCPWatcher(WatcherBase):
"""Class to watch dhcp requests."""
def __init__(self, hass, address_data, integration_matchers):
def __init__(
self,
hass: HomeAssistant,
address_data: dict[str, dict[str, str]],
integration_matchers: list[DHCPMatcher],
) -> None:
"""Initialize class."""
super().__init__(hass, address_data, integration_matchers)
self._sniffer = None
self._sniffer: AsyncSniffer | None = None
self._started = threading.Event()
async def async_stop(self):
async def async_stop(self) -> None:
"""Stop watching for new device trackers."""
await self.hass.async_add_executor_job(self._stop)
def _stop(self):
def _stop(self) -> None:
"""Stop the thread."""
if self._started.is_set():
assert self._sniffer is not None
self._sniffer.stop()
async def async_start(self):
async def async_start(self) -> None:
"""Start watching for dhcp packets."""
await self.hass.async_add_executor_job(self._start)
def _start(self):
def _start(self) -> None:
"""Start watching for dhcp packets."""
# Local import because importing from scapy has side effects such as opening
# sockets
@ -417,20 +452,25 @@ class DHCPWatcher(WatcherBase):
AsyncSniffer,
)
def _handle_dhcp_packet(packet):
def _handle_dhcp_packet(packet: Packet) -> None:
"""Process a dhcp packet."""
if DHCP not in packet:
return
options = packet[DHCP].options
request_type = _decode_dhcp_option(options, MESSAGE_TYPE)
if request_type != DHCP_REQUEST:
options_dict = _dhcp_options_as_dict(packet[DHCP].options)
if options_dict.get(MESSAGE_TYPE) != DHCP_REQUEST:
# Not a DHCP request
return
ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src
hostname = _decode_dhcp_option(options, HOSTNAME) or ""
mac_address = _format_mac(packet[Ether].src)
ip_address = options_dict.get(REQUESTED_ADDR) or cast(str, packet[IP].src)
assert isinstance(ip_address, str)
hostname = ""
if (hostname_bytes := options_dict.get(HOSTNAME)) and isinstance(
hostname_bytes, bytes
):
with contextlib.suppress(AttributeError, UnicodeDecodeError):
hostname = hostname_bytes.decode()
mac_address = _format_mac(cast(str, packet[Ether].src))
if ip_address is not None and mac_address is not None:
self.process_client(ip_address, hostname, mac_address)
@ -470,29 +510,19 @@ class DHCPWatcher(WatcherBase):
self._sniffer.thread.name = self.__class__.__name__
def _decode_dhcp_option(dhcp_options, key):
"""Extract and decode data from a packet option."""
for option in dhcp_options:
if len(option) < 2 or option[0] != key:
continue
value = option[1]
if value is None or key != HOSTNAME:
return value
# hostname is unicode
try:
return value.decode()
except (AttributeError, UnicodeDecodeError):
return None
def _dhcp_options_as_dict(
dhcp_options: Iterable[tuple[str, int | bytes | None]]
) -> dict[str, str | int | bytes | None]:
"""Extract data from packet options as a dict."""
return {option[0]: option[1] for option in dhcp_options if len(option) >= 2}
def _format_mac(mac_address):
def _format_mac(mac_address: str) -> str:
"""Format a mac address for matching."""
return format_mac(mac_address).replace(":", "")
def _verify_l2socket_setup(cap_filter):
def _verify_l2socket_setup(cap_filter: str) -> None:
"""Create a socket using the scapy configured l2socket.
Try to create the socket
@ -504,7 +534,7 @@ def _verify_l2socket_setup(cap_filter):
conf.L2socket(filter=cap_filter)
def _verify_working_pcap(cap_filter):
def _verify_working_pcap(cap_filter: str) -> None:
"""Verify we can create a packet filter.
If we cannot create a filter we will be listening for

View File

@ -60,6 +60,24 @@ MAX_LOAD_CONCURRENTLY = 4
MOVED_ZEROCONF_PROPS = ("macaddress", "model", "manufacturer")
class DHCPMatcherRequired(TypedDict, total=True):
"""Matcher for the dhcp integration for required fields."""
domain: str
class DHCPMatcherOptional(TypedDict, total=False):
"""Matcher for the dhcp integration for optional fields."""
macaddress: str
hostname: str
registered_devices: bool
class DHCPMatcher(DHCPMatcherRequired, DHCPMatcherOptional):
"""Matcher for the dhcp integration."""
class Manifest(TypedDict, total=False):
"""
Integration manifest.
@ -228,16 +246,16 @@ async def async_get_zeroconf(
return zeroconf
async def async_get_dhcp(hass: HomeAssistant) -> list[dict[str, str | bool]]:
async def async_get_dhcp(hass: HomeAssistant) -> list[DHCPMatcher]:
"""Return cached list of dhcp types."""
dhcp: list[dict[str, str | bool]] = DHCP.copy()
dhcp = cast(list[DHCPMatcher], DHCP.copy())
integrations = await async_get_custom_components(hass)
for integration in integrations.values():
if not integration.dhcp:
continue
for entry in integration.dhcp:
dhcp.append({"domain": integration.domain, **entry})
dhcp.append(cast(DHCPMatcher, {"domain": integration.domain, **entry}))
return dhcp

View File

@ -549,6 +549,17 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.dhcp.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.dlna_dmr.*]
check_untyped_defs = true
disallow_incomplete_defs = true