From 331102e592b3251f18de3d2d680e28122f6b1eb2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 17 Feb 2023 14:51:19 -0600 Subject: [PATCH] Dismiss discoveries when the underlying device disappears (#88340) * Implement discovery removals Bluetooth, HomeKit, SSDP, and Zeroconf now implement dismissing discoveries when the underlying discovered device disappears * cover * add zeroconf test * cover * cover bluetooth * fix rediscover --- homeassistant/components/bluetooth/manager.py | 10 ++ homeassistant/components/ssdp/__init__.py | 20 ++- homeassistant/components/zeroconf/__init__.py | 9 ++ homeassistant/data_entry_flow.py | 44 +++++- tests/components/bluetooth/test_manager.py | 135 ++++++++++++++++++ tests/components/ssdp/test_init.py | 73 ++++++++++ tests/components/zeroconf/test_init.py | 44 ++++++ tests/test_data_entry_flow.py | 108 +++++++++++++- 8 files changed, 432 insertions(+), 11 deletions(-) diff --git a/homeassistant/components/bluetooth/manager.py b/homeassistant/components/bluetooth/manager.py index 3fea5035588b..bc210516562f 100644 --- a/homeassistant/components/bluetooth/manager.py +++ b/homeassistant/components/bluetooth/manager.py @@ -315,6 +315,8 @@ class BluetoothManager: # the device from all the interval tracking since it is no longer # available for both connectable and non-connectable tracker.async_remove_address(address) + self._integration_matcher.async_clear_address(address) + self._async_dismiss_discoveries(address) service_info = history.pop(address) @@ -327,6 +329,14 @@ class BluetoothManager: except Exception: # pylint: disable=broad-except _LOGGER.exception("Error in unavailable callback") + def _async_dismiss_discoveries(self, address: str) -> None: + """Dismiss all discoveries for the given address.""" + for flow in self.hass.config_entries.flow.async_progress_by_init_data_type( + BluetoothServiceInfoBleak, + lambda service_info: bool(service_info.address == address), + ): + self.hass.config_entries.flow.async_abort(flow["flow_id"]) + def _prefer_previous_adv_from_different_source( self, old: BluetoothServiceInfoBleak, diff --git a/homeassistant/components/ssdp/__init__.py b/homeassistant/components/ssdp/__init__.py index c2f56bb7b4ad..b7e28f270457 100644 --- a/homeassistant/components/ssdp/__init__.py +++ b/homeassistant/components/ssdp/__init__.py @@ -518,7 +518,11 @@ class Scanner: CaseInsensitiveDict(combined_headers.as_dict(), **info_desc) ) - if not callbacks and not matching_domains: + if ( + not callbacks + and not matching_domains + and source != SsdpSource.ADVERTISEMENT_BYEBYE + ): return discovery_info = discovery_info_from_headers_and_description( @@ -534,6 +538,7 @@ class Scanner: # Config flows should only be created for alive/update messages from alive devices if source == SsdpSource.ADVERTISEMENT_BYEBYE: + self._async_dismiss_discoveries(discovery_info) return _LOGGER.debug("Discovery info: %s", discovery_info) @@ -548,6 +553,19 @@ class Scanner: discovery_info, ) + def _async_dismiss_discoveries( + self, byebye_discovery_info: SsdpServiceInfo + ) -> None: + """Dismiss all discoveries for the given address.""" + for flow in self.hass.config_entries.flow.async_progress_by_init_data_type( + SsdpServiceInfo, + lambda service_info: bool( + service_info.ssdp_st == byebye_discovery_info.ssdp_st + and service_info.ssdp_location == byebye_discovery_info.ssdp_location + ), + ): + self.hass.config_entries.flow.async_abort(flow["flow_id"]) + async def _async_get_description_dict( self, location: str | None ) -> Mapping[str, str]: diff --git a/homeassistant/components/zeroconf/__init__.py b/homeassistant/components/zeroconf/__init__.py index df7eec71bd81..badc1242714b 100644 --- a/homeassistant/components/zeroconf/__init__.py +++ b/homeassistant/components/zeroconf/__init__.py @@ -378,6 +378,14 @@ class ZeroconfDiscovery: if self.async_service_browser: await self.async_service_browser.async_cancel() + def _async_dismiss_discoveries(self, name: str) -> None: + """Dismiss all discoveries for the given name.""" + for flow in self.hass.config_entries.flow.async_progress_by_init_data_type( + ZeroconfServiceInfo, + lambda service_info: bool(service_info.name == name), + ): + self.hass.config_entries.flow.async_abort(flow["flow_id"]) + @callback def async_service_update( self, @@ -395,6 +403,7 @@ class ZeroconfDiscovery: ) if state_change == ServiceStateChange.Removed: + self._async_dismiss_discoveries(name) return try: diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index ebe67e471036..347ab89e4526 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -2,7 +2,7 @@ from __future__ import annotations import abc -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping import copy from dataclasses import dataclass import logging @@ -138,6 +138,7 @@ class FlowManager(abc.ABC): self.hass = hass self._progress: dict[str, FlowHandler] = {} self._handler_progress_index: dict[str, set[str]] = {} + self._init_data_process_index: dict[type, set[str]] = {} @abc.abstractmethod async def async_create_flow( @@ -198,6 +199,23 @@ class FlowManager(abc.ABC): self._async_progress_by_handler(handler), include_uninitialized ) + @callback + def async_progress_by_init_data_type( + self, + init_data_type: type, + matcher: Callable[[Any], bool], + include_uninitialized: bool = False, + ) -> list[FlowResult]: + """Return flows in progress init matching by data type as a partial FlowResult.""" + return _async_flow_handler_to_flow_result( + ( + self._progress[flow_id] + for flow_id in self._init_data_process_index.get(init_data_type, {}) + if matcher(self._progress[flow_id].init_data) + ), + include_uninitialized, + ) + @callback def _async_progress_by_handler(self, handler: str) -> list[FlowHandler]: """Return the flows in progress by handler.""" @@ -301,19 +319,33 @@ class FlowManager(abc.ABC): @callback def _async_add_flow_progress(self, flow: FlowHandler) -> None: """Add a flow to in progress.""" + if flow.init_data is not None: + init_data_type = type(flow.init_data) + self._init_data_process_index.setdefault(init_data_type, set()).add( + flow.flow_id + ) self._progress[flow.flow_id] = flow self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id) + @callback + def _async_remove_flow_from_index(self, flow: FlowHandler) -> None: + """Remove a flow from in progress.""" + if flow.init_data is not None: + init_data_type = type(flow.init_data) + self._init_data_process_index[init_data_type].remove(flow.flow_id) + if not self._init_data_process_index[init_data_type]: + del self._init_data_process_index[init_data_type] + handler = flow.handler + self._handler_progress_index[handler].remove(flow.flow_id) + if not self._handler_progress_index[handler]: + del self._handler_progress_index[handler] + @callback def _async_remove_flow_progress(self, flow_id: str) -> None: """Remove a flow from in progress.""" if (flow := self._progress.pop(flow_id, None)) is None: raise UnknownFlow - handler = flow.handler - self._handler_progress_index[handler].remove(flow.flow_id) - if not self._handler_progress_index[handler]: - del self._handler_progress_index[handler] - + self._async_remove_flow_from_index(flow) try: flow.async_remove() except Exception as err: # pylint: disable=broad-except diff --git a/tests/components/bluetooth/test_manager.py b/tests/components/bluetooth/test_manager.py index 64bfd2b9281d..7e605ece4cb0 100644 --- a/tests/components/bluetooth/test_manager.py +++ b/tests/components/bluetooth/test_manager.py @@ -803,3 +803,138 @@ async def test_goes_unavailable_connectable_only_and_recovers( unsetup_connectable_scanner_2() cancel_not_connectable_scanner() unsetup_not_connectable_scanner() + + +async def test_goes_unavailable_dismisses_discovery( + hass: HomeAssistant, mock_bluetooth_adapters: None +) -> None: + """Test that unavailable will dismiss any active discoveries.""" + assert await async_setup_component(hass, bluetooth.DOMAIN, {}) + await hass.async_block_till_done() + + assert async_scanner_count(hass, connectable=False) == 0 + switchbot_device_non_connectable = BLEDevice( + "44:44:33:11:23:45", + "wohand", + {}, + rssi=-100, + ) + switchbot_device_adv = generate_advertisement_data( + local_name="wohand", + service_uuids=["050a021a-0000-1000-8000-00805f9b34fb"], + service_data={"050a021a-0000-1000-8000-00805f9b34fb": b"\n\xff"}, + manufacturer_data={1: b"\x01"}, + rssi=-100, + ) + callbacks = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, + change: BluetoothChange, + ) -> None: + """Fake subscriber for the BleakScanner.""" + callbacks.append((service_info, change)) + + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + {"address": "44:44:33:11:23:45", "connectable": False}, + BluetoothScanningMode.ACTIVE, + ) + + class FakeScanner(BaseHaRemoteScanner): + def inject_advertisement( + self, device: BLEDevice, advertisement_data: AdvertisementData + ) -> None: + """Inject an advertisement.""" + self._async_on_advertisement( + device.address, + advertisement_data.rssi, + device.name, + advertisement_data.service_uuids, + advertisement_data.service_data, + advertisement_data.manufacturer_data, + advertisement_data.tx_power, + {"scanner_specific_data": "test"}, + ) + + def clear_all_devices(self) -> None: + """Clear all devices.""" + self._discovered_device_advertisement_datas.clear() + self._discovered_device_timestamps.clear() + + new_info_callback = async_get_advertisement_callback(hass) + connector = ( + HaBluetoothConnector(MockBleakClient, "mock_bleak_client", lambda: False), + ) + non_connectable_scanner = FakeScanner( + hass, + "connectable", + "connectable", + new_info_callback, + connector, + False, + ) + unsetup_connectable_scanner = non_connectable_scanner.async_setup() + cancel_connectable_scanner = _get_manager().async_register_scanner( + non_connectable_scanner, True + ) + non_connectable_scanner.inject_advertisement( + switchbot_device_non_connectable, switchbot_device_adv + ) + assert async_ble_device_from_address(hass, "44:44:33:11:23:45", False) is not None + assert async_scanner_count(hass, connectable=True) == 1 + assert len(callbacks) == 1 + + assert ( + "44:44:33:11:23:45" + in non_connectable_scanner.discovered_devices_and_advertisement_data + ) + + unavailable_callbacks: list[BluetoothServiceInfoBleak] = [] + + @callback + def _unavailable_callback(service_info: BluetoothServiceInfoBleak) -> None: + """Wrong device unavailable callback.""" + nonlocal unavailable_callbacks + unavailable_callbacks.append(service_info.address) + + cancel_unavailable = async_track_unavailable( + hass, + _unavailable_callback, + switchbot_device_non_connectable.address, + connectable=False, + ) + + assert async_scanner_count(hass, connectable=False) == 1 + + non_connectable_scanner.clear_all_devices() + assert ( + "44:44:33:11:23:45" + not in non_connectable_scanner.discovered_devices_and_advertisement_data + ) + monotonic_now = time.monotonic() + with patch.object( + hass.config_entries.flow, + "async_progress_by_init_data_type", + return_value=[{"flow_id": "mock_flow_id"}], + ) as mock_async_progress_by_init_data_type, patch.object( + hass.config_entries.flow, "async_abort" + ) as mock_async_abort, patch( + "homeassistant.components.bluetooth.manager.MONOTONIC_TIME", + return_value=monotonic_now + FALLBACK_MAXIMUM_STALE_ADVERTISEMENT_SECONDS, + ): + async_fire_time_changed( + hass, dt_util.utcnow() + timedelta(seconds=UNAVAILABLE_TRACK_SECONDS) + ) + await hass.async_block_till_done() + assert "44:44:33:11:23:45" in unavailable_callbacks + + assert len(mock_async_progress_by_init_data_type.mock_calls) == 1 + assert mock_async_abort.mock_calls[0][1][0] == "mock_flow_id" + + cancel_unavailable() + + cancel() + unsetup_connectable_scanner() + cancel_connectable_scanner() diff --git a/tests/components/ssdp/test_init.py b/tests/components/ssdp/test_init.py index 485ed278dc17..b068aed11abb 100644 --- a/tests/components/ssdp/test_init.py +++ b/tests/components/ssdp/test_init.py @@ -784,3 +784,76 @@ async def test_ipv4_does_additional_search_for_sonos( ), ) assert ssdp_listener.async_search.call_args[1] == {} + + +@pytest.mark.usefixtures("mock_get_source_ip") +@patch( + "homeassistant.components.ssdp.async_get_ssdp", + return_value={"mock-domain": [{"deviceType": "Paulus"}]}, +) +async def test_flow_dismiss_on_byebye( + mock_get_ssdp, + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + mock_flow_init, +) -> None: + """Test config flow is only started for alive devices.""" + aioclient_mock.get( + "http://1.1.1.1", + text=""" + + + Paulus + + + """, + ) + ssdp_listener = await init_ssdp_component(hass) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + # Search should start a flow + mock_ssdp_search_response = _ssdp_headers( + { + "st": "mock-st", + "location": "http://1.1.1.1", + "usn": "uuid:mock-udn::mock-st", + } + ) + ssdp_listener._on_search(mock_ssdp_search_response) + await hass.async_block_till_done() + + mock_flow_init.assert_awaited_once_with( + "mock-domain", context={"source": config_entries.SOURCE_SSDP}, data=ANY + ) + + # ssdp:alive advertisement should start a flow + mock_flow_init.reset_mock() + mock_ssdp_advertisement = _ssdp_headers( + { + "location": "http://1.1.1.1", + "usn": "uuid:mock-udn::mock-st", + "nt": "upnp:rootdevice", + "nts": "ssdp:alive", + } + ) + ssdp_listener._on_alive(mock_ssdp_advertisement) + await hass.async_block_till_done() + mock_flow_init.assert_awaited_once_with( + "mock-domain", context={"source": config_entries.SOURCE_SSDP}, data=ANY + ) + + mock_ssdp_advertisement["nts"] = "ssdp:byebye" + # ssdp:byebye advertisement should dismiss existing flows + with patch.object( + hass.config_entries.flow, + "async_progress_by_init_data_type", + return_value=[{"flow_id": "mock_flow_id"}], + ) as mock_async_progress_by_init_data_type, patch.object( + hass.config_entries.flow, "async_abort" + ) as mock_async_abort: + ssdp_listener._on_byebye(mock_ssdp_advertisement) + await hass.async_block_till_done() + + assert len(mock_async_progress_by_init_data_type.mock_calls) == 1 + assert mock_async_abort.mock_calls[0][1][0] == "mock_flow_id" diff --git a/tests/components/zeroconf/test_init.py b/tests/components/zeroconf/test_init.py index a963a498a0d3..fea6b27e208e 100644 --- a/tests/components/zeroconf/test_init.py +++ b/tests/components/zeroconf/test_init.py @@ -1339,3 +1339,47 @@ async def test_start_with_frontend( await hass.async_block_till_done() mock_async_zeroconf.async_register_service.assert_called_once() + + +async def test_zeroconf_removed(hass: HomeAssistant, mock_async_zeroconf: None) -> None: + """Test we dismiss flows when a PTR record is removed.""" + + def _device_removed_mock(ipv6, zeroconf, services, handlers): + """Call service update handler.""" + handlers[0]( + zeroconf, + "_http._tcp.local.", + "Shelly108._http._tcp.local.", + ServiceStateChange.Removed, + ) + + with patch.dict( + zc_gen.ZEROCONF, + { + "_http._tcp.local.": [ + { + "domain": "shelly", + "name": "shelly*", + } + ] + }, + clear=True, + ), patch.object( + hass.config_entries.flow, + "async_progress_by_init_data_type", + return_value=[{"flow_id": "mock_flow_id"}], + ) as mock_async_progress_by_init_data_type, patch.object( + hass.config_entries.flow, "async_abort" + ) as mock_async_abort, patch.object( + zeroconf, "HaAsyncServiceBrowser", side_effect=_device_removed_mock + ) as mock_service_browser, patch( + "homeassistant.components.zeroconf.AsyncServiceInfo", + side_effect=get_zeroconf_info_mock("FFAADDCC11DD"), + ): + assert await async_setup_component(hass, zeroconf.DOMAIN, {zeroconf.DOMAIN: {}}) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + assert len(mock_service_browser.mock_calls) == 1 + assert len(mock_async_progress_by_init_data_type.mock_calls) == 1 + assert mock_async_abort.mock_calls[0][1][0] == "mock_flow_id" diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 29bab8fbbcb3..ecb730011616 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -1,4 +1,5 @@ """Test the flow classes.""" +import dataclasses import logging from unittest.mock import Mock, patch @@ -73,7 +74,7 @@ async def test_configure_reuses_handler_instance(manager): assert len(manager.mock_created_entries) == 0 -async def test_configure_two_steps(manager): +async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None: """Test that we reuse instances.""" @manager.mock_reg_handler("test") @@ -82,7 +83,6 @@ async def test_configure_two_steps(manager): async def async_step_first(self, user_input=None): if user_input is not None: - self.init_data = user_input return await self.async_step_second() return self.async_show_form(step_id="first", data_schema=vol.Schema([str])) @@ -93,12 +93,13 @@ async def test_configure_two_steps(manager): ) return self.async_show_form(step_id="second", data_schema=vol.Schema([str])) - form = await manager.async_init("test", context={"init_step": "first"}) + form = await manager.async_init( + "test", context={"init_step": "first"}, data=["INIT-DATA"] + ) with pytest.raises(vol.Invalid): form = await manager.async_configure(form["flow_id"], "INCORRECT-DATA") - form = await manager.async_configure(form["flow_id"], ["INIT-DATA"]) form = await manager.async_configure(form["flow_id"], ["SECOND-DATA"]) assert form["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY assert len(manager.async_progress()) == 0 @@ -553,3 +554,102 @@ async def test_show_menu(hass, manager, menu_options): ) assert result["type"] == data_entry_flow.FlowResultType.FORM assert result["step_id"] == "target1" + + +async def test_find_flows_by_init_data_type( + manager: data_entry_flow.FlowManager, +) -> None: + """Test we can find flows by init data type.""" + + @dataclasses.dataclass + class BluetoothDiscoveryData: + """Bluetooth Discovery data.""" + + address: str + + @dataclasses.dataclass + class WiFiDiscoveryData: + """WiFi Discovery data.""" + + address: str + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 1 + + async def async_step_first(self, user_input=None): + if user_input is not None: + return await self.async_step_second() + return self.async_show_form(step_id="first", data_schema=vol.Schema([str])) + + async def async_step_second(self, user_input=None): + if user_input is not None: + return self.async_create_entry( + title="Test Entry", + data={"init": self.init_data, "user": user_input}, + ) + return self.async_show_form(step_id="second", data_schema=vol.Schema([str])) + + bluetooth_data = BluetoothDiscoveryData("aa:bb:cc:dd:ee:ff") + wifi_data = WiFiDiscoveryData("host") + + bluetooth_form = await manager.async_init( + "test", context={"init_step": "first"}, data=bluetooth_data + ) + await manager.async_init("test", context={"init_step": "first"}, data=wifi_data) + + assert ( + len( + manager.async_progress_by_init_data_type( + BluetoothDiscoveryData, lambda data: True + ) + ) + ) == 1 + assert ( + len( + manager.async_progress_by_init_data_type( + BluetoothDiscoveryData, + lambda data: bool(data.address == "aa:bb:cc:dd:ee:ff"), + ) + ) + ) == 1 + assert ( + len( + manager.async_progress_by_init_data_type( + BluetoothDiscoveryData, lambda data: bool(data.address == "not it") + ) + ) + ) == 0 + + wifi_flows = manager.async_progress_by_init_data_type( + WiFiDiscoveryData, lambda data: True + ) + assert len(wifi_flows) == 1 + + bluetooth_result = await manager.async_configure( + bluetooth_form["flow_id"], ["SECOND-DATA"] + ) + assert bluetooth_result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY + assert len(manager.async_progress()) == 1 + assert len(manager.mock_created_entries) == 1 + result = manager.mock_created_entries[0] + assert result["handler"] == "test" + assert result["data"] == {"init": bluetooth_data, "user": ["SECOND-DATA"]} + + bluetooth_flows = manager.async_progress_by_init_data_type( + BluetoothDiscoveryData, lambda data: True + ) + assert len(bluetooth_flows) == 0 + + wifi_flows = manager.async_progress_by_init_data_type( + WiFiDiscoveryData, lambda data: True + ) + assert len(wifi_flows) == 1 + + manager.async_abort(wifi_flows[0]["flow_id"]) + + wifi_flows = manager.async_progress_by_init_data_type( + WiFiDiscoveryData, lambda data: True + ) + assert len(wifi_flows) == 0 + assert len(manager.async_progress()) == 0