1
mirror of https://github.com/home-assistant/core synced 2024-07-30 21:18:57 +02:00

Warn when referencing missing devices/areas (#43787)

This commit is contained in:
Paulus Schoutsen 2020-12-01 08:01:27 +01:00 committed by GitHub
parent cf9598fe4f
commit cf5be049b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 162 additions and 55 deletions

View File

@ -1,5 +1,6 @@
"""Service calling related helpers."""
import asyncio
import dataclasses
from functools import partial, wraps
import logging
from typing import (
@ -37,8 +38,13 @@ from homeassistant.exceptions import (
Unauthorized,
UnknownUser,
)
from homeassistant.helpers import device_registry, entity_registry, template
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers import (
area_registry,
config_validation as cv,
device_registry,
entity_registry,
template,
)
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, TemplateVarsType
from homeassistant.loader import (
MAX_LOAD_CONCURRENTLY,
@ -64,6 +70,38 @@ _LOGGER = logging.getLogger(__name__)
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
@dataclasses.dataclass
class SelectedEntities:
"""Class to hold the selected entities."""
# Entities that were explicitly mentioned.
referenced: Set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: Set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found.
missing_devices: Set[str] = dataclasses.field(default_factory=set)
missing_areas: Set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: Set[str]) -> None:
"""Log about missing items."""
parts = []
for label, items in (
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
if not parts:
return
_LOGGER.warning("Unable to find referenced %s", ", ".join(parts))
@bind_hass
def call_from_config(
hass: HomeAssistantType,
@ -186,25 +224,25 @@ async def async_extract_entities(
if data_ent_id == ENTITY_MATCH_ALL:
return [entity for entity in entities if entity.available]
entity_ids = await async_extract_entity_ids(hass, service_call, expand_group)
referenced = await async_extract_referenced_entity_ids(
hass, service_call, expand_group
)
combined = referenced.referenced | referenced.indirectly_referenced
found = []
for entity in entities:
if entity.entity_id not in entity_ids:
if entity.entity_id not in combined:
continue
entity_ids.remove(entity.entity_id)
combined.remove(entity.entity_id)
if not entity.available:
continue
found.append(entity)
if entity_ids:
_LOGGER.warning(
"Unable to find referenced entities %s", ", ".join(sorted(entity_ids))
)
referenced.log_missing(referenced.referenced & combined)
return found
@ -213,10 +251,21 @@ async def async_extract_entities(
async def async_extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> Set[str]:
"""Extract a list of entity ids from a service call.
"""Extract a set of entity ids from a service call.
Will convert group entity ids to the entity ids it represents.
"""
referenced = await async_extract_referenced_entity_ids(
hass, service_call, expand_group
)
return referenced.referenced | referenced.indirectly_referenced
@bind_hass
async def async_extract_referenced_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a service call."""
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
device_ids = service_call.data.get(ATTR_DEVICE_ID)
area_ids = service_call.data.get(ATTR_AREA_ID)
@ -225,12 +274,14 @@ async def async_extract_entity_ids(
selects_device_ids = device_ids not in (None, ENTITY_MATCH_NONE)
selects_area_ids = area_ids not in (None, ENTITY_MATCH_NONE)
extracted: Set[str] = set()
selected = SelectedEntities()
if not selects_entity_ids and not selects_device_ids and not selects_area_ids:
return extracted
return selected
if selects_entity_ids:
assert entity_ids is not None
# Entity ID attr can be a list or a string
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
@ -238,58 +289,68 @@ async def async_extract_entity_ids(
if expand_group:
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
extracted.update(entity_ids)
selected.referenced.update(entity_ids)
if not selects_device_ids and not selects_area_ids:
return extracted
return selected
dev_reg, ent_reg = cast(
Tuple[device_registry.DeviceRegistry, entity_registry.EntityRegistry],
area_reg, dev_reg, ent_reg = cast(
Tuple[
area_registry.AreaRegistry,
device_registry.DeviceRegistry,
entity_registry.EntityRegistry,
],
await asyncio.gather(
area_registry.async_get_registry(hass),
device_registry.async_get_registry(hass),
entity_registry.async_get_registry(hass),
),
)
if not selects_device_ids:
picked_devices = set()
elif isinstance(device_ids, str):
picked_devices = {device_ids}
else:
assert isinstance(device_ids, list)
picked_devices = set(device_ids)
picked_devices = set()
if selects_device_ids:
if isinstance(device_ids, str):
picked_devices = {device_ids}
else:
assert isinstance(device_ids, list)
picked_devices = set(device_ids)
for device_id in picked_devices:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
if selects_area_ids:
if isinstance(area_ids, str):
area_ids = [area_ids]
assert area_ids is not None
assert isinstance(area_ids, list)
if isinstance(area_ids, str):
area_lookup = {area_ids}
else:
area_lookup = set(area_ids)
for area_id in area_lookup:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
continue
# Find entities tied to an area
extracted.update(
entry.entity_id
for area_id in area_ids
for entry in entity_registry.async_entries_for_area(ent_reg, area_id)
)
for entity_entry in ent_reg.entities.values():
if entity_entry.area_id in area_lookup:
selected.indirectly_referenced.add(entity_entry.entity_id)
picked_devices.update(
[
device.id
for area_id in area_ids
for device in device_registry.async_entries_for_area(dev_reg, area_id)
]
)
# Find devices for this area
for device_entry in dev_reg.devices.values():
if device_entry.area_id in area_lookup:
picked_devices.add(device_entry.id)
if not picked_devices:
return extracted
return selected
extracted.update(
entity_entry.entity_id
for entity_entry in ent_reg.entities.values()
if not entity_entry.area_id and entity_entry.device_id in picked_devices
)
for entity_entry in ent_reg.entities.values():
if not entity_entry.area_id and entity_entry.device_id in picked_devices:
selected.indirectly_referenced.add(entity_entry.entity_id)
return extracted
return selected
def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JSON_TYPE:
@ -416,9 +477,13 @@ async def entity_service_call(
target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL
if not target_all_entities:
if target_all_entities:
referenced: Optional[SelectedEntities] = None
all_referenced: Optional[Set[str]] = None
else:
# A set of entities we're trying to target.
entity_ids = await async_extract_entity_ids(hass, call, True)
referenced = await async_extract_referenced_entity_ids(hass, call, True)
all_referenced = referenced.referenced | referenced.indirectly_referenced
# If the service function is a string, we'll pass it the service call data
if isinstance(func, str):
@ -441,11 +506,12 @@ async def entity_service_call(
if target_all_entities:
entity_candidates.extend(platform.entities.values())
else:
assert all_referenced is not None
entity_candidates.extend(
[
entity
for entity in platform.entities.values()
if entity.entity_id in entity_ids
if entity.entity_id in all_referenced
]
)
@ -462,11 +528,13 @@ async def entity_service_call(
)
else:
assert all_referenced is not None
for platform in platforms:
platform_entities = []
for entity in platform.entities.values():
if entity.entity_id not in entity_ids:
if entity.entity_id not in all_referenced:
continue
if not entity_perms(entity.entity_id, POLICY_CONTROL):
@ -481,13 +549,15 @@ async def entity_service_call(
entity_candidates.extend(platform_entities)
if not target_all_entities:
for entity in entity_candidates:
entity_ids.remove(entity.entity_id)
assert referenced is not None
if entity_ids:
_LOGGER.warning(
"Unable to find referenced entities %s", ", ".join(sorted(entity_ids))
)
# Only report on explicit referenced entities
missing = set(referenced.referenced)
for entity in entity_candidates:
missing.discard(entity.entity_id)
referenced.log_missing(missing)
entities = []

View File

@ -960,3 +960,40 @@ async def test_extract_from_service_area_id(hass, area_mock):
"light.in_area",
"light.no_area",
]
async def test_entity_service_call_warn_referenced(hass, caplog):
"""Test we only warn for referenced entities in entity_service_call."""
call = ha.ServiceCall(
"light",
"turn_on",
{
"area_id": "non-existent-area",
"entity_id": "non.existent",
"device_id": "non-existent-device",
},
)
await service.entity_service_call(hass, {}, "", call)
assert (
"Unable to find referenced areas non-existent-area, devices non-existent-device, entities non.existent"
in caplog.text
)
async def test_async_extract_entities_warn_referenced(hass, caplog):
"""Test we only warn for referenced entities in async_extract_entities."""
call = ha.ServiceCall(
"light",
"turn_on",
{
"area_id": "non-existent-area",
"entity_id": "non.existent",
"device_id": "non-existent-device",
},
)
extracted = await service.async_extract_entities(hass, {}, call)
assert len(extracted) == 0
assert (
"Unable to find referenced areas non-existent-area, devices non-existent-device, entities non.existent"
in caplog.text
)