1
mirror of https://github.com/home-assistant/core synced 2024-07-30 21:18:57 +02:00

Add area id to entity registry (#42221)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Robert Svensson 2020-10-24 21:25:28 +02:00 committed by GitHub
parent b54dde10ca
commit e06c8009e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 107 additions and 5 deletions

View File

@ -71,6 +71,7 @@ async def websocket_get_entity(hass, connection, msg):
# If passed in, we update value. Passing None will remove old value.
vol.Optional("name"): vol.Any(str, None),
vol.Optional("icon"): vol.Any(str, None),
vol.Optional("area_id"): vol.Any(str, None),
vol.Optional("new_entity_id"): str,
# We only allow setting disabled_by user via API.
vol.Optional("disabled_by"): vol.Any("user", None),
@ -91,7 +92,7 @@ async def websocket_update_entity(hass, connection, msg):
changes = {}
for key in ("name", "icon", "disabled_by"):
for key in ("name", "icon", "area_id", "disabled_by"):
if key in msg:
changes[key] = msg[key]
@ -149,6 +150,7 @@ def _entry_dict(entry):
return {
"config_entry_id": entry.config_entry_id,
"device_id": entry.device_id,
"area_id": entry.area_id,
"disabled_by": entry.disabled_by,
"entity_id": entry.entity_id,
"name": entry.name,

View File

@ -122,6 +122,10 @@ class Searcher:
"""Resolve an area."""
for device in device_registry.async_entries_for_area(self._device_reg, area_id):
self._add_or_resolve("device", device.id)
for entity_entry in entity_registry.async_entries_for_area(
self._entity_reg, area_id
):
self._add_or_resolve("entity", entity_entry.entity_id)
@callback
def _resolve_device(self, device_id) -> None:

View File

@ -1,5 +1,5 @@
"""Provide a way to connect devices to one physical location."""
from asyncio import Event
from asyncio import Event, gather
from collections import OrderedDict
from typing import Dict, Iterable, List, MutableMapping, Optional, cast
@ -64,8 +64,12 @@ class AreaRegistry:
async def async_delete(self, area_id: str) -> None:
"""Delete area."""
device_registry = await self.hass.helpers.device_registry.async_get_registry()
device_registry, entity_registry = await gather(
self.hass.helpers.device_registry.async_get_registry(),
self.hass.helpers.entity_registry.async_get_registry(),
)
device_registry.async_clear_area_id(area_id)
entity_registry.async_clear_area_id(area_id)
del self.areas[area_id]

View File

@ -83,6 +83,7 @@ class RegistryEntry:
name: Optional[str] = attr.ib(default=None)
icon: Optional[str] = attr.ib(default=None)
device_id: Optional[str] = attr.ib(default=None)
area_id: Optional[str] = attr.ib(default=None)
config_entry_id: Optional[str] = attr.ib(default=None)
disabled_by: Optional[str] = attr.ib(
default=None,
@ -204,6 +205,7 @@ class EntityRegistry:
# Data that we want entry to have
config_entry: Optional["ConfigEntry"] = None,
device_id: Optional[str] = None,
area_id: Optional[str] = None,
capabilities: Optional[Dict[str, Any]] = None,
supported_features: Optional[int] = None,
device_class: Optional[str] = None,
@ -223,6 +225,7 @@ class EntityRegistry:
entity_id,
config_entry_id=config_entry_id or _UNDEF,
device_id=device_id or _UNDEF,
area_id=area_id or _UNDEF,
capabilities=capabilities or _UNDEF,
supported_features=supported_features or _UNDEF,
device_class=device_class or _UNDEF,
@ -253,6 +256,7 @@ class EntityRegistry:
entity_id=entity_id,
config_entry_id=config_entry_id,
device_id=device_id,
area_id=area_id,
unique_id=unique_id,
platform=platform,
disabled_by=disabled_by,
@ -302,6 +306,7 @@ class EntityRegistry:
*,
name=_UNDEF,
icon=_UNDEF,
area_id=_UNDEF,
new_entity_id=_UNDEF,
new_unique_id=_UNDEF,
disabled_by=_UNDEF,
@ -313,6 +318,7 @@ class EntityRegistry:
entity_id,
name=name,
icon=icon,
area_id=area_id,
new_entity_id=new_entity_id,
new_unique_id=new_unique_id,
disabled_by=disabled_by,
@ -329,6 +335,7 @@ class EntityRegistry:
config_entry_id=_UNDEF,
new_entity_id=_UNDEF,
device_id=_UNDEF,
area_id=_UNDEF,
new_unique_id=_UNDEF,
disabled_by=_UNDEF,
capabilities=_UNDEF,
@ -348,6 +355,7 @@ class EntityRegistry:
("icon", icon),
("config_entry_id", config_entry_id),
("device_id", device_id),
("area_id", area_id),
("disabled_by", disabled_by),
("capabilities", capabilities),
("supported_features", supported_features),
@ -425,6 +433,7 @@ class EntityRegistry:
entity_id=entity["entity_id"],
config_entry_id=entity.get("config_entry_id"),
device_id=entity.get("device_id"),
area_id=entity.get("area_id"),
unique_id=entity["unique_id"],
platform=entity["platform"],
name=entity.get("name"),
@ -456,6 +465,7 @@ class EntityRegistry:
"entity_id": entry.entity_id,
"config_entry_id": entry.config_entry_id,
"device_id": entry.device_id,
"area_id": entry.area_id,
"unique_id": entry.unique_id,
"platform": entry.platform,
"name": entry.name,
@ -483,6 +493,13 @@ class EntityRegistry:
]:
self.async_remove(entity_id)
@callback
def async_clear_area_id(self, area_id: str) -> None:
"""Clear area id from registry entries."""
for entity_id, entry in self.entities.items():
if area_id == entry.area_id:
self._async_update_entity(entity_id, area_id=None) # type: ignore
def _register_entry(self, entry: RegistryEntry) -> None:
self.entities[entry.entity_id] = entry
self._add_index(entry)
@ -521,6 +538,14 @@ def async_entries_for_device(
]
@callback
def async_entries_for_area(
registry: EntityRegistry, area_id: str
) -> List[RegistryEntry]:
"""Return entries that match an area."""
return [entry for entry in registry.entities.values() if entry.area_id == area_id]
@callback
def async_entries_for_config_entry(
registry: EntityRegistry, config_entry_id: str

View File

@ -234,6 +234,15 @@ async def async_extract_entity_ids(
hass.helpers.device_registry.async_get_registry(),
hass.helpers.entity_registry.async_get_registry(),
)
extracted.update(
entry.entity_id
for area_id in area_ids
for entry in hass.helpers.entity_registry.async_entries_for_area(
ent_reg, area_id
)
)
devices = [
device
for area_id in area_ids
@ -247,6 +256,7 @@ async def async_extract_entity_ids(
for entry in hass.helpers.entity_registry.async_entries_for_device(
ent_reg, device.id
)
if not entry.area_id
)
return extracted

View File

@ -39,6 +39,7 @@ async def test_list_entities(hass, client):
{
"config_entry_id": None,
"device_id": None,
"area_id": None,
"disabled_by": None,
"entity_id": "test_domain.name",
"name": "Hello World",
@ -48,6 +49,7 @@ async def test_list_entities(hass, client):
{
"config_entry_id": None,
"device_id": None,
"area_id": None,
"disabled_by": None,
"entity_id": "test_domain.no_name",
"name": None,
@ -84,6 +86,7 @@ async def test_get_entity(hass, client):
assert msg["result"] == {
"config_entry_id": None,
"device_id": None,
"area_id": None,
"disabled_by": None,
"platform": "test_platform",
"entity_id": "test_domain.name",
@ -107,6 +110,7 @@ async def test_get_entity(hass, client):
assert msg["result"] == {
"config_entry_id": None,
"device_id": None,
"area_id": None,
"disabled_by": None,
"platform": "test_platform",
"entity_id": "test_domain.no_name",
@ -143,7 +147,7 @@ async def test_update_entity(hass, client):
assert state.name == "before update"
assert state.attributes[ATTR_ICON] == "icon:before update"
# UPDATE NAME & ICON
# UPDATE NAME & ICON & AREA
await client.send_json(
{
"id": 6,
@ -151,6 +155,7 @@ async def test_update_entity(hass, client):
"entity_id": "test_domain.world",
"name": "after update",
"icon": "icon:after update",
"area_id": "mock-area-id",
}
)
@ -159,6 +164,7 @@ async def test_update_entity(hass, client):
assert msg["result"] == {
"config_entry_id": None,
"device_id": None,
"area_id": "mock-area-id",
"disabled_by": None,
"platform": "test_platform",
"entity_id": "test_domain.world",
@ -204,6 +210,7 @@ async def test_update_entity(hass, client):
assert msg["result"] == {
"config_entry_id": None,
"device_id": None,
"area_id": "mock-area-id",
"disabled_by": None,
"platform": "test_platform",
"entity_id": "test_domain.world",
@ -252,6 +259,7 @@ async def test_update_entity_no_changes(hass, client):
assert msg["result"] == {
"config_entry_id": None,
"device_id": None,
"area_id": None,
"disabled_by": None,
"platform": "test_platform",
"entity_id": "test_domain.world",
@ -329,6 +337,7 @@ async def test_update_entity_id(hass, client):
assert msg["result"] == {
"config_entry_id": None,
"device_id": None,
"area_id": None,
"disabled_by": None,
"platform": "test_platform",
"entity_id": "test_domain.planet",

View File

@ -154,6 +154,7 @@ async def test_loading_saving_data(hass, registry):
"hue",
"5678",
device_id="mock-dev-id",
area_id="mock-area-id",
config_entry=mock_config,
capabilities={"max": 100},
supported_features=5,
@ -182,6 +183,7 @@ async def test_loading_saving_data(hass, registry):
assert orig_entry2 == new_entry2
assert new_entry2.device_id == "mock-dev-id"
assert new_entry2.area_id == "mock-area-id"
assert new_entry2.disabled_by == entity_registry.DISABLED_HASS
assert new_entry2.capabilities == {"max": 100}
assert new_entry2.supported_features == 5
@ -330,6 +332,19 @@ async def test_removing_config_entry_id(hass, registry, update_events):
assert update_events[1]["entity_id"] == entry.entity_id
async def test_removing_area_id(registry):
"""Make sure we can clear area id."""
entry = registry.async_get_or_create("light", "hue", "5678")
entry_w_area = registry.async_update_entity(entry.entity_id, area_id="12345A")
registry.async_clear_area_id("12345A")
entry_wo_area = registry.async_get(entry.entity_id)
assert not entry_wo_area.area_id
assert entry_w_area != entry_wo_area
async def test_migration(hass):
"""Test migration from old data to new."""
mock_config = MockConfigEntry(domain="test-platform", entry_id="test-config-id")

View File

@ -105,12 +105,32 @@ def area_mock(hass):
},
)
entity_in_own_area = ent_reg.RegistryEntry(
entity_id="light.in_own_area",
unique_id="in-own-area-id",
platform="test",
area_id="own-area",
)
entity_in_area = ent_reg.RegistryEntry(
entity_id="light.in_area",
unique_id="in-area-id",
platform="test",
device_id=device_in_area.id,
)
entity_in_other_area = ent_reg.RegistryEntry(
entity_id="light.in_other_area",
unique_id="in-other-area-id",
platform="test",
device_id=device_in_area.id,
area_id="other-area",
)
entity_assigned_to_area = ent_reg.RegistryEntry(
entity_id="light.assigned_to_area",
unique_id="assigned-area-id",
platform="test",
device_id=device_in_area.id,
area_id="test-area",
)
entity_no_area = ent_reg.RegistryEntry(
entity_id="light.no_area",
unique_id="no-area-id",
@ -126,7 +146,10 @@ def area_mock(hass):
mock_registry(
hass,
{
entity_in_own_area.entity_id: entity_in_own_area,
entity_in_area.entity_id: entity_in_area,
entity_in_other_area.entity_id: entity_in_other_area,
entity_assigned_to_area.entity_id: entity_assigned_to_area,
entity_no_area.entity_id: entity_no_area,
entity_diff_area.entity_id: entity_diff_area,
},
@ -298,15 +321,25 @@ async def test_extract_entity_ids(hass):
async def test_extract_entity_ids_from_area(hass, area_mock):
"""Test extract_entity_ids method with areas."""
call = ha.ServiceCall("light", "turn_on", {"area_id": "own-area"})
assert {
"light.in_own_area",
} == await service.async_extract_entity_ids(hass, call)
call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"})
assert {"light.in_area"} == await service.async_extract_entity_ids(hass, call)
assert {
"light.in_area",
"light.assigned_to_area",
} == await service.async_extract_entity_ids(hass, call)
call = ha.ServiceCall("light", "turn_on", {"area_id": ["test-area", "diff-area"]})
assert {
"light.in_area",
"light.diff_area",
"light.assigned_to_area",
} == await service.async_extract_entity_ids(hass, call)
assert (