1
mirror of https://github.com/home-assistant/core synced 2024-07-15 09:42:11 +02:00

Index entities by domain for entity services (#106759)

This commit is contained in:
J. Nick Koston 2024-01-02 04:28:58 -10:00 committed by GitHub
parent bf0d891f68
commit 09b65f14b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 60 deletions

View File

@ -14,6 +14,7 @@ from homeassistant.const import (
)
from homeassistant.core import HomeAssistant, ServiceCall, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_platform import async_get_platforms
from homeassistant.helpers.service import entity_service_call
@ -120,6 +121,14 @@ SERVICE_SEND_PROGRAM_COMMAND_SCHEMA = vol.All(
)
def async_get_entities(hass: HomeAssistant) -> dict[str, Entity]:
"""Get entities for a domain."""
entities: dict[str, Entity] = {}
for platform in async_get_platforms(hass, DOMAIN):
entities.update(platform.entities)
return entities
@callback
def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
"""Create and register services for the ISY integration."""
@ -159,7 +168,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
async def _async_send_raw_node_command(call: ServiceCall) -> None:
await entity_service_call(
hass, async_get_platforms(hass, DOMAIN), "async_send_raw_node_command", call
hass, async_get_entities(hass), "async_send_raw_node_command", call
)
hass.services.async_register(
@ -171,7 +180,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
async def _async_send_node_command(call: ServiceCall) -> None:
await entity_service_call(
hass, async_get_platforms(hass, DOMAIN), "async_send_node_command", call
hass, async_get_entities(hass), "async_send_node_command", call
)
hass.services.async_register(
@ -183,7 +192,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
async def _async_get_zwave_parameter(call: ServiceCall) -> None:
await entity_service_call(
hass, async_get_platforms(hass, DOMAIN), "async_get_zwave_parameter", call
hass, async_get_entities(hass), "async_get_zwave_parameter", call
)
hass.services.async_register(
@ -195,7 +204,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
async def _async_set_zwave_parameter(call: ServiceCall) -> None:
await entity_service_call(
hass, async_get_platforms(hass, DOMAIN), "async_set_zwave_parameter", call
hass, async_get_entities(hass), "async_set_zwave_parameter", call
)
hass.services.async_register(
@ -207,7 +216,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
async def _async_rename_node(call: ServiceCall) -> None:
await entity_service_call(
hass, async_get_platforms(hass, DOMAIN), "async_rename_node", call
hass, async_get_entities(hass), "async_rename_node", call
)
hass.services.async_register(

View File

@ -89,12 +89,13 @@ class EntityComponent(Generic[_EntityT]):
self.config: ConfigType | None = None
domain_platform = self._async_init_entity_platform(domain, None)
self._platforms: dict[
str | tuple[str, timedelta | None, str | None], EntityPlatform
] = {domain: self._async_init_entity_platform(domain, None)}
self.async_add_entities = self._platforms[domain].async_add_entities
self.add_entities = self._platforms[domain].add_entities
] = {domain: domain_platform}
self.async_add_entities = domain_platform.async_add_entities
self.add_entities = domain_platform.add_entities
self._entities: dict[str, entity.Entity] = domain_platform.domain_entities
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
@property
@ -105,18 +106,11 @@ class EntityComponent(Generic[_EntityT]):
callers that iterate over this asynchronously should make a copy
using list() before iterating.
"""
return chain.from_iterable(
platform.entities.values() # type: ignore[misc]
for platform in self._platforms.values()
)
return self._entities.values() # type: ignore[return-value]
def get_entity(self, entity_id: str) -> _EntityT | None:
"""Get an entity."""
for platform in self._platforms.values():
entity_obj = platform.entities.get(entity_id)
if entity_obj is not None:
return entity_obj # type: ignore[return-value]
return None
return self._entities.get(entity_id) # type: ignore[return-value]
def register_shutdown(self) -> None:
"""Register shutdown on Home Assistant STOP event.
@ -237,7 +231,7 @@ class EntityComponent(Generic[_EntityT]):
"""Handle the service."""
result = await service.entity_service_call(
self.hass, self._platforms.values(), func, call, required_features
self.hass, self._entities, func, call, required_features
)
if result:
@ -270,7 +264,7 @@ class EntityComponent(Generic[_EntityT]):
) -> EntityServiceResponse | None:
"""Handle the service."""
return await service.entity_service_call(
self.hass, self._platforms.values(), func, call, required_features
self.hass, self._entities, func, call, required_features
)
self.hass.services.async_register(

View File

@ -55,6 +55,7 @@ SLOW_ADD_MIN_TIMEOUT = 500
PLATFORM_NOT_READY_RETRIES = 10
DATA_ENTITY_PLATFORM = "entity_platform"
DATA_DOMAIN_ENTITIES = "domain_entities"
PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds
_LOGGER = getLogger(__name__)
@ -147,6 +148,10 @@ class EntityPlatform:
self.platform_name, []
).append(self)
self.domain_entities: dict[str, Entity] = hass.data.setdefault(
DATA_DOMAIN_ENTITIES, {}
).setdefault(domain, {})
def __repr__(self) -> str:
"""Represent an EntityPlatform."""
return (
@ -734,6 +739,7 @@ class EntityPlatform:
entity_id = entity.entity_id
self.entities[entity_id] = entity
self.domain_entities[entity_id] = entity
if not restored:
# Reserve the state in the state machine
@ -746,6 +752,7 @@ class EntityPlatform:
def remove_entity_cb() -> None:
"""Remove entity from entities dict."""
self.entities.pop(entity_id)
self.domain_entities.pop(entity_id)
entity.async_on_remove(remove_entity_cb)
@ -830,11 +837,7 @@ class EntityPlatform:
"""Handle the service."""
return await service.entity_service_call(
self.hass,
[
plf
for plf in self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name]
if plf.domain == self.domain
],
self.domain_entities,
func,
call,
required_features,

View File

@ -58,7 +58,6 @@ from .typing import ConfigType, TemplateVarsType
if TYPE_CHECKING:
from .entity import Entity
from .entity_platform import EntityPlatform
_EntityT = TypeVar("_EntityT", bound=Entity)
@ -741,7 +740,7 @@ def async_set_service_schema(
def _get_permissible_entity_candidates(
call: ServiceCall,
platforms: Iterable[EntityPlatform],
entities: dict[str, Entity],
entity_perms: None | (Callable[[str, str], bool]),
target_all_entities: bool,
all_referenced: set[str] | None,
@ -754,9 +753,8 @@ def _get_permissible_entity_candidates(
# is allowed to control.
return [
entity
for platform in platforms
for entity in platform.entities.values()
if entity_perms(entity.entity_id, POLICY_CONTROL)
for entity_id, entity in entities.items()
if entity_perms(entity_id, POLICY_CONTROL)
]
assert all_referenced is not None
@ -771,29 +769,26 @@ def _get_permissible_entity_candidates(
)
elif target_all_entities:
return [
entity for platform in platforms for entity in platform.entities.values()
]
return list(entities.values())
# We have already validated they have permissions to control all_referenced
# entities so we do not need to check again.
assert all_referenced is not None
if single_entity := len(all_referenced) == 1 and list(all_referenced)[0]:
for platform in platforms:
if (entity := platform.entities.get(single_entity)) is not None:
return [entity]
if TYPE_CHECKING:
assert all_referenced is not None
if (
len(all_referenced) == 1
and (single_entity := list(all_referenced)[0])
and (entity := entities.get(single_entity)) is not None
):
return [entity]
return [
platform.entities[entity_id]
for platform in platforms
for entity_id in all_referenced.intersection(platform.entities)
]
return [entities[entity_id] for entity_id in all_referenced.intersection(entities)]
@bind_hass
async def entity_service_call(
hass: HomeAssistant,
platforms: Iterable[EntityPlatform],
registered_entities: dict[str, Entity],
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
call: ServiceCall,
required_features: Iterable[int] | None = None,
@ -832,7 +827,7 @@ async def entity_service_call(
# A list with entities to call the service on.
entity_candidates = _get_permissible_entity_candidates(
call,
platforms,
registered_entities,
entity_perms,
target_all_entities,
all_referenced,

View File

@ -1406,7 +1406,9 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
with patch.dict(
hass.data[tts.DATA_TTS_MANAGER].providers, {}, clear=True
), patch.dict(hass.data[tts.DOMAIN]._platforms, {}, clear=True):
), patch.dict(hass.data[tts.DOMAIN]._platforms, {}, clear=True), patch.dict(
hass.data[tts.DOMAIN]._entities, {}, clear=True
):
assert tts.async_resolve_engine(hass, None) is None
with patch.dict(hass.data[tts.DATA_TTS_MANAGER].providers, {"cloud": object()}):

View File

@ -802,7 +802,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
test_service_mock,
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A],
@ -821,7 +821,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
with pytest.raises(exceptions.HomeAssistantError):
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
test_service_mock,
ServiceCall(
"test_domain", "test_service", {"entity_id": "light.living_room"}
@ -838,7 +838,7 @@ async def test_call_with_both_required_features(
test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
test_service_mock,
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A | SUPPORT_B],
@ -857,7 +857,7 @@ async def test_call_with_one_of_required_features(
test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
test_service_mock,
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A, SUPPORT_C],
@ -878,7 +878,7 @@ async def test_call_with_sync_func(hass: HomeAssistant, mock_entities) -> None:
test_service_mock = Mock(return_value=None)
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
test_service_mock,
ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
)
@ -890,7 +890,7 @@ async def test_call_with_sync_attr(hass: HomeAssistant, mock_entities) -> None:
mock_method = mock_entities["light.kitchen"].sync_method = Mock(return_value=None)
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
"sync_method",
ServiceCall(
"test_domain",
@ -908,7 +908,7 @@ async def test_call_context_user_not_exist(hass: HomeAssistant) -> None:
with pytest.raises(exceptions.UnknownUser) as err:
await service.entity_service_call(
hass,
[],
{},
Mock(),
ServiceCall(
"test_domain",
@ -935,7 +935,7 @@ async def test_call_context_target_all(
):
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
Mock(),
ServiceCall(
"test_domain",
@ -963,7 +963,7 @@ async def test_call_context_target_specific(
):
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
Mock(),
ServiceCall(
"test_domain",
@ -987,7 +987,7 @@ async def test_call_context_target_specific_no_auth(
):
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
Mock(),
ServiceCall(
"test_domain",
@ -1007,7 +1007,7 @@ async def test_call_no_context_target_all(
"""Check we target all if no user context given."""
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
Mock(),
ServiceCall(
"test_domain", "test_service", data={"entity_id": ENTITY_MATCH_ALL}
@ -1026,7 +1026,7 @@ async def test_call_no_context_target_specific(
"""Check we can target specified entities."""
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
Mock(),
ServiceCall(
"test_domain",
@ -1048,7 +1048,7 @@ async def test_call_with_match_all(
"""Check we only target allowed entities if targeting all."""
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
Mock(),
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
)
@ -1065,7 +1065,7 @@ async def test_call_with_omit_entity_id(
"""Check service call if we do not pass an entity ID."""
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
mock_entities,
Mock(),
ServiceCall("test_domain", "test_service"),
)