mirror of https://github.com/home-assistant/core
Add wildcards to sentence triggers (#97236)
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
af286a8feb
commit
7e3fdd85fc
|
@ -322,7 +322,11 @@ async def websocket_hass_agent_debug(
|
|||
"intent": {
|
||||
"name": result.intent.name,
|
||||
},
|
||||
"entities": {
|
||||
"slots": { # direct access to values
|
||||
entity_key: entity.value
|
||||
for entity_key, entity in result.entities.items()
|
||||
},
|
||||
"details": {
|
||||
entity_key: {
|
||||
"name": entity.name,
|
||||
"value": entity.value,
|
||||
|
|
|
@ -11,7 +11,14 @@ from pathlib import Path
|
|||
import re
|
||||
from typing import IO, Any
|
||||
|
||||
from hassil.intents import Intents, ResponseType, SlotList, TextSlotList
|
||||
from hassil.expression import Expression, ListReference, Sequence
|
||||
from hassil.intents import (
|
||||
Intents,
|
||||
ResponseType,
|
||||
SlotList,
|
||||
TextSlotList,
|
||||
WildcardSlotList,
|
||||
)
|
||||
from hassil.recognize import RecognizeResult, recognize_all
|
||||
from hassil.util import merge_dict
|
||||
from home_assistant_intents import get_domains_and_languages, get_intents
|
||||
|
@ -48,7 +55,7 @@ _ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]
|
|||
|
||||
REGEX_TYPE = type(re.compile(""))
|
||||
TRIGGER_CALLBACK_TYPE = Callable[ # pylint: disable=invalid-name
|
||||
[str], Awaitable[str | None]
|
||||
[str, RecognizeResult], Awaitable[str | None]
|
||||
]
|
||||
|
||||
|
||||
|
@ -657,6 +664,17 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
}
|
||||
|
||||
self._trigger_intents = Intents.from_dict(intents_dict)
|
||||
|
||||
# Assume slot list references are wildcards
|
||||
wildcard_names: set[str] = set()
|
||||
for trigger_intent in self._trigger_intents.intents.values():
|
||||
for intent_data in trigger_intent.data:
|
||||
for sentence in intent_data.sentences:
|
||||
_collect_list_references(sentence, wildcard_names)
|
||||
|
||||
for wildcard_name in wildcard_names:
|
||||
self._trigger_intents.slot_lists[wildcard_name] = WildcardSlotList()
|
||||
|
||||
_LOGGER.debug("Rebuilt trigger intents: %s", intents_dict)
|
||||
|
||||
def _unregister_trigger(self, trigger_data: TriggerData) -> None:
|
||||
|
@ -682,14 +700,14 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
|
||||
assert self._trigger_intents is not None
|
||||
|
||||
matched_triggers: set[int] = set()
|
||||
matched_triggers: dict[int, RecognizeResult] = {}
|
||||
for result in recognize_all(sentence, self._trigger_intents):
|
||||
trigger_id = int(result.intent.name)
|
||||
if trigger_id in matched_triggers:
|
||||
# Already matched a sentence from this trigger
|
||||
break
|
||||
|
||||
matched_triggers.add(trigger_id)
|
||||
matched_triggers[trigger_id] = result
|
||||
|
||||
if not matched_triggers:
|
||||
# Sentence did not match any trigger sentences
|
||||
|
@ -699,14 +717,14 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
"'%s' matched %s trigger(s): %s",
|
||||
sentence,
|
||||
len(matched_triggers),
|
||||
matched_triggers,
|
||||
list(matched_triggers),
|
||||
)
|
||||
|
||||
# Gather callback responses in parallel
|
||||
trigger_responses = await asyncio.gather(
|
||||
*(
|
||||
self._trigger_sentences[trigger_id].callback(sentence)
|
||||
for trigger_id in matched_triggers
|
||||
self._trigger_sentences[trigger_id].callback(sentence, result)
|
||||
for trigger_id, result in matched_triggers.items()
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -733,3 +751,15 @@ def _make_error_result(
|
|||
response.async_set_error(error_code, response_text)
|
||||
|
||||
return ConversationResult(response, conversation_id)
|
||||
|
||||
|
||||
def _collect_list_references(expression: Expression, list_names: set[str]) -> None:
|
||||
"""Collect list reference names recursively."""
|
||||
if isinstance(expression, Sequence):
|
||||
seq: Sequence = expression
|
||||
for item in seq.items:
|
||||
_collect_list_references(item, list_names)
|
||||
elif isinstance(expression, ListReference):
|
||||
# {list}
|
||||
list_ref: ListReference = expression
|
||||
list_names.add(list_ref.slot_name)
|
||||
|
|
|
@ -7,5 +7,5 @@
|
|||
"integration_type": "system",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal",
|
||||
"requirements": ["hassil==1.2.2", "home-assistant-intents==2023.7.25"]
|
||||
"requirements": ["hassil==1.2.5", "home-assistant-intents==2023.7.25"]
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from hassil.recognize import PUNCTUATION
|
||||
from hassil.recognize import PUNCTUATION, RecognizeResult
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_COMMAND, CONF_PLATFORM
|
||||
|
@ -49,12 +49,29 @@ async def async_attach_trigger(
|
|||
job = HassJob(action)
|
||||
|
||||
@callback
|
||||
async def call_action(sentence: str) -> str | None:
|
||||
async def call_action(sentence: str, result: RecognizeResult) -> str | None:
|
||||
"""Call action with right context."""
|
||||
|
||||
# Add slot values as extra trigger data
|
||||
details = {
|
||||
entity_name: {
|
||||
"name": entity_name,
|
||||
"text": entity.text.strip(), # remove whitespace
|
||||
"value": entity.value.strip()
|
||||
if isinstance(entity.value, str)
|
||||
else entity.value,
|
||||
}
|
||||
for entity_name, entity in result.entities.items()
|
||||
}
|
||||
|
||||
trigger_input: dict[str, Any] = { # Satisfy type checker
|
||||
**trigger_data,
|
||||
"platform": DOMAIN,
|
||||
"sentence": sentence,
|
||||
"details": details,
|
||||
"slots": { # direct access to values
|
||||
entity_name: entity["value"] for entity_name, entity in details.items()
|
||||
},
|
||||
}
|
||||
|
||||
# Wait for the automation to complete
|
||||
|
|
|
@ -20,7 +20,7 @@ dbus-fast==1.87.2
|
|||
fnv-hash-fast==0.4.0
|
||||
ha-av==10.1.0
|
||||
hass-nabucasa==0.69.0
|
||||
hassil==1.2.2
|
||||
hassil==1.2.5
|
||||
home-assistant-bluetooth==1.10.2
|
||||
home-assistant-frontend==20230725.0
|
||||
home-assistant-intents==2023.7.25
|
||||
|
|
|
@ -958,7 +958,7 @@ hass-nabucasa==0.69.0
|
|||
hass-splunk==0.1.1
|
||||
|
||||
# homeassistant.components.conversation
|
||||
hassil==1.2.2
|
||||
hassil==1.2.5
|
||||
|
||||
# homeassistant.components.jewish_calendar
|
||||
hdate==0.10.4
|
||||
|
|
|
@ -753,7 +753,7 @@ habitipy==0.2.0
|
|||
hass-nabucasa==0.69.0
|
||||
|
||||
# homeassistant.components.conversation
|
||||
hassil==1.2.2
|
||||
hassil==1.2.5
|
||||
|
||||
# homeassistant.components.jewish_calendar
|
||||
hdate==0.10.4
|
||||
|
|
|
@ -372,7 +372,7 @@
|
|||
dict({
|
||||
'results': list([
|
||||
dict({
|
||||
'entities': dict({
|
||||
'details': dict({
|
||||
'name': dict({
|
||||
'name': 'name',
|
||||
'text': 'my cool light',
|
||||
|
@ -382,6 +382,9 @@
|
|||
'intent': dict({
|
||||
'name': 'HassTurnOn',
|
||||
}),
|
||||
'slots': dict({
|
||||
'name': 'my cool light',
|
||||
}),
|
||||
'targets': dict({
|
||||
'light.kitchen': dict({
|
||||
'matched': True,
|
||||
|
@ -389,7 +392,7 @@
|
|||
}),
|
||||
}),
|
||||
dict({
|
||||
'entities': dict({
|
||||
'details': dict({
|
||||
'name': dict({
|
||||
'name': 'name',
|
||||
'text': 'my cool light',
|
||||
|
@ -399,6 +402,9 @@
|
|||
'intent': dict({
|
||||
'name': 'HassTurnOff',
|
||||
}),
|
||||
'slots': dict({
|
||||
'name': 'my cool light',
|
||||
}),
|
||||
'targets': dict({
|
||||
'light.kitchen': dict({
|
||||
'matched': True,
|
||||
|
@ -406,7 +412,7 @@
|
|||
}),
|
||||
}),
|
||||
dict({
|
||||
'entities': dict({
|
||||
'details': dict({
|
||||
'area': dict({
|
||||
'name': 'area',
|
||||
'text': 'kitchen',
|
||||
|
@ -421,6 +427,10 @@
|
|||
'intent': dict({
|
||||
'name': 'HassTurnOn',
|
||||
}),
|
||||
'slots': dict({
|
||||
'area': 'kitchen',
|
||||
'domain': 'light',
|
||||
}),
|
||||
'targets': dict({
|
||||
'light.kitchen': dict({
|
||||
'matched': True,
|
||||
|
@ -428,7 +438,7 @@
|
|||
}),
|
||||
}),
|
||||
dict({
|
||||
'entities': dict({
|
||||
'details': dict({
|
||||
'area': dict({
|
||||
'name': 'area',
|
||||
'text': 'kitchen',
|
||||
|
@ -448,6 +458,11 @@
|
|||
'intent': dict({
|
||||
'name': 'HassGetState',
|
||||
}),
|
||||
'slots': dict({
|
||||
'area': 'kitchen',
|
||||
'domain': 'light',
|
||||
'state': 'on',
|
||||
}),
|
||||
'targets': dict({
|
||||
'light.kitchen': dict({
|
||||
'matched': False,
|
||||
|
|
|
@ -246,7 +246,8 @@ async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None:
|
|||
for sentence in test_sentences:
|
||||
callback.reset_mock()
|
||||
result = await conversation.async_converse(hass, sentence, None, Context())
|
||||
callback.assert_called_once_with(sentence)
|
||||
assert callback.call_count == 1
|
||||
assert callback.call_args[0][0] == sentence
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), sentence
|
||||
|
|
|
@ -61,6 +61,8 @@ async def test_if_fires_on_event(hass: HomeAssistant, calls, setup_comp) -> None
|
|||
"idx": "0",
|
||||
"platform": "conversation",
|
||||
"sentence": "Ha ha ha",
|
||||
"slots": {},
|
||||
"details": {},
|
||||
}
|
||||
|
||||
|
||||
|
@ -103,6 +105,8 @@ async def test_same_trigger_multiple_sentences(
|
|||
"idx": "0",
|
||||
"platform": "conversation",
|
||||
"sentence": "hello",
|
||||
"slots": {},
|
||||
"details": {},
|
||||
}
|
||||
|
||||
|
||||
|
@ -188,3 +192,60 @@ async def test_fails_on_punctuation(hass: HomeAssistant, command: str) -> None:
|
|||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def test_wildcards(hass: HomeAssistant, calls, setup_comp) -> None:
|
||||
"""Test wildcards in trigger sentences."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"automation",
|
||||
{
|
||||
"automation": {
|
||||
"trigger": {
|
||||
"platform": "conversation",
|
||||
"command": [
|
||||
"play {album} by {artist}",
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
"data_template": {"data": "{{ trigger }}"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{
|
||||
"text": "play the white album by the beatles",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data["data"] == {
|
||||
"alias": None,
|
||||
"id": "0",
|
||||
"idx": "0",
|
||||
"platform": "conversation",
|
||||
"sentence": "play the white album by the beatles",
|
||||
"slots": {
|
||||
"album": "the white album",
|
||||
"artist": "the beatles",
|
||||
},
|
||||
"details": {
|
||||
"album": {
|
||||
"name": "album",
|
||||
"text": "the white album",
|
||||
"value": "the white album",
|
||||
},
|
||||
"artist": {
|
||||
"name": "artist",
|
||||
"text": "the beatles",
|
||||
"value": "the beatles",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue