1
mirror of https://github.com/home-assistant/core synced 2024-07-30 21:18:57 +02:00

ESPHome voice assistant (#90691)

* Add ESPHome push-to-talk

* Send pipeline events to device

* Bump aioesphomeapi to 13.7.0

* Log error instead of print

* Rename variable

* lint

* Rename

* Fix type and cast

* Move event data manipulation into voice_assistant callback
Process full url

* Add a test?

* Remove import

* More tests

* Update import

* Update manifest

* fix tests

* Ugh

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Jesse Hills 2023-04-14 11:18:56 +12:00 committed by GitHub
parent 1c0b2630da
commit 0ddccb26fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 352 additions and 4 deletions

View File

@ -22,6 +22,7 @@ from aioesphomeapi import (
RequiresEncryptionAPIError,
UserService,
UserServiceArgType,
VoiceAssistantEventType,
)
from awesomeversion import AwesomeVersion
import voluptuous as vol
@ -64,6 +65,7 @@ from .domain_data import DomainData
# Import config flow so that it's added to the registry
from .entry_data import RuntimeEntryData
from .enum_mapper import EsphomeEnumMapper
from .voice_assistant import VoiceAssistantUDPServer
CONF_DEVICE_NAME = "device_name"
CONF_NOISE_PSK = "noise_psk"
@ -284,6 +286,39 @@ async def async_setup_entry( # noqa: C901
_send_home_assistant_state(entity_id, attribute, hass.states.get(entity_id))
)
voice_assistant_udp_server: VoiceAssistantUDPServer | None = None
def handle_pipeline_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
"""Handle a voice assistant pipeline event."""
cli.send_voice_assistant_event(event_type, data)
async def handle_pipeline_start() -> int | None:
"""Start a voice assistant pipeline."""
nonlocal voice_assistant_udp_server
if voice_assistant_udp_server is not None:
return None
voice_assistant_udp_server = VoiceAssistantUDPServer(hass)
port = await voice_assistant_udp_server.start_server()
hass.async_create_background_task(
voice_assistant_udp_server.run_pipeline(handle_pipeline_event),
"esphome.voice_assistant_udp_server.run_pipeline",
)
return port
async def handle_pipeline_stop() -> None:
"""Stop a voice assistant pipeline."""
nonlocal voice_assistant_udp_server
if voice_assistant_udp_server is not None:
voice_assistant_udp_server.stop()
voice_assistant_udp_server = None
async def on_connect() -> None:
"""Subscribe to states and list entities on successful API login."""
nonlocal device_id
@ -328,6 +363,14 @@ async def async_setup_entry( # noqa: C901
await cli.subscribe_service_calls(async_on_service_call)
await cli.subscribe_home_assistant_states(async_on_state_subscription)
if device_info.voice_assistant_version:
entry_data.disconnect_callbacks.append(
await cli.subscribe_voice_assistant(
handle_pipeline_start,
handle_pipeline_stop,
)
)
hass.async_create_task(entry_data.async_save_to_store())
except APIConnectionError as err:
_LOGGER.warning("Error getting initial data for %s: %s", host, err)

View File

@ -1,7 +1,7 @@
{
"domain": "esphome",
"name": "ESPHome",
"after_dependencies": ["zeroconf", "tag"],
"after_dependencies": ["zeroconf", "tag", "assist_pipeline"],
"codeowners": ["@OttoWinter", "@jesserockz"],
"config_flow": true,
"dependencies": ["bluetooth"],
@ -14,6 +14,6 @@
"integration_type": "device",
"iot_class": "local_push",
"loggers": ["aioesphomeapi", "noiseprotocol"],
"requirements": ["aioesphomeapi==13.6.1", "esphome-dashboard-api==1.2.3"],
"requirements": ["aioesphomeapi==13.7.0", "esphome-dashboard-api==1.2.3"],
"zeroconf": ["_esphomelib._tcp.local."]
}

View File

@ -0,0 +1,164 @@
"""ESPHome voice assistant support."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable, Callable
import logging
import socket
from typing import cast
from aioesphomeapi import VoiceAssistantEventType
from homeassistant.components import stt
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
async_pipeline_from_audio_stream,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.core import HomeAssistant, callback
from .enum_mapper import EsphomeEnumMapper
_LOGGER = logging.getLogger(__name__)
UDP_PORT = 0 # Set to 0 to let the OS pick a free random port
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
VoiceAssistantEventType, PipelineEventType
] = EsphomeEnumMapper(
{
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
}
)
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"""Receive UDP packets and forward them to the voice assistant."""
started = False
queue: asyncio.Queue[bytes] | None = None
transport: asyncio.DatagramTransport | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize UDP receiver."""
self.hass = hass
self.queue = asyncio.Queue()
async def start_server(self) -> int:
"""Start accepting connections."""
def accept_connection() -> VoiceAssistantUDPServer:
"""Accept connection."""
if self.started:
raise RuntimeError("Can only start once")
if self.queue is None:
raise RuntimeError("No longer accepting connections")
self.started = True
return self
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", UDP_PORT))
await asyncio.get_running_loop().create_datagram_endpoint(
accept_connection, sock=sock
)
return cast(int, sock.getsockname()[1])
@callback
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Store transport for later use."""
self.transport = cast(asyncio.DatagramTransport, transport)
@callback
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if self.queue is not None:
self.queue.put_nowait(data)
def error_received(self, exc: Exception) -> None:
"""Handle when a send or receive operation raises an OSError.
(Other than BlockingIOError or InterruptedError.)
"""
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
@callback
def stop(self) -> None:
"""Stop the receiver."""
if self.queue is not None:
self.queue.put_nowait(b"")
self.queue = None
if self.transport is not None:
self.transport.close()
async def _iterate_packets(self) -> AsyncIterable[bytes]:
"""Iterate over incoming packets."""
if self.queue is None:
raise RuntimeError("Already stopped")
while data := await self.queue.get():
yield data
async def run_pipeline(
self,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
) -> None:
"""Run the Voice Assistant pipeline."""
@callback
def handle_pipeline_event(event: PipelineEvent) -> None:
"""Handle pipeline events."""
try:
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
except KeyError:
_LOGGER.warning("Received unknown pipeline event type: %s", event.type)
return
data_to_send = None
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
assert event.data is not None
data_to_send = {"text": event.data["stt_output"]["text"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
assert event.data is not None
data_to_send = {"text": event.data["tts_input"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
assert event.data is not None
path = event.data["tts_output"]["url"]
url = async_process_play_media_url(self.hass, path)
data_to_send = {"url": url}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert event.data is not None
data_to_send = {
"code": event.data["code"],
"message": event.data["message"],
}
handle_event(event_type, data_to_send)
await async_pipeline_from_audio_stream(
self.hass,
event_callback=handle_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=self._iterate_packets(),
)

View File

@ -156,7 +156,7 @@ aioecowitt==2023.01.0
aioemonitor==1.0.5
# homeassistant.components.esphome
aioesphomeapi==13.6.1
aioesphomeapi==13.7.0
# homeassistant.components.flo
aioflo==2021.11.0

View File

@ -146,7 +146,7 @@ aioecowitt==2023.01.0
aioemonitor==1.0.5
# homeassistant.components.esphome
aioesphomeapi==13.6.1
aioesphomeapi==13.7.0
# homeassistant.components.flo
aioflo==2021.11.0

View File

@ -0,0 +1,141 @@
"""Test ESPHome voice assistant server."""
import asyncio
import socket
from unittest.mock import Mock, patch
import async_timeout
import pytest
from homeassistant.components import assist_pipeline, esphome
from homeassistant.core import HomeAssistant
_TEST_INPUT_TEXT = "This is an input test"
_TEST_OUTPUT_TEXT = "This is an output test"
_TEST_OUTPUT_URL = "output.mp3"
async def test_pipeline_events(hass: HomeAssistant) -> None:
"""Test that the pipeline function is called."""
async def async_pipeline_from_audio_stream(*args, **kwargs):
event_callback = kwargs["event_callback"]
# Fake events
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_START,
data={},
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": _TEST_INPUT_TEXT}},
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_START,
data={"tts_input": _TEST_OUTPUT_TEXT},
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"url": _TEST_OUTPUT_URL}},
)
)
def handle_event(
event_type: esphome.VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
if event_type == esphome.VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
assert data is not None
assert data["text"] == _TEST_INPUT_TEXT
elif event_type == esphome.VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
assert data is not None
assert data["text"] == _TEST_OUTPUT_TEXT
elif event_type == esphome.VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
assert data is not None
assert data["url"] == _TEST_OUTPUT_URL
with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
server.transport = Mock()
await server.run_pipeline(handle_event)
async def test_udp_server(
hass: HomeAssistant,
socket_enabled,
unused_udp_port_factory,
) -> None:
"""Test the UDP server runs and queues incoming data."""
port_to_use = unused_udp_port_factory()
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use
):
port = await server.start_server()
assert port == port_to_use
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
assert server.queue.qsize() == 0
sock.sendto(b"test", ("127.0.0.1", port))
# Give the socket some time to send/receive the data
async with async_timeout.timeout(1):
while server.queue.qsize() == 0:
await asyncio.sleep(0.1)
assert server.queue.qsize() == 1
server.stop()
assert server.transport.is_closing()
async def test_udp_server_multiple(
hass: HomeAssistant,
socket_enabled,
unused_udp_port_factory,
) -> None:
"""Test that the UDP server raises an error if started twice."""
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(),
):
await server.start_server()
with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(),
), pytest.raises(RuntimeError):
pass
await server.start_server()
async def test_udp_server_after_stopped(
hass: HomeAssistant,
socket_enabled,
unused_udp_port_factory,
) -> None:
"""Test that the UDP server raises an error if started after stopped."""
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
server.stop()
with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(),
), pytest.raises(RuntimeError):
await server.start_server()