1
mirror of https://github.com/home-assistant/core synced 2024-09-12 15:16:21 +02:00

Discover devices from device_trackers with router sources (#45160)

* Discover devices from device_trackers with router sources

* Update homeassistant/components/dhcp/__init__.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* fix stop being called on the wrong context

* clean

* move it to base

* cleanup was too agressive

* Update homeassistant/components/dhcp/__init__.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* coverage

* revert legacy changes

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
J. Nick Koston 2021-01-15 14:01:37 -10:00 committed by GitHub
parent 598a0d19b1
commit 5e01b828af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 393 additions and 68 deletions

View File

@ -13,7 +13,7 @@ from homeassistant.const import (
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from .const import ATTR_SOURCE_TYPE, DOMAIN, LOGGER
from .const import ATTR_HOST_NAME, ATTR_IP, ATTR_MAC, ATTR_SOURCE_TYPE, DOMAIN, LOGGER
async def async_setup_entry(hass, entry):
@ -47,6 +47,21 @@ class BaseTrackerEntity(Entity):
"""Return the source type, eg gps or router, of the device."""
raise NotImplementedError
@property
def ip_address(self) -> str:
"""Return the primary ip address of the device."""
return None
@property
def mac_address(self) -> str:
"""Return the mac address of the device."""
return None
@property
def hostname(self) -> str:
"""Return hostname of the device."""
return None
@property
def state_attributes(self):
"""Return the device state attributes."""
@ -54,6 +69,12 @@ class BaseTrackerEntity(Entity):
if self.battery_level:
attr[ATTR_BATTERY_LEVEL] = self.battery_level
if self.ip_address is not None:
attr[ATTR_IP] = self.ip_address
if self.ip_address is not None:
attr[ATTR_MAC] = self.mac_address
if self.hostname is not None:
attr[ATTR_HOST_NAME] = self.hostname
return attr

View File

@ -34,3 +34,4 @@ ATTR_LOCATION_NAME = "location_name"
ATTR_MAC = "mac"
ATTR_SOURCE_TYPE = "source_type"
ATTR_CONSIDER_HOME = "consider_home"
ATTR_IP = "ip"

View File

@ -1,18 +1,32 @@
"""The dhcp integration."""
from abc import abstractmethod
import fnmatch
import logging
import os
from threading import Event, Thread
import threading
from scapy.error import Scapy_Exception
from scapy.layers.dhcp import DHCP
from scapy.layers.l2 import Ether
from scapy.sendrecv import sniff
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant
from homeassistant.components.device_tracker.const import (
ATTR_HOST_NAME,
ATTR_IP,
ATTR_MAC,
ATTR_SOURCE_TYPE,
DOMAIN as DEVICE_TRACKER_DOMAIN,
SOURCE_TYPE_ROUTER,
)
from homeassistant.const import (
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP,
STATE_HOME,
)
from homeassistant.core import Event, HomeAssistant, State, callback
from homeassistant.helpers.device_registry import format_mac
from homeassistant.helpers.event import async_track_state_added_domain
from homeassistant.loader import async_get_dhcp
from .const import DOMAIN
@ -32,35 +46,162 @@ async def async_setup(hass: HomeAssistant, config: dict) -> bool:
"""Set up the dhcp component."""
async def _initialize(_):
dhcp_watcher = DHCPWatcher(hass, await async_get_dhcp(hass))
dhcp_watcher.start()
address_data = {}
integration_matchers = await async_get_dhcp(hass)
watchers = []
def _stop(*_):
dhcp_watcher.stop()
dhcp_watcher.join()
for cls in (DHCPWatcher, DeviceTrackerWatcher):
watcher = cls(hass, address_data, integration_matchers)
watcher.async_start()
watchers.append(watcher)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _stop)
async def _async_stop(*_):
for watcher in watchers:
if hasattr(watcher, "async_stop"):
watcher.async_stop()
else:
await hass.async_add_executor_job(watcher.stop)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, _initialize)
return True
class DHCPWatcher(Thread):
"""Class to watch dhcp requests."""
class WatcherBase:
"""Base class for dhcp and device tracker watching."""
def __init__(self, hass, integration_matchers):
def __init__(self, hass, address_data, integration_matchers):
"""Initialize class."""
super().__init__()
self.hass = hass
self.name = "dhcp-discovery"
self._integration_matchers = integration_matchers
self._address_data = {}
self._stop_event = Event()
self._address_data = address_data
def process_client(self, ip_address, hostname, mac_address):
"""Process a client."""
data = self._address_data.get(ip_address)
if data and data[MAC_ADDRESS] == mac_address and data[HOSTNAME] == hostname:
# If the address data is the same no need
# to process it
return
self._address_data[ip_address] = {MAC_ADDRESS: mac_address, HOSTNAME: hostname}
self.process_updated_address_data(ip_address, self._address_data[ip_address])
def process_updated_address_data(self, ip_address, data):
"""Process the address data update."""
lowercase_hostname = data[HOSTNAME].lower()
uppercase_mac = data[MAC_ADDRESS].upper()
_LOGGER.debug(
"Processing updated address data for %s: mac=%s hostname=%s",
ip_address,
uppercase_mac,
lowercase_hostname,
)
for entry in self._integration_matchers:
if MAC_ADDRESS in entry and not fnmatch.fnmatch(
uppercase_mac, entry[MAC_ADDRESS]
):
continue
if HOSTNAME in entry and not fnmatch.fnmatch(
lowercase_hostname, entry[HOSTNAME]
):
continue
_LOGGER.debug("Matched %s against %s", data, entry)
self.create_task(
self.hass.config_entries.flow.async_init(
entry["domain"],
context={"source": DOMAIN},
data={IP_ADDRESS: ip_address, **data},
)
)
@abstractmethod
def create_task(self, task):
"""Pass a task to async_add_task based on which context we are in."""
class DeviceTrackerWatcher(WatcherBase):
"""Class to watch dhcp data from routers."""
def __init__(self, hass, address_data, integration_matchers):
"""Initialize class."""
super().__init__(hass, address_data, integration_matchers)
self._unsub = None
@callback
def async_stop(self):
"""Stop watching for new device trackers."""
if self._unsub:
self._unsub()
self._unsub = None
@callback
def async_start(self):
"""Stop watching for new device trackers."""
self._unsub = async_track_state_added_domain(
self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event
)
for state in self.hass.states.async_all(DEVICE_TRACKER_DOMAIN):
self._async_process_device_state(state)
@callback
def _async_process_device_event(self, event: Event):
"""Process a device tracker state change event."""
self._async_process_device_state(event.data.get("new_state"))
@callback
def _async_process_device_state(self, state: State):
"""Process a device tracker state."""
if state.state != STATE_HOME:
return
attributes = state.attributes
if attributes.get(ATTR_SOURCE_TYPE) != SOURCE_TYPE_ROUTER:
return
ip_address = attributes.get(ATTR_IP)
hostname = attributes.get(ATTR_HOST_NAME)
mac_address = attributes.get(ATTR_MAC)
if ip_address is None or hostname is None or mac_address is None:
return
self.process_client(ip_address, hostname, _format_mac(mac_address))
def create_task(self, task):
"""Pass a task to async_create_task since we are in async context."""
self.hass.async_create_task(task)
class DHCPWatcher(WatcherBase, threading.Thread):
"""Class to watch dhcp requests."""
def __init__(self, hass, address_data, integration_matchers):
"""Initialize class."""
super().__init__(hass, address_data, integration_matchers)
self.name = "dhcp-discovery"
self._stop_event = threading.Event()
def stop(self):
"""Stop the thread."""
self._stop_event.set()
self.join()
@callback
def async_start(self):
"""Start the thread."""
self.start()
def run(self):
"""Start watching for dhcp packets."""
@ -98,49 +239,11 @@ class DHCPWatcher(Thread):
if ip_address is None or hostname is None or mac_address is None:
return
data = self._address_data.get(ip_address)
self.process_client(ip_address, hostname, mac_address)
if data and data[MAC_ADDRESS] == mac_address and data[HOSTNAME] == hostname:
# If the address data is the same no need
# to process it
return
self._address_data[ip_address] = {MAC_ADDRESS: mac_address, HOSTNAME: hostname}
self.process_updated_address_data(ip_address, self._address_data[ip_address])
def process_updated_address_data(self, ip_address, data):
"""Process the address data update."""
lowercase_hostname = data[HOSTNAME].lower()
uppercase_mac = data[MAC_ADDRESS].upper()
_LOGGER.debug(
"Processing updated address data for %s: mac=%s hostname=%s",
ip_address,
uppercase_mac,
lowercase_hostname,
)
for entry in self._integration_matchers:
if MAC_ADDRESS in entry and not fnmatch.fnmatch(
uppercase_mac, entry[MAC_ADDRESS]
):
continue
if HOSTNAME in entry and not fnmatch.fnmatch(
lowercase_hostname, entry[HOSTNAME]
):
continue
_LOGGER.debug("Matched %s against %s", data, entry)
self.hass.add_job(
self.hass.config_entries.flow.async_init(
entry["domain"],
context={"source": DOMAIN},
data={IP_ADDRESS: ip_address, **data},
)
)
def create_task(self, task):
"""Pass a task to hass.add_job since we are in a thread."""
self.hass.add_job(task)
def _decode_dhcp_option(dhcp_options, key):

View File

@ -52,6 +52,7 @@ CLIENT_STATIC_ATTRIBUTES = [
"oui",
]
CLIENT_CONNECTED_ALL_ATTRIBUTES = CLIENT_CONNECTED_ATTRIBUTES + CLIENT_STATIC_ATTRIBUTES
DEVICE_UPGRADED = (ACCESS_POINT_UPGRADED, GATEWAY_UPGRADED, SWITCH_UPGRADED)
@ -239,6 +240,21 @@ class UniFiClientTracker(UniFiClient, ScannerEntity):
return attributes
@property
def ip_address(self) -> str:
"""Return the primary ip address of the device."""
return self.client.raw.get("ip")
@property
def mac_address(self) -> str:
"""Return the mac address of the device."""
return self.client.raw.get("mac")
@property
def hostname(self) -> str:
"""Return hostname of the device."""
return self.client.raw.get("hostname")
async def options_updated(self) -> None:
"""Config entry options are updated, remove entity if option is disabled."""
if not self.controller.option_track_clients:

View File

@ -59,3 +59,6 @@ def test_base_tracker_entity():
assert entity.battery_level is None
with pytest.raises(NotImplementedError):
assert entity.state_attributes is None
assert entity.ip_address is None
assert entity.mac_address is None
assert entity.hostname is None

View File

@ -7,7 +7,19 @@ from scapy.layers.dhcp import DHCP
from scapy.layers.l2 import Ether
from homeassistant.components import dhcp
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.components.device_tracker.const import (
ATTR_HOST_NAME,
ATTR_IP,
ATTR_MAC,
ATTR_SOURCE_TYPE,
SOURCE_TYPE_ROUTER,
)
from homeassistant.const import (
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP,
STATE_HOME,
STATE_NOT_HOME,
)
from homeassistant.setup import async_setup_component
from tests.common import mock_coro
@ -41,6 +53,7 @@ async def test_dhcp_match_hostname_and_macaddress(hass):
"""Test matching based on hostname and macaddress."""
dhcp_watcher = dhcp.DHCPWatcher(
hass,
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
@ -66,7 +79,7 @@ async def test_dhcp_match_hostname_and_macaddress(hass):
async def test_dhcp_match_hostname(hass):
"""Test matching based on hostname only."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, [{"domain": "mock-domain", "hostname": "connect"}]
hass, {}, [{"domain": "mock-domain", "hostname": "connect"}]
)
packet = Ether(RAW_DHCP_REQUEST)
@ -89,7 +102,7 @@ async def test_dhcp_match_hostname(hass):
async def test_dhcp_match_macaddress(hass):
"""Test matching based on macaddress only."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, [{"domain": "mock-domain", "macaddress": "B8B7F1*"}]
hass, {}, [{"domain": "mock-domain", "macaddress": "B8B7F1*"}]
)
packet = Ether(RAW_DHCP_REQUEST)
@ -112,7 +125,7 @@ async def test_dhcp_match_macaddress(hass):
async def test_dhcp_nomatch(hass):
"""Test not matching based on macaddress only."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, [{"domain": "mock-domain", "macaddress": "ABC123*"}]
hass, {}, [{"domain": "mock-domain", "macaddress": "ABC123*"}]
)
packet = Ether(RAW_DHCP_REQUEST)
@ -128,7 +141,7 @@ async def test_dhcp_nomatch(hass):
async def test_dhcp_nomatch_hostname(hass):
"""Test not matching based on hostname only."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(RAW_DHCP_REQUEST)
@ -144,7 +157,7 @@ async def test_dhcp_nomatch_hostname(hass):
async def test_dhcp_nomatch_non_dhcp_packet(hass):
"""Test matching does not throw on a non-dhcp packet."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(b"")
@ -160,7 +173,7 @@ async def test_dhcp_nomatch_non_dhcp_packet(hass):
async def test_dhcp_nomatch_non_dhcp_request_packet(hass):
"""Test nothing happens with the wrong message-type."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(RAW_DHCP_REQUEST)
@ -185,7 +198,7 @@ async def test_dhcp_nomatch_non_dhcp_request_packet(hass):
async def test_dhcp_invalid_hostname(hass):
"""Test we ignore invalid hostnames."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(RAW_DHCP_REQUEST)
@ -210,7 +223,7 @@ async def test_dhcp_invalid_hostname(hass):
async def test_dhcp_missing_hostname(hass):
"""Test we ignore missing hostnames."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(RAW_DHCP_REQUEST)
@ -235,7 +248,7 @@ async def test_dhcp_missing_hostname(hass):
async def test_dhcp_invalid_option(hass):
"""Test we ignore invalid hostname option."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(RAW_DHCP_REQUEST)
@ -327,3 +340,167 @@ async def test_setup_fails_non_root(hass, caplog):
await hass.async_block_till_done()
wait_event.set()
assert "Cannot watch for dhcp packets without root or CAP_NET_RAW" in caplog.text
async def test_device_tracker_hostname_and_macaddress_exists_before_start(hass):
"""Test matching based on hostname and macaddress before start."""
hass.states.async_set(
"device_tracker.august_connect",
STATE_HOME,
{
ATTR_HOST_NAME: "connect",
ATTR_IP: "192.168.210.56",
ATTR_SOURCE_TYPE: SOURCE_TYPE_ROUTER,
ATTR_MAC: "B8:B7:F1:6D:B5:33",
},
)
with patch.object(
hass.config_entries.flow, "async_init", return_value=mock_coro()
) as mock_init:
device_tracker_watcher = dhcp.DeviceTrackerWatcher(
hass,
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
device_tracker_watcher.async_start()
await hass.async_block_till_done()
device_tracker_watcher.async_stop()
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
assert mock_init.mock_calls[0][2]["context"] == {"source": "dhcp"}
assert mock_init.mock_calls[0][2]["data"] == {
dhcp.IP_ADDRESS: "192.168.210.56",
dhcp.HOSTNAME: "connect",
dhcp.MAC_ADDRESS: "b8b7f16db533",
}
async def test_device_tracker_hostname_and_macaddress_after_start(hass):
"""Test matching based on hostname and macaddress after start."""
with patch.object(
hass.config_entries.flow, "async_init", return_value=mock_coro()
) as mock_init:
device_tracker_watcher = dhcp.DeviceTrackerWatcher(
hass,
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
device_tracker_watcher.async_start()
await hass.async_block_till_done()
hass.states.async_set(
"device_tracker.august_connect",
STATE_HOME,
{
ATTR_HOST_NAME: "connect",
ATTR_IP: "192.168.210.56",
ATTR_SOURCE_TYPE: SOURCE_TYPE_ROUTER,
ATTR_MAC: "B8:B7:F1:6D:B5:33",
},
)
await hass.async_block_till_done()
device_tracker_watcher.async_stop()
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
assert mock_init.mock_calls[0][2]["context"] == {"source": "dhcp"}
assert mock_init.mock_calls[0][2]["data"] == {
dhcp.IP_ADDRESS: "192.168.210.56",
dhcp.HOSTNAME: "connect",
dhcp.MAC_ADDRESS: "b8b7f16db533",
}
async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass):
"""Test matching based on hostname and macaddress after start but not home."""
with patch.object(
hass.config_entries.flow, "async_init", return_value=mock_coro()
) as mock_init:
device_tracker_watcher = dhcp.DeviceTrackerWatcher(
hass,
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
device_tracker_watcher.async_start()
await hass.async_block_till_done()
hass.states.async_set(
"device_tracker.august_connect",
STATE_NOT_HOME,
{
ATTR_HOST_NAME: "connect",
ATTR_IP: "192.168.210.56",
ATTR_SOURCE_TYPE: SOURCE_TYPE_ROUTER,
ATTR_MAC: "B8:B7:F1:6D:B5:33",
},
)
await hass.async_block_till_done()
device_tracker_watcher.async_stop()
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 0
async def test_device_tracker_hostname_and_macaddress_after_start_not_router(hass):
"""Test matching based on hostname and macaddress after start but not router."""
with patch.object(
hass.config_entries.flow, "async_init", return_value=mock_coro()
) as mock_init:
device_tracker_watcher = dhcp.DeviceTrackerWatcher(
hass,
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
device_tracker_watcher.async_start()
await hass.async_block_till_done()
hass.states.async_set(
"device_tracker.august_connect",
STATE_HOME,
{
ATTR_HOST_NAME: "connect",
ATTR_IP: "192.168.210.56",
ATTR_SOURCE_TYPE: "something_else",
ATTR_MAC: "B8:B7:F1:6D:B5:33",
},
)
await hass.async_block_till_done()
device_tracker_watcher.async_stop()
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 0
async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missing(
hass,
):
"""Test matching based on hostname and macaddress after start but missing hostname."""
with patch.object(
hass.config_entries.flow, "async_init", return_value=mock_coro()
) as mock_init:
device_tracker_watcher = dhcp.DeviceTrackerWatcher(
hass,
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
device_tracker_watcher.async_start()
await hass.async_block_till_done()
hass.states.async_set(
"device_tracker.august_connect",
STATE_HOME,
{
ATTR_IP: "192.168.210.56",
ATTR_SOURCE_TYPE: SOURCE_TYPE_ROUTER,
ATTR_MAC: "B8:B7:F1:6D:B5:33",
},
)
await hass.async_block_till_done()
device_tracker_watcher.async_stop()
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 0

View File

@ -189,6 +189,10 @@ async def test_tracked_wireless_clients(hass):
client_1 = hass.states.get("device_tracker.client_1")
assert client_1.state == "home"
assert client_1.attributes["ip"] == "10.0.0.1"
assert client_1.attributes["mac"] == "00:00:00:00:00:01"
assert client_1.attributes["hostname"] == "client_1"
assert client_1.attributes["host_name"] == "client_1"
# State change signalling works with events
controller.api.websocket._data = {