1
mirror of https://github.com/home-assistant/core synced 2024-09-28 03:04:04 +02:00

Add unique ID to config entries (#29806)

* Add unique ID to config entries

* Unload existing entries with same unique ID if flow with unique ID is
finished

* Remove unused exception

* Fix typing

* silence pylint

* Fix tests

* Add unique ID to Hue

* Address typing comment

* Tweaks to comments

* lint
This commit is contained in:
Paulus Schoutsen 2019-12-16 12:27:43 +01:00 committed by GitHub
parent 87ca61ddd7
commit d851cb6f9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 305 additions and 46 deletions

View File

@ -4,11 +4,11 @@ import logging
import voluptuous as vol
from homeassistant import config_entries
from homeassistant import config_entries, core
from homeassistant.const import CONF_FILENAME, CONF_HOST
from homeassistant.helpers import config_validation as cv, device_registry as dr
from .bridge import HueBridge
from .bridge import HueBridge, normalize_bridge_id
from .config_flow import ( # Loading the config flow file will register the flow
configured_hosts,
)
@ -102,7 +102,9 @@ async def async_setup(hass, config):
return True
async def async_setup_entry(hass, entry):
async def async_setup_entry(
hass: core.HomeAssistant, entry: config_entries.ConfigEntry
):
"""Set up a bridge from a config entry."""
host = entry.data["host"]
config = hass.data[DATA_CONFIGS].get(host)
@ -121,6 +123,13 @@ async def async_setup_entry(hass, entry):
hass.data[DOMAIN][host] = bridge
config = bridge.api.config
# For backwards compat
if entry.unique_id is None:
hass.config_entries.async_update_entry(
entry, unique_id=normalize_bridge_id(config.bridgeid)
)
device_registry = await dr.async_get_registry(hass)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,

View File

@ -201,3 +201,25 @@ async def get_bridge(hass, host, username=None):
except aiohue.AiohueException:
LOGGER.exception("Unknown Hue linking error occurred")
raise AuthenticationRequired
def normalize_bridge_id(bridge_id: str):
"""Normalize a bridge identifier.
There are three sources where we receive bridge ID from:
- ssdp/upnp: <host>/description.xml, field root/device/serialNumber
- nupnp: "id" field
- Hue Bridge API: config.bridgeid
The SSDP/UPNP source does not contain the middle 4 characters compared
to the other sources. In all our tests the middle 4 characters are "fffe".
"""
if len(bridge_id) == 16:
return bridge_id[0:6] + bridge_id[-6:]
if len(bridge_id) == 12:
return bridge_id
LOGGER.warning("Unexpected bridge id number found: %s", bridge_id)
return bridge_id

View File

@ -12,7 +12,7 @@ from homeassistant.components.ssdp import ATTR_MANUFACTURERURL, ATTR_NAME
from homeassistant.core import callback
from homeassistant.helpers import aiohttp_client
from .bridge import get_bridge
from .bridge import get_bridge, normalize_bridge_id
from .const import DOMAIN, LOGGER
from .errors import AuthenticationRequired, CannotConnect
@ -154,17 +154,15 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
if host in configured_hosts(self.hass):
return self.async_abort(reason="already_configured")
# This value is based off host/description.xml and is, weirdly, missing
# 4 characters in the middle of the serial compared to results returned
# from the NUPNP API or when querying the bridge API for bridgeid.
# (on first gen Hue hub)
serial = discovery_info.get("serial")
bridge_id = discovery_info.get("serial")
await self.async_set_unique_id(normalize_bridge_id(bridge_id))
return await self.async_step_import(
{
"host": host,
# This format is the legacy format that Hue used for discovery
"path": f"phue-{serial}.conf",
"path": f"phue-{bridge_id}.conf",
}
)
@ -180,6 +178,10 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
if host in configured_hosts(self.hass):
return self.async_abort(reason="already_configured")
await self.async_set_unique_id(
normalize_bridge_id(homekit_info["properties"]["id"].replace(":", ""))
)
return await self.async_step_import({"host": host})
async def async_step_import(self, import_info):
@ -234,18 +236,9 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
host = bridge.host
bridge_id = bridge.config.bridgeid
same_hub_entries = [
entry.entry_id
for entry in self.hass.config_entries.async_entries(DOMAIN)
if entry.data["bridge_id"] == bridge_id or entry.data["host"] == host
]
if same_hub_entries:
await asyncio.wait(
[
self.hass.config_entries.async_remove(entry_id)
for entry_id in same_hub_entries
]
if self.unique_id is None:
await self.async_set_unique_id(
normalize_bridge_id(bridge_id), raise_on_progress=False
)
return self.async_create_entry(

View File

@ -2,7 +2,7 @@
import asyncio
import functools
import logging
from typing import Any, Callable, Dict, List, Optional, Set, cast
from typing import Any, Callable, Dict, List, Optional, Set, Union, cast
import uuid
import weakref
@ -75,6 +75,10 @@ class OperationNotAllowed(ConfigError):
"""Raised when a config entry operation is not allowed."""
class UniqueIdInProgress(data_entry_flow.AbortFlow):
"""Error to indicate that the unique Id is in progress."""
class ConfigEntry:
"""Hold a configuration entry."""
@ -85,6 +89,7 @@ class ConfigEntry:
"title",
"data",
"options",
"unique_id",
"system_options",
"source",
"connection_class",
@ -104,6 +109,7 @@ class ConfigEntry:
connection_class: str,
system_options: dict,
options: Optional[dict] = None,
unique_id: Optional[str] = None,
entry_id: Optional[str] = None,
state: str = ENTRY_STATE_NOT_LOADED,
) -> None:
@ -138,6 +144,9 @@ class ConfigEntry:
# State of the entry (LOADED, NOT_LOADED)
self.state = state
# Unique ID of this entry.
self.unique_id = unique_id
# Listeners to call on update
self.update_listeners: List = []
@ -533,11 +542,15 @@ class ConfigEntries:
self,
entry: ConfigEntry,
*,
unique_id: Union[str, dict, None] = _UNDEF,
data: dict = _UNDEF,
options: dict = _UNDEF,
system_options: dict = _UNDEF,
) -> None:
"""Update a config entry."""
if unique_id is not _UNDEF:
entry.unique_id = cast(Optional[str], unique_id)
if data is not _UNDEF:
entry.data = data
@ -602,6 +615,25 @@ class ConfigEntries:
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return result
# Check if config entry exists with unique ID. Unload it.
existing_entry = None
unique_id = flow.context.get("unique_id")
if unique_id is not None:
for check_entry in self.async_entries(result["handler"]):
if check_entry.unique_id == unique_id:
existing_entry = check_entry
break
# Unload the entry before setting up the new one.
# We will remove it only after the other one is set up,
# so that device customizations are not getting lost.
if (
existing_entry is not None
and existing_entry.state not in UNRECOVERABLE_STATES
):
await self.async_unload(existing_entry.entry_id)
entry = ConfigEntry(
version=result["version"],
domain=result["handler"],
@ -611,12 +643,16 @@ class ConfigEntries:
system_options={},
source=flow.context["source"],
connection_class=flow.CONNECTION_CLASS,
unique_id=unique_id,
)
self._entries.append(entry)
self._async_schedule_save()
await self.async_setup(entry.entry_id)
if existing_entry is not None:
await self.async_remove(existing_entry.entry_id)
result["result"] = entry
return result
@ -687,6 +723,8 @@ async def _old_conf_migrator(old_config: Dict[str, Any]) -> Dict[str, Any]:
class ConfigFlow(data_entry_flow.FlowHandler):
"""Base class for config flows with some helpers."""
unique_id = None
def __init_subclass__(cls, domain: Optional[str] = None, **kwargs: Any) -> None:
"""Initialize a subclass, register if possible."""
super().__init_subclass__(**kwargs) # type: ignore
@ -701,6 +739,27 @@ class ConfigFlow(data_entry_flow.FlowHandler):
"""Get the options flow for this handler."""
raise data_entry_flow.UnknownHandler
async def async_set_unique_id(
self, unique_id: str, *, raise_on_progress: bool = True
) -> Optional[ConfigEntry]:
"""Set a unique ID for the config flow.
Returns optionally existing config entry with same ID.
"""
if raise_on_progress:
for progress in self._async_in_progress():
if progress["context"].get("unique_id") == unique_id:
raise UniqueIdInProgress("already_in_progress")
# pylint: disable=no-member
self.context["unique_id"] = unique_id
for entry in self._async_current_entries():
if entry.unique_id == unique_id:
return entry
return None
@callback
def _async_current_entries(self) -> List[ConfigEntry]:
"""Return current entries."""

View File

@ -1,6 +1,6 @@
"""Classes to help gather user submissions."""
import logging
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, cast
import uuid
import voluptuous as vol
@ -36,6 +36,16 @@ class UnknownStep(FlowError):
"""Unknown step specified."""
class AbortFlow(FlowError):
"""Exception to indicate a flow needs to be aborted."""
def __init__(self, reason: str, description_placeholders: Optional[Dict] = None):
"""Initialize an abort flow exception."""
super().__init__(f"Flow aborted: {reason}")
self.reason = reason
self.description_placeholders = description_placeholders
class FlowManager:
"""Manage all the flows that are in progress."""
@ -131,7 +141,12 @@ class FlowManager:
)
)
result: Dict = await getattr(flow, method)(user_input)
try:
result: Dict = await getattr(flow, method)(user_input)
except AbortFlow as err:
result = _create_abort_data(
flow.flow_id, flow.handler, err.reason, err.description_placeholders
)
if result["type"] not in (
RESULT_TYPE_FORM,
@ -228,13 +243,9 @@ class FlowHandler:
self, *, reason: str, description_placeholders: Optional[Dict] = None
) -> Dict[str, Any]:
"""Abort the config flow."""
return {
"type": RESULT_TYPE_ABORT,
"flow_id": self.flow_id,
"handler": self.handler,
"reason": reason,
"description_placeholders": description_placeholders,
}
return _create_abort_data(
self.flow_id, cast(str, self.handler), reason, description_placeholders
)
@callback
def async_external_step(
@ -259,3 +270,20 @@ class FlowHandler:
"handler": self.handler,
"step_id": next_step_id,
}
@callback
def _create_abort_data(
flow_id: str,
handler: str,
reason: str,
description_placeholders: Optional[Dict] = None,
) -> Dict[str, Any]:
"""Return the definition of an external step for the user to take."""
return {
"type": RESULT_TYPE_ABORT,
"flow_id": flow_id,
"handler": handler,
"reason": reason,
"description_placeholders": description_placeholders,
}

View File

@ -671,6 +671,7 @@ class MockConfigEntry(config_entries.ConfigEntry):
options={},
system_options={},
connection_class=config_entries.CONN_CLASS_UNKNOWN,
unique_id=None,
):
"""Initialize a mock config entry."""
kwargs = {
@ -682,6 +683,7 @@ class MockConfigEntry(config_entries.ConfigEntry):
"version": version,
"title": title,
"connection_class": connection_class,
"unique_id": unique_id,
}
if source is not None:
kwargs["source"] = source

View File

@ -19,6 +19,7 @@ async def test_flow_works(hass, aioclient_mock):
flow = config_flow.HueFlowHandler()
flow.hass = hass
flow.context = {}
await flow.async_step_init()
with patch("aiohue.Bridge") as mock_bridge:
@ -349,28 +350,33 @@ async def test_creating_entry_removes_entries_for_same_host_or_bridge(hass):
accessible via a single IP. So when we create a new entry, we'll remove
all existing entries that either have same IP or same bridge_id.
"""
MockConfigEntry(
domain="hue", data={"host": "0.0.0.0", "bridge_id": "id-1234"}
).add_to_hass(hass)
orig_entry = MockConfigEntry(
domain="hue",
data={"host": "0.0.0.0", "bridge_id": "id-1234"},
unique_id="id-1234",
)
orig_entry.add_to_hass(hass)
MockConfigEntry(
domain="hue", data={"host": "1.2.3.4", "bridge_id": "id-1234"}
domain="hue",
data={"host": "1.2.3.4", "bridge_id": "id-5678"},
unique_id="id-5678",
).add_to_hass(hass)
assert len(hass.config_entries.async_entries("hue")) == 2
flow = config_flow.HueFlowHandler()
flow.hass = hass
flow.context = {}
bridge = Mock()
bridge.username = "username-abc"
bridge.config.bridgeid = "id-1234"
bridge.config.name = "Mock Bridge"
bridge.host = "0.0.0.0"
with patch.object(config_flow, "get_bridge", return_value=mock_coro(bridge)):
result = await flow.async_step_import({"host": "0.0.0.0"})
with patch.object(
config_flow, "_find_username_from_config", return_value="mock-user"
), patch.object(config_flow, "get_bridge", return_value=mock_coro(bridge)):
result = await hass.config_entries.flow.async_init(
"hue", data={"host": "2.2.2.2"}, context={"source": "import"}
)
assert result["type"] == "create_entry"
assert result["title"] == "Mock Bridge"
@ -379,9 +385,11 @@ async def test_creating_entry_removes_entries_for_same_host_or_bridge(hass):
"bridge_id": "id-1234",
"username": "username-abc",
}
# We did not process the result of this entry but already removed the old
# ones. So we should have 0 entries.
assert len(hass.config_entries.async_entries("hue")) == 0
entries = hass.config_entries.async_entries("hue")
assert len(entries) == 2
new_entry = entries[-1]
assert orig_entry.entry_id != new_entry.entry_id
assert new_entry.unique_id == "id-1234"
async def test_bridge_homekit(hass):
@ -398,6 +406,7 @@ async def test_bridge_homekit(hass):
"host": "0.0.0.0",
"serial": "1234",
"manufacturerURL": config_flow.HUE_MANUFACTURERURL,
"properties": {"id": "aa:bb:cc:dd:ee:ff"},
}
)

View File

@ -175,3 +175,19 @@ async def test_unload_entry(hass):
assert await hue.async_unload_entry(hass, entry)
assert len(mock_bridge.return_value.async_reset.mock_calls) == 1
assert hass.data[hue.DOMAIN] == {}
async def test_setting_unique_id(hass):
"""Test we set unique ID if not set yet."""
entry = MockConfigEntry(domain=hue.DOMAIN, data={"host": "0.0.0.0"})
entry.add_to_hass(hass)
with patch.object(hue, "HueBridge") as mock_bridge, patch(
"homeassistant.helpers.device_registry.async_get_registry",
return_value=mock_coro(Mock()),
):
mock_bridge.return_value.async_setup.return_value = mock_coro(True)
mock_bridge.return_value.api.config = Mock(bridgeid="mock-id")
assert await async_setup_component(hass, hue.DOMAIN, {}) is True
assert entry.unique_id == "mock-id"

View File

@ -1001,3 +1001,110 @@ async def test_reload_entry_entity_registry_works(hass):
await hass.async_block_till_done()
assert len(mock_unload_entry.mock_calls) == 1
async def test_unqiue_id_persisted(hass, manager):
"""Test that a unique ID is stored in the config entry."""
mock_setup_entry = MagicMock(return_value=mock_coro(True))
mock_integration(hass, MockModule("comp", async_setup_entry=mock_setup_entry))
mock_entity_platform(hass, "config_flow.comp", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 1
async def async_step_user(self, user_input=None):
await self.async_set_unique_id("mock-unique-id")
return self.async_create_entry(title="mock-title", data={})
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
await manager.flow.async_init(
"comp", context={"source": config_entries.SOURCE_USER}
)
assert len(mock_setup_entry.mock_calls) == 1
p_hass, p_entry = mock_setup_entry.mock_calls[0][1]
assert p_hass is hass
assert p_entry.unique_id == "mock-unique-id"
async def test_unique_id_existing_entry(hass, manager):
"""Test that we remove an entry if there already is an entry with unique ID."""
hass.config.components.add("comp")
MockConfigEntry(
domain="comp",
state=config_entries.ENTRY_STATE_LOADED,
unique_id="mock-unique-id",
).add_to_hass(hass)
async_setup_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True))
async_unload_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True))
async_remove_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True))
mock_integration(
hass,
MockModule(
"comp",
async_setup_entry=async_setup_entry,
async_unload_entry=async_unload_entry,
async_remove_entry=async_remove_entry,
),
)
mock_entity_platform(hass, "config_flow.comp", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 1
async def async_step_user(self, user_input=None):
existing_entry = await self.async_set_unique_id("mock-unique-id")
assert existing_entry is not None
return self.async_create_entry(title="mock-title", data={"via": "flow"})
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
result = await manager.flow.async_init(
"comp", context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
entries = hass.config_entries.async_entries("comp")
assert len(entries) == 1
assert entries[0].data == {"via": "flow"}
assert len(async_setup_entry.mock_calls) == 1
assert len(async_unload_entry.mock_calls) == 1
assert len(async_remove_entry.mock_calls) == 1
async def test_unique_id_in_progress(hass, manager):
"""Test that we abort if there is already a flow in progress with same unique id."""
mock_integration(hass, MockModule("comp"))
mock_entity_platform(hass, "config_flow.comp", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 1
async def async_step_user(self, user_input=None):
await self.async_set_unique_id("mock-unique-id")
return self.async_show_form(step_id="discovery")
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
# Create one to be in progress
result = await manager.flow.async_init(
"comp", context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
# Will be canceled
result2 = await manager.flow.async_init(
"comp", context={"source": config_entries.SOURCE_USER}
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result2["reason"] == "already_in_progress"

View File

@ -94,7 +94,7 @@ async def test_configure_two_steps(manager):
async def test_show_form(manager):
"""Test that abort removes the flow from progress."""
"""Test that we can show a form."""
schema = vol.Schema({vol.Required("username"): str, vol.Required("password"): str})
@manager.mock_reg_handler("test")
@ -271,3 +271,17 @@ async def test_external_step(hass, manager):
result = await manager.async_configure(result["flow_id"])
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["title"] == "Hello"
async def test_abort_flow_exception(manager):
"""Test that the AbortFlow exception works."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
raise data_entry_flow.AbortFlow("mock-reason", {"placeholder": "yo"})
form = await manager.async_init("test")
assert form["type"] == "abort"
assert form["reason"] == "mock-reason"
assert form["description_placeholders"] == {"placeholder": "yo"}