Prioritize entity names over area names in Assist matching (#86982)

* Refactor async_match_states

* Check entity name after state, before aliases

* Give entity name matches priority over area names

* Don't force result to have area

* Add area alias in tests

* Move name/area list creation back

* Clean up PR

* More clean up
This commit is contained in:
Michael Hansen 2023-01-30 22:46:25 -06:00 committed by GitHub
parent f8c6e4c20a
commit be69c81db5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 148 additions and 37 deletions

View File

@ -11,7 +11,7 @@ import re
from typing import IO, Any
from hassil.intents import Intents, ResponseType, SlotList, TextSlotList
from hassil.recognize import recognize
from hassil.recognize import RecognizeResult, recognize_all
from hassil.util import merge_dict
from home_assistant_intents import get_intents
import yaml
@ -128,7 +128,10 @@ class DefaultAgent(AbstractConversationAgent):
}
result = await self.hass.async_add_executor_job(
recognize, user_input.text, lang_intents.intents, slot_lists
self._recognize,
user_input,
lang_intents,
slot_lists,
)
if result is None:
_LOGGER.debug("No intent was matched for '%s'", user_input.text)
@ -197,6 +200,26 @@ class DefaultAgent(AbstractConversationAgent):
response=intent_response, conversation_id=conversation_id
)
def _recognize(
self,
user_input: ConversationInput,
lang_intents: LanguageIntents,
slot_lists: dict[str, SlotList],
) -> RecognizeResult | None:
"""Search intents for a match to user input."""
# Prioritize matches with entity names above area names
maybe_result: RecognizeResult | None = None
for result in recognize_all(
user_input.text, lang_intents.intents, slot_lists=slot_lists
):
if "name" in result.entities:
return result
# Keep looking in case an entity has the same name
maybe_result = result
return maybe_result
async def async_reload(self, language: str | None = None):
"""Clear cached intents for a language."""
if language is None:
@ -373,19 +396,19 @@ class DefaultAgent(AbstractConversationAgent):
if self._names_list is not None:
return self._names_list
states = self.hass.states.async_all()
registry = entity_registry.async_get(self.hass)
entities = entity_registry.async_get(self.hass)
names = []
for state in states:
context = {"domain": state.domain}
entry = registry.async_get(state.entity_id)
if entry is not None:
if entry.entity_category:
entity = entities.async_get(state.entity_id)
if entity is not None:
if entity.entity_category:
# Skip configuration/diagnostic entities
continue
if entry.aliases:
for alias in entry.aliases:
if entity.aliases:
for alias in entity.aliases:
names.append((alias, state.entity_id, context))
# Default name

View File

@ -138,15 +138,62 @@ def _has_name(
if name in (state.entity_id, state.name.casefold()):
return True
# Check aliases
if (entity is not None) and entity.aliases:
for alias in entity.aliases:
if name == alias.casefold():
return True
# Check name/aliases
if (entity is None) or (not entity.aliases):
return False
for alias in entity.aliases:
if name == alias.casefold():
return True
return False
def _find_area(
id_or_name: str, areas: area_registry.AreaRegistry
) -> area_registry.AreaEntry | None:
"""Find an area by id or name, checking aliases too."""
area = areas.async_get_area(id_or_name) or areas.async_get_area_by_name(id_or_name)
if area is not None:
return area
# Check area aliases
for maybe_area in areas.areas.values():
if not maybe_area.aliases:
continue
for area_alias in maybe_area.aliases:
if id_or_name == area_alias.casefold():
return maybe_area
return None
def _filter_by_area(
states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]],
area: area_registry.AreaEntry,
devices: device_registry.DeviceRegistry,
) -> Iterable[tuple[State, entity_registry.RegistryEntry | None]]:
"""Filter state/entity pairs by an area."""
entity_area_ids: dict[str, str | None] = {}
for _state, entity in states_and_entities:
if entity is None:
continue
if entity.area_id:
# Use entity's area id first
entity_area_ids[entity.id] = entity.area_id
elif entity.device_id:
# Fall back to device area if not set on entity
device = devices.async_get(entity.device_id)
if device is not None:
entity_area_ids[entity.id] = device.area_id
for state, entity in states_and_entities:
if (entity is not None) and (entity_area_ids.get(entity.id) == area.id):
yield (state, entity)
@callback
@bind_hass
def async_match_states(
@ -200,45 +247,29 @@ def async_match_states(
if areas is None:
areas = area_registry.async_get(hass)
# id or name
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
area_name
)
area = _find_area(area_name, areas)
assert area is not None, f"No area named {area_name}"
if area is not None:
# Filter by states/entities by area
if devices is None:
devices = device_registry.async_get(hass)
entity_area_ids: dict[str, str | None] = {}
for _state, entity in states_and_entities:
if entity is None:
continue
if entity.area_id:
# Use entity's area id first
entity_area_ids[entity.id] = entity.area_id
elif entity.device_id:
# Fall back to device area if not set on entity
device = devices.async_get(entity.device_id)
if device is not None:
entity_area_ids[entity.id] = device.area_id
# Filter by area
states_and_entities = [
(state, entity)
for state, entity in states_and_entities
if (entity is not None) and (entity_area_ids.get(entity.id) == area.id)
]
states_and_entities = list(_filter_by_area(states_and_entities, area, devices))
if name is not None:
if devices is None:
devices = device_registry.async_get(hass)
# Filter by name
name = name.casefold()
# Check states
for state, entity in states_and_entities:
if _has_name(state, entity, name):
yield state
break
else:
# Not filtered by name
for state, _entity in states_and_entities:

View File

@ -6,6 +6,7 @@ import pytest
from homeassistant.components import conversation
from homeassistant.components.cover import SERVICE_OPEN_COVER
from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context
from homeassistant.helpers import (
area_registry,
@ -777,3 +778,51 @@ async def test_turn_on_area(hass, init_components):
assert call.domain == HASS_DOMAIN
assert call.service == "turn_on"
assert call.data == {"entity_id": "light.stove"}
async def test_light_area_same_name(hass, init_components):
"""Test turning on a light with the same name as an area."""
entities = entity_registry.async_get(hass)
devices = device_registry.async_get(hass)
areas = area_registry.async_get(hass)
entry = MockConfigEntry(domain="test")
device = devices.async_get_or_create(
config_entry_id=entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
kitchen_area = areas.async_create("kitchen")
devices.async_update_device(device.id, area_id=kitchen_area.id)
kitchen_light = entities.async_get_or_create(
"light", "demo", "1234", original_name="kitchen light"
)
entities.async_update_entity(kitchen_light.entity_id, area_id=kitchen_area.id)
hass.states.async_set(
kitchen_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
)
ceiling_light = entities.async_get_or_create(
"light", "demo", "5678", original_name="ceiling light"
)
entities.async_update_entity(ceiling_light.entity_id, area_id=kitchen_area.id)
hass.states.async_set(
ceiling_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "ceiling light"}
)
calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")
await hass.services.async_call(
"conversation",
"process",
{conversation.ATTR_TEXT: "turn on kitchen light"},
)
await hass.async_block_till_done()
# Should only turn on one light instead of all lights in the kitchen
assert len(calls) == 1
call = calls[0]
assert call.domain == HASS_DOMAIN
assert call.service == "turn_on"
assert call.data == {"entity_id": kitchen_light.entity_id}

View File

@ -27,6 +27,7 @@ async def test_async_match_states(hass):
"""Test async_match_state helper."""
areas = area_registry.async_get(hass)
area_kitchen = areas.async_get_or_create("kitchen")
areas.async_update(area_kitchen.id, aliases={"food room"})
area_bedroom = areas.async_get_or_create("bedroom")
state1 = State(
@ -68,6 +69,13 @@ async def test_async_match_states(hass):
)
)
# Test area alias
assert [state1] == list(
intent.async_match_states(
hass, name="kitchen light", area_name="food room", states=[state1, state2]
)
)
# Wrong area
assert not list(
intent.async_match_states(