Add wildcards to sentence triggers (#97236)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Michael Hansen 2023-07-27 13:30:42 -05:00 committed by GitHub
parent af286a8feb
commit 7e3fdd85fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 147 additions and 19 deletions

View File

@ -322,7 +322,11 @@ async def websocket_hass_agent_debug(
"intent": {
"name": result.intent.name,
},
"entities": {
"slots": { # direct access to values
entity_key: entity.value
for entity_key, entity in result.entities.items()
},
"details": {
entity_key: {
"name": entity.name,
"value": entity.value,

View File

@ -11,7 +11,14 @@ from pathlib import Path
import re
from typing import IO, Any
from hassil.intents import Intents, ResponseType, SlotList, TextSlotList
from hassil.expression import Expression, ListReference, Sequence
from hassil.intents import (
Intents,
ResponseType,
SlotList,
TextSlotList,
WildcardSlotList,
)
from hassil.recognize import RecognizeResult, recognize_all
from hassil.util import merge_dict
from home_assistant_intents import get_domains_and_languages, get_intents
@ -48,7 +55,7 @@ _ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]
REGEX_TYPE = type(re.compile(""))
TRIGGER_CALLBACK_TYPE = Callable[ # pylint: disable=invalid-name
[str], Awaitable[str | None]
[str, RecognizeResult], Awaitable[str | None]
]
@ -657,6 +664,17 @@ class DefaultAgent(AbstractConversationAgent):
}
self._trigger_intents = Intents.from_dict(intents_dict)
# Assume slot list references are wildcards
wildcard_names: set[str] = set()
for trigger_intent in self._trigger_intents.intents.values():
for intent_data in trigger_intent.data:
for sentence in intent_data.sentences:
_collect_list_references(sentence, wildcard_names)
for wildcard_name in wildcard_names:
self._trigger_intents.slot_lists[wildcard_name] = WildcardSlotList()
_LOGGER.debug("Rebuilt trigger intents: %s", intents_dict)
def _unregister_trigger(self, trigger_data: TriggerData) -> None:
@ -682,14 +700,14 @@ class DefaultAgent(AbstractConversationAgent):
assert self._trigger_intents is not None
matched_triggers: set[int] = set()
matched_triggers: dict[int, RecognizeResult] = {}
for result in recognize_all(sentence, self._trigger_intents):
trigger_id = int(result.intent.name)
if trigger_id in matched_triggers:
# Already matched a sentence from this trigger
break
matched_triggers.add(trigger_id)
matched_triggers[trigger_id] = result
if not matched_triggers:
# Sentence did not match any trigger sentences
@ -699,14 +717,14 @@ class DefaultAgent(AbstractConversationAgent):
"'%s' matched %s trigger(s): %s",
sentence,
len(matched_triggers),
matched_triggers,
list(matched_triggers),
)
# Gather callback responses in parallel
trigger_responses = await asyncio.gather(
*(
self._trigger_sentences[trigger_id].callback(sentence)
for trigger_id in matched_triggers
self._trigger_sentences[trigger_id].callback(sentence, result)
for trigger_id, result in matched_triggers.items()
)
)
@ -733,3 +751,15 @@ def _make_error_result(
response.async_set_error(error_code, response_text)
return ConversationResult(response, conversation_id)
def _collect_list_references(expression: Expression, list_names: set[str]) -> None:
"""Collect list reference names recursively."""
if isinstance(expression, Sequence):
seq: Sequence = expression
for item in seq.items:
_collect_list_references(item, list_names)
elif isinstance(expression, ListReference):
# {list}
list_ref: ListReference = expression
list_names.add(list_ref.slot_name)

View File

@ -7,5 +7,5 @@
"integration_type": "system",
"iot_class": "local_push",
"quality_scale": "internal",
"requirements": ["hassil==1.2.2", "home-assistant-intents==2023.7.25"]
"requirements": ["hassil==1.2.5", "home-assistant-intents==2023.7.25"]
}

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any
from hassil.recognize import PUNCTUATION
from hassil.recognize import PUNCTUATION, RecognizeResult
import voluptuous as vol
from homeassistant.const import CONF_COMMAND, CONF_PLATFORM
@ -49,12 +49,29 @@ async def async_attach_trigger(
job = HassJob(action)
@callback
async def call_action(sentence: str) -> str | None:
async def call_action(sentence: str, result: RecognizeResult) -> str | None:
"""Call action with right context."""
# Add slot values as extra trigger data
details = {
entity_name: {
"name": entity_name,
"text": entity.text.strip(), # remove whitespace
"value": entity.value.strip()
if isinstance(entity.value, str)
else entity.value,
}
for entity_name, entity in result.entities.items()
}
trigger_input: dict[str, Any] = { # Satisfy type checker
**trigger_data,
"platform": DOMAIN,
"sentence": sentence,
"details": details,
"slots": { # direct access to values
entity_name: entity["value"] for entity_name, entity in details.items()
},
}
# Wait for the automation to complete

View File

@ -20,7 +20,7 @@ dbus-fast==1.87.2
fnv-hash-fast==0.4.0
ha-av==10.1.0
hass-nabucasa==0.69.0
hassil==1.2.2
hassil==1.2.5
home-assistant-bluetooth==1.10.2
home-assistant-frontend==20230725.0
home-assistant-intents==2023.7.25

View File

@ -958,7 +958,7 @@ hass-nabucasa==0.69.0
hass-splunk==0.1.1
# homeassistant.components.conversation
hassil==1.2.2
hassil==1.2.5
# homeassistant.components.jewish_calendar
hdate==0.10.4

View File

@ -753,7 +753,7 @@ habitipy==0.2.0
hass-nabucasa==0.69.0
# homeassistant.components.conversation
hassil==1.2.2
hassil==1.2.5
# homeassistant.components.jewish_calendar
hdate==0.10.4

View File

@ -372,7 +372,7 @@
dict({
'results': list([
dict({
'entities': dict({
'details': dict({
'name': dict({
'name': 'name',
'text': 'my cool light',
@ -382,6 +382,9 @@
'intent': dict({
'name': 'HassTurnOn',
}),
'slots': dict({
'name': 'my cool light',
}),
'targets': dict({
'light.kitchen': dict({
'matched': True,
@ -389,7 +392,7 @@
}),
}),
dict({
'entities': dict({
'details': dict({
'name': dict({
'name': 'name',
'text': 'my cool light',
@ -399,6 +402,9 @@
'intent': dict({
'name': 'HassTurnOff',
}),
'slots': dict({
'name': 'my cool light',
}),
'targets': dict({
'light.kitchen': dict({
'matched': True,
@ -406,7 +412,7 @@
}),
}),
dict({
'entities': dict({
'details': dict({
'area': dict({
'name': 'area',
'text': 'kitchen',
@ -421,6 +427,10 @@
'intent': dict({
'name': 'HassTurnOn',
}),
'slots': dict({
'area': 'kitchen',
'domain': 'light',
}),
'targets': dict({
'light.kitchen': dict({
'matched': True,
@ -428,7 +438,7 @@
}),
}),
dict({
'entities': dict({
'details': dict({
'area': dict({
'name': 'area',
'text': 'kitchen',
@ -448,6 +458,11 @@
'intent': dict({
'name': 'HassGetState',
}),
'slots': dict({
'area': 'kitchen',
'domain': 'light',
'state': 'on',
}),
'targets': dict({
'light.kitchen': dict({
'matched': False,

View File

@ -246,7 +246,8 @@ async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None:
for sentence in test_sentences:
callback.reset_mock()
result = await conversation.async_converse(hass, sentence, None, Context())
callback.assert_called_once_with(sentence)
assert callback.call_count == 1
assert callback.call_args[0][0] == sentence
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), sentence

View File

@ -61,6 +61,8 @@ async def test_if_fires_on_event(hass: HomeAssistant, calls, setup_comp) -> None
"idx": "0",
"platform": "conversation",
"sentence": "Ha ha ha",
"slots": {},
"details": {},
}
@ -103,6 +105,8 @@ async def test_same_trigger_multiple_sentences(
"idx": "0",
"platform": "conversation",
"sentence": "hello",
"slots": {},
"details": {},
}
@ -188,3 +192,60 @@ async def test_fails_on_punctuation(hass: HomeAssistant, command: str) -> None:
},
],
)
async def test_wildcards(hass: HomeAssistant, calls, setup_comp) -> None:
"""Test wildcards in trigger sentences."""
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": [
"play {album} by {artist}",
],
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
},
}
},
)
await hass.services.async_call(
"conversation",
"process",
{
"text": "play the white album by the beatles",
},
blocking=True,
)
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"platform": "conversation",
"sentence": "play the white album by the beatles",
"slots": {
"album": "the white album",
"artist": "the beatles",
},
"details": {
"album": {
"name": "album",
"text": "the white album",
"value": "the white album",
},
"artist": {
"name": "artist",
"text": "the beatles",
"value": "the beatles",
},
},
}