Clean up Speech-to-text integration and add tests (#79012)

This commit is contained in:
Paulus Schoutsen 2022-09-24 03:58:01 -04:00 committed by GitHub
parent 1b144c0e4d
commit 5774664234
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 215 additions and 54 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from dataclasses import asdict, dataclass
import logging
from typing import Any
@ -13,7 +14,6 @@ from aiohttp.web_exceptions import (
HTTPNotFound,
HTTPUnsupportedMediaType,
)
import attr
from homeassistant.components.http import HomeAssistantView
from homeassistant.core import HomeAssistant, callback
@ -34,9 +34,18 @@ from .const import (
_LOGGER = logging.getLogger(__name__)
@callback
def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider:
"""Return provider."""
if domain is None:
domain = next(iter(hass.data[DOMAIN]))
return hass.data[DOMAIN][domain]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT."""
providers = {}
providers = hass.data[DOMAIN] = {}
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
"""Set up a TTS platform."""
@ -80,24 +89,30 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
@attr.s
@dataclass
class SpeechMetadata:
"""Metadata of audio stream."""
language: str = attr.ib()
format: AudioFormats = attr.ib()
codec: AudioCodecs = attr.ib()
bit_rate: AudioBitRates = attr.ib(converter=int)
sample_rate: AudioSampleRates = attr.ib(converter=int)
channel: AudioChannels = attr.ib(converter=int)
language: str
format: AudioFormats
codec: AudioCodecs
bit_rate: AudioBitRates
sample_rate: AudioSampleRates
channel: AudioChannels
def __post_init__(self) -> None:
"""Finish initializing the metadata."""
self.bit_rate = AudioBitRates(int(self.bit_rate))
self.sample_rate = AudioSampleRates(int(self.sample_rate))
self.channel = AudioChannels(int(self.channel))
@attr.s
@dataclass
class SpeechResult:
"""Result of audio Speech."""
text: str | None = attr.ib()
result: SpeechResultState = attr.ib()
text: str | None
result: SpeechResultState
class Provider(ABC):
@ -171,30 +186,6 @@ class SpeechToTextView(HomeAssistantView):
"""Initialize a tts view."""
self.providers = providers
@staticmethod
def _metadata_from_header(request: web.Request) -> SpeechMetadata | None:
"""Extract metadata from header.
X-Speech-Content: format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de
"""
try:
data = request.headers[istr("X-Speech-Content")].split(";")
except KeyError:
_LOGGER.warning("Missing X-Speech-Content")
return None
# Convert Header data
args: dict[str, Any] = {}
for value in data:
value = value.strip()
args[value.partition("=")[0]] = value.partition("=")[2]
try:
return SpeechMetadata(**args)
except TypeError as err:
_LOGGER.warning("Wrong format of X-Speech-Content: %s", err)
return None
async def post(self, request: web.Request, provider: str) -> web.Response:
"""Convert Speech (audio) to text."""
if provider not in self.providers:
@ -202,9 +193,10 @@ class SpeechToTextView(HomeAssistantView):
stt_provider: Provider = self.providers[provider]
# Get metadata
metadata = self._metadata_from_header(request)
if not metadata:
raise HTTPBadRequest()
try:
metadata = metadata_from_header(request)
except ValueError as err:
raise HTTPBadRequest(text=str(err)) from err
# Check format
if not stt_provider.check_metadata(metadata):
@ -216,7 +208,7 @@ class SpeechToTextView(HomeAssistantView):
)
# Return result
return self.json(attr.asdict(result))
return self.json(asdict(result))
async def get(self, request: web.Request, provider: str) -> web.Response:
"""Return provider specific audio information."""
@ -234,3 +226,47 @@ class SpeechToTextView(HomeAssistantView):
"channels": stt_provider.supported_channels,
}
)
def metadata_from_header(request: web.Request) -> SpeechMetadata:
"""Extract STT metadata from header.
X-Speech-Content: format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de
"""
try:
data = request.headers[istr("X-Speech-Content")].split(";")
except KeyError as err:
raise ValueError("Missing X-Speech-Content header") from err
fields = (
"language",
"format",
"codec",
"bit_rate",
"sample_rate",
"channel",
)
# Convert Header data
args: dict[str, Any] = {}
for entry in data:
key, _, value = entry.strip().partition("=")
if key not in fields:
raise ValueError(f"Invalid field {key}")
args[key] = value
for field in fields:
if field not in args:
raise ValueError(f"Missing {field} in X-Speech-Content header")
try:
return SpeechMetadata(
language=args["language"],
format=args["format"],
codec=args["codec"],
bit_rate=args["bit_rate"],
sample_rate=args["sample_rate"],
channel=args["channel"],
)
except TypeError as err:
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err

View File

@ -1,30 +1,155 @@
"""Test STT component setup."""
from asyncio import StreamReader
from http import HTTPStatus
from unittest.mock import AsyncMock, Mock
from homeassistant.components import stt
import pytest
from homeassistant.components.stt import (
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
Provider,
SpeechMetadata,
SpeechResult,
SpeechResultState,
async_get_provider,
)
from homeassistant.setup import async_setup_component
async def test_setup_comp(hass):
"""Set up demo component."""
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
from tests.common import mock_platform
async def test_demo_settings_not_exists(hass, hass_client):
"""Test retrieve settings from demo provider."""
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
class TestProvider(Provider):
"""Test provider."""
fail_process_audio = False
def __init__(self) -> None:
"""Init test provider."""
self.calls = []
@property
def supported_languages(self):
"""Return a list of supported languages."""
return ["en"]
@property
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV, AudioFormats.OGG]
@property
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM, AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bitrates."""
return [AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported samplerates."""
return [AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: StreamReader
) -> SpeechResult:
"""Process an audio stream."""
self.calls.append((metadata, stream))
if self.fail_process_audio:
return SpeechResult(None, SpeechResultState.ERROR)
return SpeechResult("test", SpeechResultState.SUCCESS)
@pytest.fixture
def test_provider():
"""Test provider fixture."""
return TestProvider()
@pytest.fixture(autouse=True)
async def mock_setup(hass, test_provider):
"""Set up a test provider."""
mock_platform(
hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=test_provider))
)
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
async def test_get_provider_info(hass, hass_client):
"""Test engine that doesn't exist."""
client = await hass_client()
response = await client.get("/api/stt/test")
assert response.status == HTTPStatus.OK
assert await response.json() == {
"languages": ["en"],
"formats": ["wav", "ogg"],
"codecs": ["pcm", "opus"],
"sample_rates": [16000],
"bit_rates": [16],
"channels": [1],
}
response = await client.get("/api/stt/beer")
async def test_get_non_existing_provider_info(hass, hass_client):
"""Test streaming to engine that doesn't exist."""
client = await hass_client()
response = await client.get("/api/stt/not_exist")
assert response.status == HTTPStatus.NOT_FOUND
async def test_demo_speech_not_exists(hass, hass_client):
"""Test retrieve settings from demo provider."""
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
async def test_stream_audio(hass, hass_client):
"""Test streaming audio and getting response."""
client = await hass_client()
response = await client.post(
"/api/stt/test",
headers={
"X-Speech-Content": "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=en"
},
)
assert response.status == HTTPStatus.OK
assert await response.json() == {"text": "test", "result": "success"}
response = await client.post("/api/stt/beer", data=b"test")
assert response.status == HTTPStatus.NOT_FOUND
@pytest.mark.parametrize(
"header,status,error",
(
(None, 400, "Missing X-Speech-Content header"),
(
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100; language=en",
400,
"100 is not a valid AudioChannels",
),
(
"format=wav; codec=pcm; sample_rate=16000",
400,
"Missing language in X-Speech-Content header",
),
),
)
async def test_metadata_errors(hass, hass_client, header, status, error):
"""Test metadata errors."""
client = await hass_client()
headers = {}
if header:
headers["X-Speech-Content"] = header
response = await client.post("/api/stt/test", headers=headers)
assert response.status == status
assert await response.text() == error
async def test_get_provider(hass, test_provider):
"""Test we can get STT providers."""
assert test_provider == async_get_provider(hass, "test")