mirror of
https://github.com/home-assistant/core
synced 2024-07-12 07:21:24 +02:00
Clean up Speech-to-text integration and add tests (#79012)
This commit is contained in:
parent
1b144c0e4d
commit
5774664234
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from dataclasses import asdict, dataclass
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@ -13,7 +14,6 @@ from aiohttp.web_exceptions import (
|
||||
HTTPNotFound,
|
||||
HTTPUnsupportedMediaType,
|
||||
)
|
||||
import attr
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
@ -34,9 +34,18 @@ from .const import (
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider:
|
||||
"""Return provider."""
|
||||
if domain is None:
|
||||
domain = next(iter(hass.data[DOMAIN]))
|
||||
|
||||
return hass.data[DOMAIN][domain]
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up STT."""
|
||||
providers = {}
|
||||
providers = hass.data[DOMAIN] = {}
|
||||
|
||||
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
||||
"""Set up a TTS platform."""
|
||||
@ -80,24 +89,30 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@attr.s
|
||||
@dataclass
|
||||
class SpeechMetadata:
|
||||
"""Metadata of audio stream."""
|
||||
|
||||
language: str = attr.ib()
|
||||
format: AudioFormats = attr.ib()
|
||||
codec: AudioCodecs = attr.ib()
|
||||
bit_rate: AudioBitRates = attr.ib(converter=int)
|
||||
sample_rate: AudioSampleRates = attr.ib(converter=int)
|
||||
channel: AudioChannels = attr.ib(converter=int)
|
||||
language: str
|
||||
format: AudioFormats
|
||||
codec: AudioCodecs
|
||||
bit_rate: AudioBitRates
|
||||
sample_rate: AudioSampleRates
|
||||
channel: AudioChannels
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Finish initializing the metadata."""
|
||||
self.bit_rate = AudioBitRates(int(self.bit_rate))
|
||||
self.sample_rate = AudioSampleRates(int(self.sample_rate))
|
||||
self.channel = AudioChannels(int(self.channel))
|
||||
|
||||
|
||||
@attr.s
|
||||
@dataclass
|
||||
class SpeechResult:
|
||||
"""Result of audio Speech."""
|
||||
|
||||
text: str | None = attr.ib()
|
||||
result: SpeechResultState = attr.ib()
|
||||
text: str | None
|
||||
result: SpeechResultState
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
@ -171,30 +186,6 @@ class SpeechToTextView(HomeAssistantView):
|
||||
"""Initialize a tts view."""
|
||||
self.providers = providers
|
||||
|
||||
@staticmethod
|
||||
def _metadata_from_header(request: web.Request) -> SpeechMetadata | None:
|
||||
"""Extract metadata from header.
|
||||
|
||||
X-Speech-Content: format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de
|
||||
"""
|
||||
try:
|
||||
data = request.headers[istr("X-Speech-Content")].split(";")
|
||||
except KeyError:
|
||||
_LOGGER.warning("Missing X-Speech-Content")
|
||||
return None
|
||||
|
||||
# Convert Header data
|
||||
args: dict[str, Any] = {}
|
||||
for value in data:
|
||||
value = value.strip()
|
||||
args[value.partition("=")[0]] = value.partition("=")[2]
|
||||
|
||||
try:
|
||||
return SpeechMetadata(**args)
|
||||
except TypeError as err:
|
||||
_LOGGER.warning("Wrong format of X-Speech-Content: %s", err)
|
||||
return None
|
||||
|
||||
async def post(self, request: web.Request, provider: str) -> web.Response:
|
||||
"""Convert Speech (audio) to text."""
|
||||
if provider not in self.providers:
|
||||
@ -202,9 +193,10 @@ class SpeechToTextView(HomeAssistantView):
|
||||
stt_provider: Provider = self.providers[provider]
|
||||
|
||||
# Get metadata
|
||||
metadata = self._metadata_from_header(request)
|
||||
if not metadata:
|
||||
raise HTTPBadRequest()
|
||||
try:
|
||||
metadata = metadata_from_header(request)
|
||||
except ValueError as err:
|
||||
raise HTTPBadRequest(text=str(err)) from err
|
||||
|
||||
# Check format
|
||||
if not stt_provider.check_metadata(metadata):
|
||||
@ -216,7 +208,7 @@ class SpeechToTextView(HomeAssistantView):
|
||||
)
|
||||
|
||||
# Return result
|
||||
return self.json(attr.asdict(result))
|
||||
return self.json(asdict(result))
|
||||
|
||||
async def get(self, request: web.Request, provider: str) -> web.Response:
|
||||
"""Return provider specific audio information."""
|
||||
@ -234,3 +226,47 @@ class SpeechToTextView(HomeAssistantView):
|
||||
"channels": stt_provider.supported_channels,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def metadata_from_header(request: web.Request) -> SpeechMetadata:
|
||||
"""Extract STT metadata from header.
|
||||
|
||||
X-Speech-Content: format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de
|
||||
"""
|
||||
try:
|
||||
data = request.headers[istr("X-Speech-Content")].split(";")
|
||||
except KeyError as err:
|
||||
raise ValueError("Missing X-Speech-Content header") from err
|
||||
|
||||
fields = (
|
||||
"language",
|
||||
"format",
|
||||
"codec",
|
||||
"bit_rate",
|
||||
"sample_rate",
|
||||
"channel",
|
||||
)
|
||||
|
||||
# Convert Header data
|
||||
args: dict[str, Any] = {}
|
||||
for entry in data:
|
||||
key, _, value = entry.strip().partition("=")
|
||||
if key not in fields:
|
||||
raise ValueError(f"Invalid field {key}")
|
||||
args[key] = value
|
||||
|
||||
for field in fields:
|
||||
if field not in args:
|
||||
raise ValueError(f"Missing {field} in X-Speech-Content header")
|
||||
|
||||
try:
|
||||
return SpeechMetadata(
|
||||
language=args["language"],
|
||||
format=args["format"],
|
||||
codec=args["codec"],
|
||||
bit_rate=args["bit_rate"],
|
||||
sample_rate=args["sample_rate"],
|
||||
channel=args["channel"],
|
||||
)
|
||||
except TypeError as err:
|
||||
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err
|
||||
|
@ -1,30 +1,155 @@
|
||||
"""Test STT component setup."""
|
||||
from asyncio import StreamReader
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from homeassistant.components import stt
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.stt import (
|
||||
AudioBitRates,
|
||||
AudioChannels,
|
||||
AudioCodecs,
|
||||
AudioFormats,
|
||||
AudioSampleRates,
|
||||
Provider,
|
||||
SpeechMetadata,
|
||||
SpeechResult,
|
||||
SpeechResultState,
|
||||
async_get_provider,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
||||
async def test_setup_comp(hass):
|
||||
"""Set up demo component."""
|
||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
|
||||
from tests.common import mock_platform
|
||||
|
||||
|
||||
async def test_demo_settings_not_exists(hass, hass_client):
|
||||
"""Test retrieve settings from demo provider."""
|
||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
|
||||
class TestProvider(Provider):
|
||||
"""Test provider."""
|
||||
|
||||
fail_process_audio = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Init test provider."""
|
||||
self.calls = []
|
||||
|
||||
@property
|
||||
def supported_languages(self):
|
||||
"""Return a list of supported languages."""
|
||||
return ["en"]
|
||||
|
||||
@property
|
||||
def supported_formats(self) -> list[AudioFormats]:
|
||||
"""Return a list of supported formats."""
|
||||
return [AudioFormats.WAV, AudioFormats.OGG]
|
||||
|
||||
@property
|
||||
def supported_codecs(self) -> list[AudioCodecs]:
|
||||
"""Return a list of supported codecs."""
|
||||
return [AudioCodecs.PCM, AudioCodecs.OPUS]
|
||||
|
||||
@property
|
||||
def supported_bit_rates(self) -> list[AudioBitRates]:
|
||||
"""Return a list of supported bitrates."""
|
||||
return [AudioBitRates.BITRATE_16]
|
||||
|
||||
@property
|
||||
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
||||
"""Return a list of supported samplerates."""
|
||||
return [AudioSampleRates.SAMPLERATE_16000]
|
||||
|
||||
@property
|
||||
def supported_channels(self) -> list[AudioChannels]:
|
||||
"""Return a list of supported channels."""
|
||||
return [AudioChannels.CHANNEL_MONO]
|
||||
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: SpeechMetadata, stream: StreamReader
|
||||
) -> SpeechResult:
|
||||
"""Process an audio stream."""
|
||||
self.calls.append((metadata, stream))
|
||||
if self.fail_process_audio:
|
||||
return SpeechResult(None, SpeechResultState.ERROR)
|
||||
|
||||
return SpeechResult("test", SpeechResultState.SUCCESS)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_provider():
|
||||
"""Test provider fixture."""
|
||||
return TestProvider()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def mock_setup(hass, test_provider):
|
||||
"""Set up a test provider."""
|
||||
mock_platform(
|
||||
hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=test_provider))
|
||||
)
|
||||
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
|
||||
|
||||
|
||||
async def test_get_provider_info(hass, hass_client):
|
||||
"""Test engine that doesn't exist."""
|
||||
client = await hass_client()
|
||||
response = await client.get("/api/stt/test")
|
||||
assert response.status == HTTPStatus.OK
|
||||
assert await response.json() == {
|
||||
"languages": ["en"],
|
||||
"formats": ["wav", "ogg"],
|
||||
"codecs": ["pcm", "opus"],
|
||||
"sample_rates": [16000],
|
||||
"bit_rates": [16],
|
||||
"channels": [1],
|
||||
}
|
||||
|
||||
response = await client.get("/api/stt/beer")
|
||||
|
||||
async def test_get_non_existing_provider_info(hass, hass_client):
|
||||
"""Test streaming to engine that doesn't exist."""
|
||||
client = await hass_client()
|
||||
response = await client.get("/api/stt/not_exist")
|
||||
assert response.status == HTTPStatus.NOT_FOUND
|
||||
|
||||
|
||||
async def test_demo_speech_not_exists(hass, hass_client):
|
||||
"""Test retrieve settings from demo provider."""
|
||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
|
||||
async def test_stream_audio(hass, hass_client):
|
||||
"""Test streaming audio and getting response."""
|
||||
client = await hass_client()
|
||||
response = await client.post(
|
||||
"/api/stt/test",
|
||||
headers={
|
||||
"X-Speech-Content": "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=en"
|
||||
},
|
||||
)
|
||||
assert response.status == HTTPStatus.OK
|
||||
assert await response.json() == {"text": "test", "result": "success"}
|
||||
|
||||
response = await client.post("/api/stt/beer", data=b"test")
|
||||
|
||||
assert response.status == HTTPStatus.NOT_FOUND
|
||||
@pytest.mark.parametrize(
|
||||
"header,status,error",
|
||||
(
|
||||
(None, 400, "Missing X-Speech-Content header"),
|
||||
(
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100; language=en",
|
||||
400,
|
||||
"100 is not a valid AudioChannels",
|
||||
),
|
||||
(
|
||||
"format=wav; codec=pcm; sample_rate=16000",
|
||||
400,
|
||||
"Missing language in X-Speech-Content header",
|
||||
),
|
||||
),
|
||||
)
|
||||
async def test_metadata_errors(hass, hass_client, header, status, error):
|
||||
"""Test metadata errors."""
|
||||
client = await hass_client()
|
||||
headers = {}
|
||||
if header:
|
||||
headers["X-Speech-Content"] = header
|
||||
|
||||
response = await client.post("/api/stt/test", headers=headers)
|
||||
assert response.status == status
|
||||
assert await response.text() == error
|
||||
|
||||
|
||||
async def test_get_provider(hass, test_provider):
|
||||
"""Test we can get STT providers."""
|
||||
assert test_provider == async_get_provider(hass, "test")
|
||||
|
Loading…
Reference in New Issue
Block a user