1
mirror of https://github.com/home-assistant/core synced 2024-08-28 03:36:46 +02:00
ha-core/homeassistant/components/stt/legacy.py
Franck Nijhof 2d0325a471
Mark stt entity component as strictly typed (#106723)
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
2024-01-02 17:07:47 +01:00

154 lines
4.5 KiB
Python

"""Handle legacy speech-to-text platforms."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Coroutine
import logging
from typing import Any
from homeassistant.config import config_per_platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import discovery
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_prepare_setup_platform
from .const import (
DATA_PROVIDERS,
DOMAIN,
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
)
from .models import SpeechMetadata, SpeechResult
_LOGGER = logging.getLogger(__name__)
@callback
def async_default_provider(hass: HomeAssistant) -> str | None:
"""Return the domain of the default provider."""
return next(iter(hass.data[DATA_PROVIDERS]), None)
@callback
def async_get_provider(
hass: HomeAssistant, domain: str | None = None
) -> Provider | None:
"""Return provider."""
providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
if domain:
return providers.get(domain)
provider = async_default_provider(hass)
return providers[provider] if provider is not None else None
@callback
def async_setup_legacy(
hass: HomeAssistant, config: ConfigType
) -> list[Coroutine[Any, Any, None]]:
"""Set up legacy speech-to-text providers."""
providers = hass.data[DATA_PROVIDERS] = {}
async def async_setup_platform(
p_type: str,
p_config: ConfigType | None = None,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up an STT platform."""
if p_config is None:
p_config = {}
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
if platform is None:
_LOGGER.error("Unknown speech-to-text platform specified")
return
try:
provider = await platform.async_get_engine(hass, p_config, discovery_info)
provider.name = p_type
provider.hass = hass
providers[provider.name] = provider
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error setting up platform: %s", p_type)
return
# Add discovery support
async def async_platform_discovered(
platform: str, info: DiscoveryInfoType | None
) -> None:
"""Handle for discovered platform."""
await async_setup_platform(platform, discovery_info=info)
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
return [
async_setup_platform(p_type, p_config)
for p_type, p_config in config_per_platform(config, DOMAIN)
if p_type
]
class Provider(ABC):
"""Represent a single STT provider."""
hass: HomeAssistant | None = None
name: str | None = None
@property
@abstractmethod
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
@property
@abstractmethod
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
@property
@abstractmethod
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
@property
@abstractmethod
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bit rates."""
@property
@abstractmethod
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported sample rates."""
@property
@abstractmethod
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
@abstractmethod
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service.
Only streaming of content are allow!
"""
@callback
def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider."""
if (
metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates
or metadata.channel not in self.supported_channels
):
return False
return True