mirror of https://github.com/home-assistant/core
Voice assistant integration with pipelines (#89822)
* Initial commit * Add websocket test tool * Small tweak * Tiny cleanup * Make pipeline work with frontend branch * Add some more info to start event * Fixes * First voice assistant tests * Remove run_task * Clean up for PR * Add config_flow.py * Remove CLI tool * Simplify by removing stt/tts for now * Clean up and fix tests * More clean up and API changes * Add quality_scale * Remove data from run-finish * Use StrEnum backport --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
81c0382e4b
commit
e16f17f5a8
|
@ -1309,6 +1309,8 @@ build.json @home-assistant/supervisor
|
|||
/tests/components/vizio/ @raman325
|
||||
/homeassistant/components/vlc_telnet/ @rodripf @MartinHjelmare
|
||||
/tests/components/vlc_telnet/ @rodripf @MartinHjelmare
|
||||
/homeassistant/components/voice_assistant/ @balloob @synesthesiam
|
||||
/tests/components/voice_assistant/ @balloob @synesthesiam
|
||||
/homeassistant/components/volumio/ @OnFreund
|
||||
/tests/components/volumio/ @OnFreund
|
||||
/homeassistant/components/volvooncall/ @molobrakos
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
"""The Voice Assistant integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DEFAULT_PIPELINE, DOMAIN
|
||||
from .pipeline import Pipeline
|
||||
from .websocket_api import async_register_websocket_api
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up Voice Assistant integration."""
|
||||
hass.data[DOMAIN] = {
|
||||
DEFAULT_PIPELINE: Pipeline(
|
||||
name=DEFAULT_PIPELINE,
|
||||
language=None,
|
||||
conversation_engine=None,
|
||||
)
|
||||
}
|
||||
async_register_websocket_api(hass)
|
||||
|
||||
return True
|
|
@ -0,0 +1,3 @@
|
|||
"""Constants for the Voice Assistant integration."""
|
||||
DOMAIN = "voice_assistant"
|
||||
DEFAULT_PIPELINE = "default"
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"domain": "voice_assistant",
|
||||
"name": "Voice Assistant",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"dependencies": ["conversation"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/voice_assistant",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal"
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
"""Classes for voice assistant pipelines."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.backports.enum import StrEnum
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
DEFAULT_TIMEOUT = 30 # seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRequest:
|
||||
"""Request to start a pipeline run."""
|
||||
|
||||
intent_input: str
|
||||
conversation_id: str | None = None
|
||||
|
||||
|
||||
class PipelineEventType(StrEnum):
|
||||
"""Event types emitted during a pipeline run."""
|
||||
|
||||
RUN_START = "run-start"
|
||||
RUN_FINISH = "run-finish"
|
||||
INTENT_START = "intent-start"
|
||||
INTENT_FINISH = "intent-finish"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineEvent:
|
||||
"""Events emitted during a pipeline run."""
|
||||
|
||||
type: PipelineEventType
|
||||
data: dict[str, Any] | None = None
|
||||
timestamp: str = field(default_factory=lambda: utcnow().isoformat())
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dict representation of the event."""
|
||||
return {
|
||||
"type": self.type,
|
||||
"timestamp": self.timestamp,
|
||||
"data": self.data or {},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pipeline:
|
||||
"""A voice assistant pipeline."""
|
||||
|
||||
name: str
|
||||
language: str | None
|
||||
conversation_engine: str | None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
context: Context,
|
||||
request: PipelineRequest,
|
||||
event_callback: Callable[[PipelineEvent], None],
|
||||
timeout: int | float | None = DEFAULT_TIMEOUT,
|
||||
) -> None:
|
||||
"""Run a pipeline with an optional timeout."""
|
||||
await asyncio.wait_for(
|
||||
self._run(hass, context, request, event_callback), timeout=timeout
|
||||
)
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
context: Context,
|
||||
request: PipelineRequest,
|
||||
event_callback: Callable[[PipelineEvent], None],
|
||||
) -> None:
|
||||
"""Run a pipeline."""
|
||||
language = self.language or hass.config.language
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.RUN_START,
|
||||
{
|
||||
"pipeline": self.name,
|
||||
"language": language,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
intent_input = request.intent_input
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.INTENT_START,
|
||||
{
|
||||
"engine": self.conversation_engine or "default",
|
||||
"intent_input": intent_input,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
conversation_result = await conversation.async_converse(
|
||||
hass=hass,
|
||||
text=intent_input,
|
||||
conversation_id=request.conversation_id,
|
||||
context=context,
|
||||
language=language,
|
||||
agent_id=self.conversation_engine,
|
||||
)
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.INTENT_FINISH,
|
||||
{"intent_output": conversation_result.as_dict()},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.RUN_FINISH,
|
||||
)
|
||||
)
|
|
@ -0,0 +1,67 @@
|
|||
"""Voice Assistant Websocket API."""
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .pipeline import DEFAULT_TIMEOUT, PipelineRequest
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
"""Register the websocket API."""
|
||||
websocket_api.async_register_command(hass, websocket_run)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "voice_assistant/run",
|
||||
vol.Optional("pipeline", default="default"): str,
|
||||
vol.Required("intent_input"): str,
|
||||
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||
vol.Optional("timeout"): vol.Any(float, int),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_run(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Run a pipeline."""
|
||||
pipeline_id = msg["pipeline"]
|
||||
pipeline = hass.data[DOMAIN].get(pipeline_id)
|
||||
if pipeline is None:
|
||||
connection.send_error(
|
||||
msg["id"], "pipeline_not_found", f"Pipeline not found: {pipeline_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Run pipeline with a timeout.
|
||||
# Events are sent over the websocket connection.
|
||||
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
|
||||
run_task = hass.async_create_task(
|
||||
pipeline.run(
|
||||
hass,
|
||||
connection.context(msg),
|
||||
request=PipelineRequest(
|
||||
intent_input=msg["intent_input"],
|
||||
conversation_id=msg.get("conversation_id"),
|
||||
),
|
||||
event_callback=lambda event: connection.send_event(
|
||||
msg["id"], event.as_dict()
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
)
|
||||
|
||||
# Cancel pipeline if user unsubscribes
|
||||
connection.subscriptions[msg["id"]] = run_task.cancel
|
||||
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
# Task contains a timeout
|
||||
await run_task
|
|
@ -65,6 +65,11 @@ class ActiveConnection:
|
|||
"""Send a result message."""
|
||||
self.send_message(messages.result_message(msg_id, result))
|
||||
|
||||
@callback
|
||||
def send_event(self, msg_id: int, event: Any | None = None) -> None:
|
||||
"""Send a event message."""
|
||||
self.send_message(messages.event_message(msg_id, event))
|
||||
|
||||
@callback
|
||||
def send_error(self, msg_id: int, code: str, message: str) -> None:
|
||||
"""Send a error message."""
|
||||
|
|
|
@ -6068,6 +6068,12 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"voice_assistant": {
|
||||
"name": "Voice Assistant",
|
||||
"integration_type": "hub",
|
||||
"config_flow": false,
|
||||
"iot_class": "local_push"
|
||||
},
|
||||
"voicerss": {
|
||||
"name": "VoiceRSS",
|
||||
"integration_type": "hub",
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
"""Tests for the Voice Assistant integration."""
|
|
@ -0,0 +1,152 @@
|
|||
"""Websocket tests for Voice Assistant integration."""
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def init_components(hass):
|
||||
"""Initialize relevant components with empty configs."""
|
||||
assert await async_setup_component(hass, "voice_assistant", {})
|
||||
|
||||
|
||||
async def test_text_only_pipeline(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json(
|
||||
{"id": 5, "type": "voice_assistant/run", "intent_input": "Are the lights on?"}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"pipeline": "default",
|
||||
"language": hass.config.language,
|
||||
}
|
||||
|
||||
# intent
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"engine": "default",
|
||||
"intent_input": "Are the lights on?",
|
||||
}
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-finish"
|
||||
assert msg["event"]["data"] == {
|
||||
"intent_output": {
|
||||
"response": {
|
||||
"speech": {
|
||||
"plain": {
|
||||
"speech": "Sorry, I couldn't understand that",
|
||||
"extra_data": None,
|
||||
}
|
||||
},
|
||||
"card": {},
|
||||
"language": "en",
|
||||
"response_type": "error",
|
||||
"data": {"code": "no_intent_match"},
|
||||
},
|
||||
"conversation_id": None,
|
||||
}
|
||||
}
|
||||
|
||||
# run finish
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-finish"
|
||||
assert msg["event"]["data"] == {}
|
||||
|
||||
|
||||
async def test_conversation_timeout(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
) -> None:
|
||||
"""Test partial pipeline run with conversation agent timeout."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
async def sleepy_converse(*args, **kwargs):
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.async_converse", new=sleepy_converse
|
||||
):
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "voice_assistant/run",
|
||||
"intent_input": "Are the lights on?",
|
||||
"timeout": 0.00001,
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"pipeline": "default",
|
||||
"language": hass.config.language,
|
||||
}
|
||||
|
||||
# intent
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"engine": "default",
|
||||
"intent_input": "Are the lights on?",
|
||||
}
|
||||
|
||||
# timeout error
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "timeout"
|
||||
|
||||
|
||||
async def test_pipeline_timeout(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
) -> None:
|
||||
"""Test pipeline run with immediate timeout."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
async def sleepy_run(*args, **kwargs):
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voice_assistant.pipeline.Pipeline._run",
|
||||
new=sleepy_run,
|
||||
):
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "voice_assistant/run",
|
||||
"intent_input": "Are the lights on?",
|
||||
"timeout": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# timeout error
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "timeout"
|
Loading…
Reference in New Issue