1
mirror of https://github.com/home-assistant/core synced 2024-08-15 18:25:44 +02:00

Clean up Speech-to-text integration and add tests (#79012)

This commit is contained in:
Paulus Schoutsen 2022-09-24 03:58:01 -04:00 committed by GitHub
parent 1b144c0e4d
commit 5774664234
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 215 additions and 54 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
from dataclasses import asdict, dataclass
import logging import logging
from typing import Any from typing import Any
@ -13,7 +14,6 @@ from aiohttp.web_exceptions import (
HTTPNotFound, HTTPNotFound,
HTTPUnsupportedMediaType, HTTPUnsupportedMediaType,
) )
import attr
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
@ -34,9 +34,18 @@ from .const import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@callback
def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider:
"""Return provider."""
if domain is None:
domain = next(iter(hass.data[DOMAIN]))
return hass.data[DOMAIN][domain]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT.""" """Set up STT."""
providers = {} providers = hass.data[DOMAIN] = {}
async def async_setup_platform(p_type, p_config=None, discovery_info=None): async def async_setup_platform(p_type, p_config=None, discovery_info=None):
"""Set up a TTS platform.""" """Set up a TTS platform."""
@ -80,24 +89,30 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
@attr.s @dataclass
class SpeechMetadata: class SpeechMetadata:
"""Metadata of audio stream.""" """Metadata of audio stream."""
language: str = attr.ib() language: str
format: AudioFormats = attr.ib() format: AudioFormats
codec: AudioCodecs = attr.ib() codec: AudioCodecs
bit_rate: AudioBitRates = attr.ib(converter=int) bit_rate: AudioBitRates
sample_rate: AudioSampleRates = attr.ib(converter=int) sample_rate: AudioSampleRates
channel: AudioChannels = attr.ib(converter=int) channel: AudioChannels
def __post_init__(self) -> None:
"""Finish initializing the metadata."""
self.bit_rate = AudioBitRates(int(self.bit_rate))
self.sample_rate = AudioSampleRates(int(self.sample_rate))
self.channel = AudioChannels(int(self.channel))
@attr.s @dataclass
class SpeechResult: class SpeechResult:
"""Result of audio Speech.""" """Result of audio Speech."""
text: str | None = attr.ib() text: str | None
result: SpeechResultState = attr.ib() result: SpeechResultState
class Provider(ABC): class Provider(ABC):
@ -171,30 +186,6 @@ class SpeechToTextView(HomeAssistantView):
"""Initialize a tts view.""" """Initialize a tts view."""
self.providers = providers self.providers = providers
@staticmethod
def _metadata_from_header(request: web.Request) -> SpeechMetadata | None:
"""Extract metadata from header.
X-Speech-Content: format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de
"""
try:
data = request.headers[istr("X-Speech-Content")].split(";")
except KeyError:
_LOGGER.warning("Missing X-Speech-Content")
return None
# Convert Header data
args: dict[str, Any] = {}
for value in data:
value = value.strip()
args[value.partition("=")[0]] = value.partition("=")[2]
try:
return SpeechMetadata(**args)
except TypeError as err:
_LOGGER.warning("Wrong format of X-Speech-Content: %s", err)
return None
async def post(self, request: web.Request, provider: str) -> web.Response: async def post(self, request: web.Request, provider: str) -> web.Response:
"""Convert Speech (audio) to text.""" """Convert Speech (audio) to text."""
if provider not in self.providers: if provider not in self.providers:
@ -202,9 +193,10 @@ class SpeechToTextView(HomeAssistantView):
stt_provider: Provider = self.providers[provider] stt_provider: Provider = self.providers[provider]
# Get metadata # Get metadata
metadata = self._metadata_from_header(request) try:
if not metadata: metadata = metadata_from_header(request)
raise HTTPBadRequest() except ValueError as err:
raise HTTPBadRequest(text=str(err)) from err
# Check format # Check format
if not stt_provider.check_metadata(metadata): if not stt_provider.check_metadata(metadata):
@ -216,7 +208,7 @@ class SpeechToTextView(HomeAssistantView):
) )
# Return result # Return result
return self.json(attr.asdict(result)) return self.json(asdict(result))
async def get(self, request: web.Request, provider: str) -> web.Response: async def get(self, request: web.Request, provider: str) -> web.Response:
"""Return provider specific audio information.""" """Return provider specific audio information."""
@ -234,3 +226,47 @@ class SpeechToTextView(HomeAssistantView):
"channels": stt_provider.supported_channels, "channels": stt_provider.supported_channels,
} }
) )
def metadata_from_header(request: web.Request) -> SpeechMetadata:
"""Extract STT metadata from header.
X-Speech-Content: format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de
"""
try:
data = request.headers[istr("X-Speech-Content")].split(";")
except KeyError as err:
raise ValueError("Missing X-Speech-Content header") from err
fields = (
"language",
"format",
"codec",
"bit_rate",
"sample_rate",
"channel",
)
# Convert Header data
args: dict[str, Any] = {}
for entry in data:
key, _, value = entry.strip().partition("=")
if key not in fields:
raise ValueError(f"Invalid field {key}")
args[key] = value
for field in fields:
if field not in args:
raise ValueError(f"Missing {field} in X-Speech-Content header")
try:
return SpeechMetadata(
language=args["language"],
format=args["format"],
codec=args["codec"],
bit_rate=args["bit_rate"],
sample_rate=args["sample_rate"],
channel=args["channel"],
)
except TypeError as err:
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err

View File

@ -1,30 +1,155 @@
"""Test STT component setup.""" """Test STT component setup."""
from asyncio import StreamReader
from http import HTTPStatus from http import HTTPStatus
from unittest.mock import AsyncMock, Mock
from homeassistant.components import stt import pytest
from homeassistant.components.stt import (
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
Provider,
SpeechMetadata,
SpeechResult,
SpeechResultState,
async_get_provider,
)
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import mock_platform
async def test_setup_comp(hass):
"""Set up demo component."""
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
async def test_demo_settings_not_exists(hass, hass_client): class TestProvider(Provider):
"""Test retrieve settings from demo provider.""" """Test provider."""
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
fail_process_audio = False
def __init__(self) -> None:
"""Init test provider."""
self.calls = []
@property
def supported_languages(self):
"""Return a list of supported languages."""
return ["en"]
@property
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV, AudioFormats.OGG]
@property
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM, AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bitrates."""
return [AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported samplerates."""
return [AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: StreamReader
) -> SpeechResult:
"""Process an audio stream."""
self.calls.append((metadata, stream))
if self.fail_process_audio:
return SpeechResult(None, SpeechResultState.ERROR)
return SpeechResult("test", SpeechResultState.SUCCESS)
@pytest.fixture
def test_provider():
"""Test provider fixture."""
return TestProvider()
@pytest.fixture(autouse=True)
async def mock_setup(hass, test_provider):
"""Set up a test provider."""
mock_platform(
hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=test_provider))
)
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
async def test_get_provider_info(hass, hass_client):
"""Test engine that doesn't exist."""
client = await hass_client() client = await hass_client()
response = await client.get("/api/stt/test")
assert response.status == HTTPStatus.OK
assert await response.json() == {
"languages": ["en"],
"formats": ["wav", "ogg"],
"codecs": ["pcm", "opus"],
"sample_rates": [16000],
"bit_rates": [16],
"channels": [1],
}
response = await client.get("/api/stt/beer")
async def test_get_non_existing_provider_info(hass, hass_client):
"""Test streaming to engine that doesn't exist."""
client = await hass_client()
response = await client.get("/api/stt/not_exist")
assert response.status == HTTPStatus.NOT_FOUND assert response.status == HTTPStatus.NOT_FOUND
async def test_demo_speech_not_exists(hass, hass_client): async def test_stream_audio(hass, hass_client):
"""Test retrieve settings from demo provider.""" """Test streaming audio and getting response."""
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
client = await hass_client() client = await hass_client()
response = await client.post(
"/api/stt/test",
headers={
"X-Speech-Content": "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=en"
},
)
assert response.status == HTTPStatus.OK
assert await response.json() == {"text": "test", "result": "success"}
response = await client.post("/api/stt/beer", data=b"test")
assert response.status == HTTPStatus.NOT_FOUND @pytest.mark.parametrize(
"header,status,error",
(
(None, 400, "Missing X-Speech-Content header"),
(
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100; language=en",
400,
"100 is not a valid AudioChannels",
),
(
"format=wav; codec=pcm; sample_rate=16000",
400,
"Missing language in X-Speech-Content header",
),
),
)
async def test_metadata_errors(hass, hass_client, header, status, error):
"""Test metadata errors."""
client = await hass_client()
headers = {}
if header:
headers["X-Speech-Content"] = header
response = await client.post("/api/stt/test", headers=headers)
assert response.status == status
assert await response.text() == error
async def test_get_provider(hass, test_provider):
"""Test we can get STT providers."""
assert test_provider == async_get_provider(hass, "test")