From d851cb6f9e116dddb8ebeed4dd3910eb811993bd Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 16 Dec 2019 12:27:43 +0100 Subject: [PATCH] Add unique ID to config entries (#29806) * Add unique ID to config entries * Unload existing entries with same unique ID if flow with unique ID is finished * Remove unused exception * Fix typing * silence pylint * Fix tests * Add unique ID to Hue * Address typing comment * Tweaks to comments * lint --- homeassistant/components/hue/__init__.py | 15 ++- homeassistant/components/hue/bridge.py | 22 ++++ homeassistant/components/hue/config_flow.py | 31 +++--- homeassistant/config_entries.py | 61 ++++++++++- homeassistant/data_entry_flow.py | 46 +++++++-- tests/common.py | 2 + tests/components/hue/test_config_flow.py | 35 ++++--- tests/components/hue/test_init.py | 16 +++ tests/test_config_entries.py | 107 ++++++++++++++++++++ tests/test_data_entry_flow.py | 16 ++- 10 files changed, 305 insertions(+), 46 deletions(-) diff --git a/homeassistant/components/hue/__init__.py b/homeassistant/components/hue/__init__.py index f2b9bd1a229f..57057004479a 100644 --- a/homeassistant/components/hue/__init__.py +++ b/homeassistant/components/hue/__init__.py @@ -4,11 +4,11 @@ import logging import voluptuous as vol -from homeassistant import config_entries +from homeassistant import config_entries, core from homeassistant.const import CONF_FILENAME, CONF_HOST from homeassistant.helpers import config_validation as cv, device_registry as dr -from .bridge import HueBridge +from .bridge import HueBridge, normalize_bridge_id from .config_flow import ( # Loading the config flow file will register the flow configured_hosts, ) @@ -102,7 +102,9 @@ async def async_setup(hass, config): return True -async def async_setup_entry(hass, entry): +async def async_setup_entry( + hass: core.HomeAssistant, entry: config_entries.ConfigEntry +): """Set up a bridge from a config entry.""" host = entry.data["host"] config = hass.data[DATA_CONFIGS].get(host) @@ -121,6 +123,13 @@ async def async_setup_entry(hass, entry): hass.data[DOMAIN][host] = bridge config = bridge.api.config + + # For backwards compat + if entry.unique_id is None: + hass.config_entries.async_update_entry( + entry, unique_id=normalize_bridge_id(config.bridgeid) + ) + device_registry = await dr.async_get_registry(hass) device_registry.async_get_or_create( config_entry_id=entry.entry_id, diff --git a/homeassistant/components/hue/bridge.py b/homeassistant/components/hue/bridge.py index 5a5e55773a5d..0ed6e3a9911a 100644 --- a/homeassistant/components/hue/bridge.py +++ b/homeassistant/components/hue/bridge.py @@ -201,3 +201,25 @@ async def get_bridge(hass, host, username=None): except aiohue.AiohueException: LOGGER.exception("Unknown Hue linking error occurred") raise AuthenticationRequired + + +def normalize_bridge_id(bridge_id: str): + """Normalize a bridge identifier. + + There are three sources where we receive bridge ID from: + - ssdp/upnp: /description.xml, field root/device/serialNumber + - nupnp: "id" field + - Hue Bridge API: config.bridgeid + + The SSDP/UPNP source does not contain the middle 4 characters compared + to the other sources. In all our tests the middle 4 characters are "fffe". + """ + if len(bridge_id) == 16: + return bridge_id[0:6] + bridge_id[-6:] + + if len(bridge_id) == 12: + return bridge_id + + LOGGER.warning("Unexpected bridge id number found: %s", bridge_id) + + return bridge_id diff --git a/homeassistant/components/hue/config_flow.py b/homeassistant/components/hue/config_flow.py index 0423dc6fc2bd..882bf5b70db5 100644 --- a/homeassistant/components/hue/config_flow.py +++ b/homeassistant/components/hue/config_flow.py @@ -12,7 +12,7 @@ from homeassistant.components.ssdp import ATTR_MANUFACTURERURL, ATTR_NAME from homeassistant.core import callback from homeassistant.helpers import aiohttp_client -from .bridge import get_bridge +from .bridge import get_bridge, normalize_bridge_id from .const import DOMAIN, LOGGER from .errors import AuthenticationRequired, CannotConnect @@ -154,17 +154,15 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): if host in configured_hosts(self.hass): return self.async_abort(reason="already_configured") - # This value is based off host/description.xml and is, weirdly, missing - # 4 characters in the middle of the serial compared to results returned - # from the NUPNP API or when querying the bridge API for bridgeid. - # (on first gen Hue hub) - serial = discovery_info.get("serial") + bridge_id = discovery_info.get("serial") + + await self.async_set_unique_id(normalize_bridge_id(bridge_id)) return await self.async_step_import( { "host": host, # This format is the legacy format that Hue used for discovery - "path": f"phue-{serial}.conf", + "path": f"phue-{bridge_id}.conf", } ) @@ -180,6 +178,10 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): if host in configured_hosts(self.hass): return self.async_abort(reason="already_configured") + await self.async_set_unique_id( + normalize_bridge_id(homekit_info["properties"]["id"].replace(":", "")) + ) + return await self.async_step_import({"host": host}) async def async_step_import(self, import_info): @@ -234,18 +236,9 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): host = bridge.host bridge_id = bridge.config.bridgeid - same_hub_entries = [ - entry.entry_id - for entry in self.hass.config_entries.async_entries(DOMAIN) - if entry.data["bridge_id"] == bridge_id or entry.data["host"] == host - ] - - if same_hub_entries: - await asyncio.wait( - [ - self.hass.config_entries.async_remove(entry_id) - for entry_id in same_hub_entries - ] + if self.unique_id is None: + await self.async_set_unique_id( + normalize_bridge_id(bridge_id), raise_on_progress=False ) return self.async_create_entry( diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 07a287c387cb..09ee186da0ff 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -2,7 +2,7 @@ import asyncio import functools import logging -from typing import Any, Callable, Dict, List, Optional, Set, cast +from typing import Any, Callable, Dict, List, Optional, Set, Union, cast import uuid import weakref @@ -75,6 +75,10 @@ class OperationNotAllowed(ConfigError): """Raised when a config entry operation is not allowed.""" +class UniqueIdInProgress(data_entry_flow.AbortFlow): + """Error to indicate that the unique Id is in progress.""" + + class ConfigEntry: """Hold a configuration entry.""" @@ -85,6 +89,7 @@ class ConfigEntry: "title", "data", "options", + "unique_id", "system_options", "source", "connection_class", @@ -104,6 +109,7 @@ class ConfigEntry: connection_class: str, system_options: dict, options: Optional[dict] = None, + unique_id: Optional[str] = None, entry_id: Optional[str] = None, state: str = ENTRY_STATE_NOT_LOADED, ) -> None: @@ -138,6 +144,9 @@ class ConfigEntry: # State of the entry (LOADED, NOT_LOADED) self.state = state + # Unique ID of this entry. + self.unique_id = unique_id + # Listeners to call on update self.update_listeners: List = [] @@ -533,11 +542,15 @@ class ConfigEntries: self, entry: ConfigEntry, *, + unique_id: Union[str, dict, None] = _UNDEF, data: dict = _UNDEF, options: dict = _UNDEF, system_options: dict = _UNDEF, ) -> None: """Update a config entry.""" + if unique_id is not _UNDEF: + entry.unique_id = cast(Optional[str], unique_id) + if data is not _UNDEF: entry.data = data @@ -602,6 +615,25 @@ class ConfigEntries: if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: return result + # Check if config entry exists with unique ID. Unload it. + existing_entry = None + unique_id = flow.context.get("unique_id") + + if unique_id is not None: + for check_entry in self.async_entries(result["handler"]): + if check_entry.unique_id == unique_id: + existing_entry = check_entry + break + + # Unload the entry before setting up the new one. + # We will remove it only after the other one is set up, + # so that device customizations are not getting lost. + if ( + existing_entry is not None + and existing_entry.state not in UNRECOVERABLE_STATES + ): + await self.async_unload(existing_entry.entry_id) + entry = ConfigEntry( version=result["version"], domain=result["handler"], @@ -611,12 +643,16 @@ class ConfigEntries: system_options={}, source=flow.context["source"], connection_class=flow.CONNECTION_CLASS, + unique_id=unique_id, ) self._entries.append(entry) self._async_schedule_save() await self.async_setup(entry.entry_id) + if existing_entry is not None: + await self.async_remove(existing_entry.entry_id) + result["result"] = entry return result @@ -687,6 +723,8 @@ async def _old_conf_migrator(old_config: Dict[str, Any]) -> Dict[str, Any]: class ConfigFlow(data_entry_flow.FlowHandler): """Base class for config flows with some helpers.""" + unique_id = None + def __init_subclass__(cls, domain: Optional[str] = None, **kwargs: Any) -> None: """Initialize a subclass, register if possible.""" super().__init_subclass__(**kwargs) # type: ignore @@ -701,6 +739,27 @@ class ConfigFlow(data_entry_flow.FlowHandler): """Get the options flow for this handler.""" raise data_entry_flow.UnknownHandler + async def async_set_unique_id( + self, unique_id: str, *, raise_on_progress: bool = True + ) -> Optional[ConfigEntry]: + """Set a unique ID for the config flow. + + Returns optionally existing config entry with same ID. + """ + if raise_on_progress: + for progress in self._async_in_progress(): + if progress["context"].get("unique_id") == unique_id: + raise UniqueIdInProgress("already_in_progress") + + # pylint: disable=no-member + self.context["unique_id"] = unique_id + + for entry in self._async_current_entries(): + if entry.unique_id == unique_id: + return entry + + return None + @callback def _async_current_entries(self) -> List[ConfigEntry]: """Return current entries.""" diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index e7432cd52f75..7c2b4ab6ddc6 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -1,6 +1,6 @@ """Classes to help gather user submissions.""" import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, cast import uuid import voluptuous as vol @@ -36,6 +36,16 @@ class UnknownStep(FlowError): """Unknown step specified.""" +class AbortFlow(FlowError): + """Exception to indicate a flow needs to be aborted.""" + + def __init__(self, reason: str, description_placeholders: Optional[Dict] = None): + """Initialize an abort flow exception.""" + super().__init__(f"Flow aborted: {reason}") + self.reason = reason + self.description_placeholders = description_placeholders + + class FlowManager: """Manage all the flows that are in progress.""" @@ -131,7 +141,12 @@ class FlowManager: ) ) - result: Dict = await getattr(flow, method)(user_input) + try: + result: Dict = await getattr(flow, method)(user_input) + except AbortFlow as err: + result = _create_abort_data( + flow.flow_id, flow.handler, err.reason, err.description_placeholders + ) if result["type"] not in ( RESULT_TYPE_FORM, @@ -228,13 +243,9 @@ class FlowHandler: self, *, reason: str, description_placeholders: Optional[Dict] = None ) -> Dict[str, Any]: """Abort the config flow.""" - return { - "type": RESULT_TYPE_ABORT, - "flow_id": self.flow_id, - "handler": self.handler, - "reason": reason, - "description_placeholders": description_placeholders, - } + return _create_abort_data( + self.flow_id, cast(str, self.handler), reason, description_placeholders + ) @callback def async_external_step( @@ -259,3 +270,20 @@ class FlowHandler: "handler": self.handler, "step_id": next_step_id, } + + +@callback +def _create_abort_data( + flow_id: str, + handler: str, + reason: str, + description_placeholders: Optional[Dict] = None, +) -> Dict[str, Any]: + """Return the definition of an external step for the user to take.""" + return { + "type": RESULT_TYPE_ABORT, + "flow_id": flow_id, + "handler": handler, + "reason": reason, + "description_placeholders": description_placeholders, + } diff --git a/tests/common.py b/tests/common.py index a54b3899698f..5d13da74e880 100644 --- a/tests/common.py +++ b/tests/common.py @@ -671,6 +671,7 @@ class MockConfigEntry(config_entries.ConfigEntry): options={}, system_options={}, connection_class=config_entries.CONN_CLASS_UNKNOWN, + unique_id=None, ): """Initialize a mock config entry.""" kwargs = { @@ -682,6 +683,7 @@ class MockConfigEntry(config_entries.ConfigEntry): "version": version, "title": title, "connection_class": connection_class, + "unique_id": unique_id, } if source is not None: kwargs["source"] = source diff --git a/tests/components/hue/test_config_flow.py b/tests/components/hue/test_config_flow.py index a6d221ef3234..030f6ade1fa7 100644 --- a/tests/components/hue/test_config_flow.py +++ b/tests/components/hue/test_config_flow.py @@ -19,6 +19,7 @@ async def test_flow_works(hass, aioclient_mock): flow = config_flow.HueFlowHandler() flow.hass = hass + flow.context = {} await flow.async_step_init() with patch("aiohue.Bridge") as mock_bridge: @@ -349,28 +350,33 @@ async def test_creating_entry_removes_entries_for_same_host_or_bridge(hass): accessible via a single IP. So when we create a new entry, we'll remove all existing entries that either have same IP or same bridge_id. """ - MockConfigEntry( - domain="hue", data={"host": "0.0.0.0", "bridge_id": "id-1234"} - ).add_to_hass(hass) + orig_entry = MockConfigEntry( + domain="hue", + data={"host": "0.0.0.0", "bridge_id": "id-1234"}, + unique_id="id-1234", + ) + orig_entry.add_to_hass(hass) MockConfigEntry( - domain="hue", data={"host": "1.2.3.4", "bridge_id": "id-1234"} + domain="hue", + data={"host": "1.2.3.4", "bridge_id": "id-5678"}, + unique_id="id-5678", ).add_to_hass(hass) assert len(hass.config_entries.async_entries("hue")) == 2 - flow = config_flow.HueFlowHandler() - flow.hass = hass - flow.context = {} - bridge = Mock() bridge.username = "username-abc" bridge.config.bridgeid = "id-1234" bridge.config.name = "Mock Bridge" bridge.host = "0.0.0.0" - with patch.object(config_flow, "get_bridge", return_value=mock_coro(bridge)): - result = await flow.async_step_import({"host": "0.0.0.0"}) + with patch.object( + config_flow, "_find_username_from_config", return_value="mock-user" + ), patch.object(config_flow, "get_bridge", return_value=mock_coro(bridge)): + result = await hass.config_entries.flow.async_init( + "hue", data={"host": "2.2.2.2"}, context={"source": "import"} + ) assert result["type"] == "create_entry" assert result["title"] == "Mock Bridge" @@ -379,9 +385,11 @@ async def test_creating_entry_removes_entries_for_same_host_or_bridge(hass): "bridge_id": "id-1234", "username": "username-abc", } - # We did not process the result of this entry but already removed the old - # ones. So we should have 0 entries. - assert len(hass.config_entries.async_entries("hue")) == 0 + entries = hass.config_entries.async_entries("hue") + assert len(entries) == 2 + new_entry = entries[-1] + assert orig_entry.entry_id != new_entry.entry_id + assert new_entry.unique_id == "id-1234" async def test_bridge_homekit(hass): @@ -398,6 +406,7 @@ async def test_bridge_homekit(hass): "host": "0.0.0.0", "serial": "1234", "manufacturerURL": config_flow.HUE_MANUFACTURERURL, + "properties": {"id": "aa:bb:cc:dd:ee:ff"}, } ) diff --git a/tests/components/hue/test_init.py b/tests/components/hue/test_init.py index 58f004ec5402..d064ff9f3406 100644 --- a/tests/components/hue/test_init.py +++ b/tests/components/hue/test_init.py @@ -175,3 +175,19 @@ async def test_unload_entry(hass): assert await hue.async_unload_entry(hass, entry) assert len(mock_bridge.return_value.async_reset.mock_calls) == 1 assert hass.data[hue.DOMAIN] == {} + + +async def test_setting_unique_id(hass): + """Test we set unique ID if not set yet.""" + entry = MockConfigEntry(domain=hue.DOMAIN, data={"host": "0.0.0.0"}) + entry.add_to_hass(hass) + + with patch.object(hue, "HueBridge") as mock_bridge, patch( + "homeassistant.helpers.device_registry.async_get_registry", + return_value=mock_coro(Mock()), + ): + mock_bridge.return_value.async_setup.return_value = mock_coro(True) + mock_bridge.return_value.api.config = Mock(bridgeid="mock-id") + assert await async_setup_component(hass, hue.DOMAIN, {}) is True + + assert entry.unique_id == "mock-id" diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 24a0b0939bef..a9ae4eb59ac2 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -1001,3 +1001,110 @@ async def test_reload_entry_entity_registry_works(hass): await hass.async_block_till_done() assert len(mock_unload_entry.mock_calls) == 1 + + +async def test_unqiue_id_persisted(hass, manager): + """Test that a unique ID is stored in the config entry.""" + mock_setup_entry = MagicMock(return_value=mock_coro(True)) + + mock_integration(hass, MockModule("comp", async_setup_entry=mock_setup_entry)) + mock_entity_platform(hass, "config_flow.comp", None) + + class TestFlow(config_entries.ConfigFlow): + + VERSION = 1 + + async def async_step_user(self, user_input=None): + await self.async_set_unique_id("mock-unique-id") + return self.async_create_entry(title="mock-title", data={}) + + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): + await manager.flow.async_init( + "comp", context={"source": config_entries.SOURCE_USER} + ) + + assert len(mock_setup_entry.mock_calls) == 1 + p_hass, p_entry = mock_setup_entry.mock_calls[0][1] + + assert p_hass is hass + assert p_entry.unique_id == "mock-unique-id" + + +async def test_unique_id_existing_entry(hass, manager): + """Test that we remove an entry if there already is an entry with unique ID.""" + hass.config.components.add("comp") + MockConfigEntry( + domain="comp", + state=config_entries.ENTRY_STATE_LOADED, + unique_id="mock-unique-id", + ).add_to_hass(hass) + + async_setup_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True)) + async_unload_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True)) + async_remove_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True)) + + mock_integration( + hass, + MockModule( + "comp", + async_setup_entry=async_setup_entry, + async_unload_entry=async_unload_entry, + async_remove_entry=async_remove_entry, + ), + ) + mock_entity_platform(hass, "config_flow.comp", None) + + class TestFlow(config_entries.ConfigFlow): + + VERSION = 1 + + async def async_step_user(self, user_input=None): + existing_entry = await self.async_set_unique_id("mock-unique-id") + + assert existing_entry is not None + + return self.async_create_entry(title="mock-title", data={"via": "flow"}) + + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): + result = await manager.flow.async_init( + "comp", context={"source": config_entries.SOURCE_USER} + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + + entries = hass.config_entries.async_entries("comp") + assert len(entries) == 1 + assert entries[0].data == {"via": "flow"} + + assert len(async_setup_entry.mock_calls) == 1 + assert len(async_unload_entry.mock_calls) == 1 + assert len(async_remove_entry.mock_calls) == 1 + + +async def test_unique_id_in_progress(hass, manager): + """Test that we abort if there is already a flow in progress with same unique id.""" + mock_integration(hass, MockModule("comp")) + mock_entity_platform(hass, "config_flow.comp", None) + + class TestFlow(config_entries.ConfigFlow): + + VERSION = 1 + + async def async_step_user(self, user_input=None): + await self.async_set_unique_id("mock-unique-id") + return self.async_show_form(step_id="discovery") + + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): + # Create one to be in progress + result = await manager.flow.async_init( + "comp", context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + + # Will be canceled + result2 = await manager.flow.async_init( + "comp", context={"source": config_entries.SOURCE_USER} + ) + + assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result2["reason"] == "already_in_progress" diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 175efebd7554..a6bdd2b5cb69 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -94,7 +94,7 @@ async def test_configure_two_steps(manager): async def test_show_form(manager): - """Test that abort removes the flow from progress.""" + """Test that we can show a form.""" schema = vol.Schema({vol.Required("username"): str, vol.Required("password"): str}) @manager.mock_reg_handler("test") @@ -271,3 +271,17 @@ async def test_external_step(hass, manager): result = await manager.async_configure(result["flow_id"]) assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result["title"] == "Hello" + + +async def test_abort_flow_exception(manager): + """Test that the AbortFlow exception works.""" + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + async def async_step_init(self, user_input=None): + raise data_entry_flow.AbortFlow("mock-reason", {"placeholder": "yo"}) + + form = await manager.async_init("test") + assert form["type"] == "abort" + assert form["reason"] == "mock-reason" + assert form["description_placeholders"] == {"placeholder": "yo"}