From 09b65f14b96d53812d44eb02f6fe603f1f2c5ace Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 2 Jan 2024 04:28:58 -1000 Subject: [PATCH] Index entities by domain for entity services (#106759) --- homeassistant/components/isy994/services.py | 19 ++++++++--- homeassistant/helpers/entity_component.py | 24 ++++++-------- homeassistant/helpers/entity_platform.py | 13 +++++--- homeassistant/helpers/service.py | 35 +++++++++------------ tests/components/tts/test_init.py | 4 ++- tests/helpers/test_service.py | 28 ++++++++--------- 6 files changed, 63 insertions(+), 60 deletions(-) diff --git a/homeassistant/components/isy994/services.py b/homeassistant/components/isy994/services.py index 7d7696755cfb..fec6c141915f 100644 --- a/homeassistant/components/isy994/services.py +++ b/homeassistant/components/isy994/services.py @@ -14,6 +14,7 @@ from homeassistant.const import ( ) from homeassistant.core import HomeAssistant, ServiceCall, callback import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_platform import async_get_platforms from homeassistant.helpers.service import entity_service_call @@ -120,6 +121,14 @@ SERVICE_SEND_PROGRAM_COMMAND_SCHEMA = vol.All( ) +def async_get_entities(hass: HomeAssistant) -> dict[str, Entity]: + """Get entities for a domain.""" + entities: dict[str, Entity] = {} + for platform in async_get_platforms(hass, DOMAIN): + entities.update(platform.entities) + return entities + + @callback def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901 """Create and register services for the ISY integration.""" @@ -159,7 +168,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901 async def _async_send_raw_node_command(call: ServiceCall) -> None: await entity_service_call( - hass, async_get_platforms(hass, DOMAIN), "async_send_raw_node_command", call + hass, async_get_entities(hass), "async_send_raw_node_command", call ) hass.services.async_register( @@ -171,7 +180,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901 async def _async_send_node_command(call: ServiceCall) -> None: await entity_service_call( - hass, async_get_platforms(hass, DOMAIN), "async_send_node_command", call + hass, async_get_entities(hass), "async_send_node_command", call ) hass.services.async_register( @@ -183,7 +192,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901 async def _async_get_zwave_parameter(call: ServiceCall) -> None: await entity_service_call( - hass, async_get_platforms(hass, DOMAIN), "async_get_zwave_parameter", call + hass, async_get_entities(hass), "async_get_zwave_parameter", call ) hass.services.async_register( @@ -195,7 +204,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901 async def _async_set_zwave_parameter(call: ServiceCall) -> None: await entity_service_call( - hass, async_get_platforms(hass, DOMAIN), "async_set_zwave_parameter", call + hass, async_get_entities(hass), "async_set_zwave_parameter", call ) hass.services.async_register( @@ -207,7 +216,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901 async def _async_rename_node(call: ServiceCall) -> None: await entity_service_call( - hass, async_get_platforms(hass, DOMAIN), "async_rename_node", call + hass, async_get_entities(hass), "async_rename_node", call ) hass.services.async_register( diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 30e892a8840d..b3eb8722997f 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -89,12 +89,13 @@ class EntityComponent(Generic[_EntityT]): self.config: ConfigType | None = None + domain_platform = self._async_init_entity_platform(domain, None) self._platforms: dict[ str | tuple[str, timedelta | None, str | None], EntityPlatform - ] = {domain: self._async_init_entity_platform(domain, None)} - self.async_add_entities = self._platforms[domain].async_add_entities - self.add_entities = self._platforms[domain].add_entities - + ] = {domain: domain_platform} + self.async_add_entities = domain_platform.async_add_entities + self.add_entities = domain_platform.add_entities + self._entities: dict[str, entity.Entity] = domain_platform.domain_entities hass.data.setdefault(DATA_INSTANCES, {})[domain] = self @property @@ -105,18 +106,11 @@ class EntityComponent(Generic[_EntityT]): callers that iterate over this asynchronously should make a copy using list() before iterating. """ - return chain.from_iterable( - platform.entities.values() # type: ignore[misc] - for platform in self._platforms.values() - ) + return self._entities.values() # type: ignore[return-value] def get_entity(self, entity_id: str) -> _EntityT | None: """Get an entity.""" - for platform in self._platforms.values(): - entity_obj = platform.entities.get(entity_id) - if entity_obj is not None: - return entity_obj # type: ignore[return-value] - return None + return self._entities.get(entity_id) # type: ignore[return-value] def register_shutdown(self) -> None: """Register shutdown on Home Assistant STOP event. @@ -237,7 +231,7 @@ class EntityComponent(Generic[_EntityT]): """Handle the service.""" result = await service.entity_service_call( - self.hass, self._platforms.values(), func, call, required_features + self.hass, self._entities, func, call, required_features ) if result: @@ -270,7 +264,7 @@ class EntityComponent(Generic[_EntityT]): ) -> EntityServiceResponse | None: """Handle the service.""" return await service.entity_service_call( - self.hass, self._platforms.values(), func, call, required_features + self.hass, self._entities, func, call, required_features ) self.hass.services.async_register( diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 221203902c59..1bf7d95135ba 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -55,6 +55,7 @@ SLOW_ADD_MIN_TIMEOUT = 500 PLATFORM_NOT_READY_RETRIES = 10 DATA_ENTITY_PLATFORM = "entity_platform" +DATA_DOMAIN_ENTITIES = "domain_entities" PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds _LOGGER = getLogger(__name__) @@ -147,6 +148,10 @@ class EntityPlatform: self.platform_name, [] ).append(self) + self.domain_entities: dict[str, Entity] = hass.data.setdefault( + DATA_DOMAIN_ENTITIES, {} + ).setdefault(domain, {}) + def __repr__(self) -> str: """Represent an EntityPlatform.""" return ( @@ -734,6 +739,7 @@ class EntityPlatform: entity_id = entity.entity_id self.entities[entity_id] = entity + self.domain_entities[entity_id] = entity if not restored: # Reserve the state in the state machine @@ -746,6 +752,7 @@ class EntityPlatform: def remove_entity_cb() -> None: """Remove entity from entities dict.""" self.entities.pop(entity_id) + self.domain_entities.pop(entity_id) entity.async_on_remove(remove_entity_cb) @@ -830,11 +837,7 @@ class EntityPlatform: """Handle the service.""" return await service.entity_service_call( self.hass, - [ - plf - for plf in self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name] - if plf.domain == self.domain - ], + self.domain_entities, func, call, required_features, diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 9af69acc6b28..59fd061d8c90 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -58,7 +58,6 @@ from .typing import ConfigType, TemplateVarsType if TYPE_CHECKING: from .entity import Entity - from .entity_platform import EntityPlatform _EntityT = TypeVar("_EntityT", bound=Entity) @@ -741,7 +740,7 @@ def async_set_service_schema( def _get_permissible_entity_candidates( call: ServiceCall, - platforms: Iterable[EntityPlatform], + entities: dict[str, Entity], entity_perms: None | (Callable[[str, str], bool]), target_all_entities: bool, all_referenced: set[str] | None, @@ -754,9 +753,8 @@ def _get_permissible_entity_candidates( # is allowed to control. return [ entity - for platform in platforms - for entity in platform.entities.values() - if entity_perms(entity.entity_id, POLICY_CONTROL) + for entity_id, entity in entities.items() + if entity_perms(entity_id, POLICY_CONTROL) ] assert all_referenced is not None @@ -771,29 +769,26 @@ def _get_permissible_entity_candidates( ) elif target_all_entities: - return [ - entity for platform in platforms for entity in platform.entities.values() - ] + return list(entities.values()) # We have already validated they have permissions to control all_referenced # entities so we do not need to check again. - assert all_referenced is not None - if single_entity := len(all_referenced) == 1 and list(all_referenced)[0]: - for platform in platforms: - if (entity := platform.entities.get(single_entity)) is not None: - return [entity] + if TYPE_CHECKING: + assert all_referenced is not None + if ( + len(all_referenced) == 1 + and (single_entity := list(all_referenced)[0]) + and (entity := entities.get(single_entity)) is not None + ): + return [entity] - return [ - platform.entities[entity_id] - for platform in platforms - for entity_id in all_referenced.intersection(platform.entities) - ] + return [entities[entity_id] for entity_id in all_referenced.intersection(entities)] @bind_hass async def entity_service_call( hass: HomeAssistant, - platforms: Iterable[EntityPlatform], + registered_entities: dict[str, Entity], func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]], call: ServiceCall, required_features: Iterable[int] | None = None, @@ -832,7 +827,7 @@ async def entity_service_call( # A list with entities to call the service on. entity_candidates = _get_permissible_entity_candidates( call, - platforms, + registered_entities, entity_perms, target_all_entities, all_referenced, diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 990d8d273ed2..d56542b2a57b 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -1406,7 +1406,9 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None with patch.dict( hass.data[tts.DATA_TTS_MANAGER].providers, {}, clear=True - ), patch.dict(hass.data[tts.DOMAIN]._platforms, {}, clear=True): + ), patch.dict(hass.data[tts.DOMAIN]._platforms, {}, clear=True), patch.dict( + hass.data[tts.DOMAIN]._entities, {}, clear=True + ): assert tts.async_resolve_engine(hass, None) is None with patch.dict(hass.data[tts.DATA_TTS_MANAGER].providers, {"cloud": object()}): diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 04324cdbfa33..628ead473d75 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -802,7 +802,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) - test_service_mock = AsyncMock(return_value=None) await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, test_service_mock, ServiceCall("test_domain", "test_service", {"entity_id": "all"}), required_features=[SUPPORT_A], @@ -821,7 +821,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) - with pytest.raises(exceptions.HomeAssistantError): await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, test_service_mock, ServiceCall( "test_domain", "test_service", {"entity_id": "light.living_room"} @@ -838,7 +838,7 @@ async def test_call_with_both_required_features( test_service_mock = AsyncMock(return_value=None) await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, test_service_mock, ServiceCall("test_domain", "test_service", {"entity_id": "all"}), required_features=[SUPPORT_A | SUPPORT_B], @@ -857,7 +857,7 @@ async def test_call_with_one_of_required_features( test_service_mock = AsyncMock(return_value=None) await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, test_service_mock, ServiceCall("test_domain", "test_service", {"entity_id": "all"}), required_features=[SUPPORT_A, SUPPORT_C], @@ -878,7 +878,7 @@ async def test_call_with_sync_func(hass: HomeAssistant, mock_entities) -> None: test_service_mock = Mock(return_value=None) await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, test_service_mock, ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}), ) @@ -890,7 +890,7 @@ async def test_call_with_sync_attr(hass: HomeAssistant, mock_entities) -> None: mock_method = mock_entities["light.kitchen"].sync_method = Mock(return_value=None) await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, "sync_method", ServiceCall( "test_domain", @@ -908,7 +908,7 @@ async def test_call_context_user_not_exist(hass: HomeAssistant) -> None: with pytest.raises(exceptions.UnknownUser) as err: await service.entity_service_call( hass, - [], + {}, Mock(), ServiceCall( "test_domain", @@ -935,7 +935,7 @@ async def test_call_context_target_all( ): await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, Mock(), ServiceCall( "test_domain", @@ -963,7 +963,7 @@ async def test_call_context_target_specific( ): await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, Mock(), ServiceCall( "test_domain", @@ -987,7 +987,7 @@ async def test_call_context_target_specific_no_auth( ): await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, Mock(), ServiceCall( "test_domain", @@ -1007,7 +1007,7 @@ async def test_call_no_context_target_all( """Check we target all if no user context given.""" await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, Mock(), ServiceCall( "test_domain", "test_service", data={"entity_id": ENTITY_MATCH_ALL} @@ -1026,7 +1026,7 @@ async def test_call_no_context_target_specific( """Check we can target specified entities.""" await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, Mock(), ServiceCall( "test_domain", @@ -1048,7 +1048,7 @@ async def test_call_with_match_all( """Check we only target allowed entities if targeting all.""" await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, Mock(), ServiceCall("test_domain", "test_service", {"entity_id": "all"}), ) @@ -1065,7 +1065,7 @@ async def test_call_with_omit_entity_id( """Check service call if we do not pass an entity ID.""" await service.entity_service_call( hass, - [Mock(entities=mock_entities)], + mock_entities, Mock(), ServiceCall("test_domain", "test_service"), )