From be69c81db520a32d45d9021d757f81c2bdf6a7eb Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 30 Jan 2023 22:46:25 -0600 Subject: [PATCH] Prioritize entity names over area names in Assist matching (#86982) * Refactor async_match_states * Check entity name after state, before aliases * Give entity name matches priority over area names * Don't force result to have area * Add area alias in tests * Move name/area list creation back * Clean up PR * More clean up --- .../components/conversation/default_agent.py | 39 ++++++-- homeassistant/helpers/intent.py | 89 +++++++++++++------ tests/components/conversation/test_init.py | 49 ++++++++++ tests/helpers/test_intent.py | 8 ++ 4 files changed, 148 insertions(+), 37 deletions(-) diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index c897d2e3b87..cabf9089b1c 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -11,7 +11,7 @@ import re from typing import IO, Any from hassil.intents import Intents, ResponseType, SlotList, TextSlotList -from hassil.recognize import recognize +from hassil.recognize import RecognizeResult, recognize_all from hassil.util import merge_dict from home_assistant_intents import get_intents import yaml @@ -128,7 +128,10 @@ class DefaultAgent(AbstractConversationAgent): } result = await self.hass.async_add_executor_job( - recognize, user_input.text, lang_intents.intents, slot_lists + self._recognize, + user_input, + lang_intents, + slot_lists, ) if result is None: _LOGGER.debug("No intent was matched for '%s'", user_input.text) @@ -197,6 +200,26 @@ class DefaultAgent(AbstractConversationAgent): response=intent_response, conversation_id=conversation_id ) + def _recognize( + self, + user_input: ConversationInput, + lang_intents: LanguageIntents, + slot_lists: dict[str, SlotList], + ) -> RecognizeResult | None: + """Search intents for a match to user input.""" + # Prioritize matches with entity names above area names + maybe_result: RecognizeResult | None = None + for result in recognize_all( + user_input.text, lang_intents.intents, slot_lists=slot_lists + ): + if "name" in result.entities: + return result + + # Keep looking in case an entity has the same name + maybe_result = result + + return maybe_result + async def async_reload(self, language: str | None = None): """Clear cached intents for a language.""" if language is None: @@ -373,19 +396,19 @@ class DefaultAgent(AbstractConversationAgent): if self._names_list is not None: return self._names_list states = self.hass.states.async_all() - registry = entity_registry.async_get(self.hass) + entities = entity_registry.async_get(self.hass) names = [] for state in states: context = {"domain": state.domain} - entry = registry.async_get(state.entity_id) - if entry is not None: - if entry.entity_category: + entity = entities.async_get(state.entity_id) + if entity is not None: + if entity.entity_category: # Skip configuration/diagnostic entities continue - if entry.aliases: - for alias in entry.aliases: + if entity.aliases: + for alias in entity.aliases: names.append((alias, state.entity_id, context)) # Default name diff --git a/homeassistant/helpers/intent.py b/homeassistant/helpers/intent.py index 511c2b2c009..58252da4822 100644 --- a/homeassistant/helpers/intent.py +++ b/homeassistant/helpers/intent.py @@ -138,15 +138,62 @@ def _has_name( if name in (state.entity_id, state.name.casefold()): return True - # Check aliases - if (entity is not None) and entity.aliases: - for alias in entity.aliases: - if name == alias.casefold(): - return True + # Check name/aliases + if (entity is None) or (not entity.aliases): + return False + + for alias in entity.aliases: + if name == alias.casefold(): + return True return False +def _find_area( + id_or_name: str, areas: area_registry.AreaRegistry +) -> area_registry.AreaEntry | None: + """Find an area by id or name, checking aliases too.""" + area = areas.async_get_area(id_or_name) or areas.async_get_area_by_name(id_or_name) + if area is not None: + return area + + # Check area aliases + for maybe_area in areas.areas.values(): + if not maybe_area.aliases: + continue + + for area_alias in maybe_area.aliases: + if id_or_name == area_alias.casefold(): + return maybe_area + + return None + + +def _filter_by_area( + states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]], + area: area_registry.AreaEntry, + devices: device_registry.DeviceRegistry, +) -> Iterable[tuple[State, entity_registry.RegistryEntry | None]]: + """Filter state/entity pairs by an area.""" + entity_area_ids: dict[str, str | None] = {} + for _state, entity in states_and_entities: + if entity is None: + continue + + if entity.area_id: + # Use entity's area id first + entity_area_ids[entity.id] = entity.area_id + elif entity.device_id: + # Fall back to device area if not set on entity + device = devices.async_get(entity.device_id) + if device is not None: + entity_area_ids[entity.id] = device.area_id + + for state, entity in states_and_entities: + if (entity is not None) and (entity_area_ids.get(entity.id) == area.id): + yield (state, entity) + + @callback @bind_hass def async_match_states( @@ -200,45 +247,29 @@ def async_match_states( if areas is None: areas = area_registry.async_get(hass) - # id or name - area = areas.async_get_area(area_name) or areas.async_get_area_by_name( - area_name - ) + area = _find_area(area_name, areas) assert area is not None, f"No area named {area_name}" if area is not None: + # Filter by states/entities by area if devices is None: devices = device_registry.async_get(hass) - entity_area_ids: dict[str, str | None] = {} - for _state, entity in states_and_entities: - if entity is None: - continue - - if entity.area_id: - # Use entity's area id first - entity_area_ids[entity.id] = entity.area_id - elif entity.device_id: - # Fall back to device area if not set on entity - device = devices.async_get(entity.device_id) - if device is not None: - entity_area_ids[entity.id] = device.area_id - - # Filter by area - states_and_entities = [ - (state, entity) - for state, entity in states_and_entities - if (entity is not None) and (entity_area_ids.get(entity.id) == area.id) - ] + states_and_entities = list(_filter_by_area(states_and_entities, area, devices)) if name is not None: + if devices is None: + devices = device_registry.async_get(hass) + # Filter by name name = name.casefold() + # Check states for state, entity in states_and_entities: if _has_name(state, entity, name): yield state break + else: # Not filtered by name for state, _entity in states_and_entities: diff --git a/tests/components/conversation/test_init.py b/tests/components/conversation/test_init.py index f4b386cbe4b..54fed8a6139 100644 --- a/tests/components/conversation/test_init.py +++ b/tests/components/conversation/test_init.py @@ -6,6 +6,7 @@ import pytest from homeassistant.components import conversation from homeassistant.components.cover import SERVICE_OPEN_COVER +from homeassistant.const import ATTR_FRIENDLY_NAME from homeassistant.core import DOMAIN as HASS_DOMAIN, Context from homeassistant.helpers import ( area_registry, @@ -777,3 +778,51 @@ async def test_turn_on_area(hass, init_components): assert call.domain == HASS_DOMAIN assert call.service == "turn_on" assert call.data == {"entity_id": "light.stove"} + + +async def test_light_area_same_name(hass, init_components): + """Test turning on a light with the same name as an area.""" + entities = entity_registry.async_get(hass) + devices = device_registry.async_get(hass) + areas = area_registry.async_get(hass) + entry = MockConfigEntry(domain="test") + + device = devices.async_get_or_create( + config_entry_id=entry.entry_id, + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + + kitchen_area = areas.async_create("kitchen") + devices.async_update_device(device.id, area_id=kitchen_area.id) + + kitchen_light = entities.async_get_or_create( + "light", "demo", "1234", original_name="kitchen light" + ) + entities.async_update_entity(kitchen_light.entity_id, area_id=kitchen_area.id) + hass.states.async_set( + kitchen_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "kitchen light"} + ) + + ceiling_light = entities.async_get_or_create( + "light", "demo", "5678", original_name="ceiling light" + ) + entities.async_update_entity(ceiling_light.entity_id, area_id=kitchen_area.id) + hass.states.async_set( + ceiling_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "ceiling light"} + ) + + calls = async_mock_service(hass, HASS_DOMAIN, "turn_on") + + await hass.services.async_call( + "conversation", + "process", + {conversation.ATTR_TEXT: "turn on kitchen light"}, + ) + await hass.async_block_till_done() + + # Should only turn on one light instead of all lights in the kitchen + assert len(calls) == 1 + call = calls[0] + assert call.domain == HASS_DOMAIN + assert call.service == "turn_on" + assert call.data == {"entity_id": kitchen_light.entity_id} diff --git a/tests/helpers/test_intent.py b/tests/helpers/test_intent.py index 11a54b3b529..14ada0b967d 100644 --- a/tests/helpers/test_intent.py +++ b/tests/helpers/test_intent.py @@ -27,6 +27,7 @@ async def test_async_match_states(hass): """Test async_match_state helper.""" areas = area_registry.async_get(hass) area_kitchen = areas.async_get_or_create("kitchen") + areas.async_update(area_kitchen.id, aliases={"food room"}) area_bedroom = areas.async_get_or_create("bedroom") state1 = State( @@ -68,6 +69,13 @@ async def test_async_match_states(hass): ) ) + # Test area alias + assert [state1] == list( + intent.async_match_states( + hass, name="kitchen light", area_name="food room", states=[state1, state2] + ) + ) + # Wrong area assert not list( intent.async_match_states(