1
mirror of https://github.com/home-assistant/core synced 2024-08-02 23:40:32 +02:00
ha-core/homeassistant/components/tts/__init__.py
epenet 8bcf495caf
Import tts (#64212)
Co-authored-by: epenet <epenet@users.noreply.github.com>
2022-01-17 22:01:28 -08:00

684 lines
22 KiB
Python

"""Provide functionality for TTS."""
from __future__ import annotations
import asyncio
import functools as ft
import hashlib
from http import HTTPStatus
import io
import logging
import mimetypes
import os
import re
from typing import TYPE_CHECKING, Optional, cast
from aiohttp import web
import mutagen
from mutagen.id3 import ID3, TextFrame as ID3Text
import voluptuous as vol
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player.const import (
ATTR_MEDIA_CONTENT_ID,
ATTR_MEDIA_CONTENT_TYPE,
DOMAIN as DOMAIN_MP,
MEDIA_TYPE_MUSIC,
SERVICE_PLAY_MEDIA,
)
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_DESCRIPTION,
CONF_NAME,
CONF_PLATFORM,
PLATFORM_FORMAT,
)
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.network import get_url
from homeassistant.helpers.service import async_set_service_schema
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.loader import async_get_integration
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util.yaml import load_yaml
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
TtsAudioType = tuple[Optional[str], Optional[bytes]]
ATTR_CACHE = "cache"
ATTR_LANGUAGE = "language"
ATTR_MESSAGE = "message"
ATTR_OPTIONS = "options"
ATTR_PLATFORM = "platform"
BASE_URL_KEY = "tts_base_url"
CONF_BASE_URL = "base_url"
CONF_CACHE = "cache"
CONF_CACHE_DIR = "cache_dir"
CONF_LANG = "language"
CONF_SERVICE_NAME = "service_name"
CONF_TIME_MEMORY = "time_memory"
CONF_FIELDS = "fields"
DEFAULT_CACHE = True
DEFAULT_CACHE_DIR = "tts"
DEFAULT_TIME_MEMORY = 300
DOMAIN = "tts"
MEM_CACHE_FILENAME = "filename"
MEM_CACHE_VOICE = "voice"
SERVICE_CLEAR_CACHE = "clear_cache"
SERVICE_SAY = "say"
_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}")
KEY_PATTERN = "{0}_{1}_{2}_{3}"
def _deprecated_platform(value):
"""Validate if platform is deprecated."""
if value == "google":
raise vol.Invalid(
"google tts service has been renamed to google_translate,"
" please update your configuration."
)
return value
PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform),
vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean,
vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string,
vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All(
vol.Coerce(int), vol.Range(min=60, max=57600)
),
vol.Optional(CONF_BASE_URL): cv.string,
vol.Optional(CONF_SERVICE_NAME): cv.string,
}
)
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE.extend(PLATFORM_SCHEMA.schema)
SCHEMA_SERVICE_SAY = vol.Schema(
{
vol.Required(ATTR_MESSAGE): cv.string,
vol.Optional(ATTR_CACHE): cv.boolean,
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Optional(ATTR_LANGUAGE): cv.string,
vol.Optional(ATTR_OPTIONS): dict,
}
)
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up TTS."""
tts = SpeechManager(hass)
try:
conf = config[DOMAIN][0] if config.get(DOMAIN, []) else {}
use_cache = conf.get(CONF_CACHE, DEFAULT_CACHE)
cache_dir = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
time_memory = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
base_url = conf.get(CONF_BASE_URL)
hass.data[BASE_URL_KEY] = base_url
await tts.async_init_cache(use_cache, cache_dir, time_memory, base_url)
except (HomeAssistantError, KeyError):
_LOGGER.exception("Error on cache init")
return False
hass.http.register_view(TextToSpeechView(tts))
hass.http.register_view(TextToSpeechUrlView(tts))
# Load service descriptions from tts/services.yaml
integration = await async_get_integration(hass, DOMAIN)
services_yaml = integration.file_path / "services.yaml"
services_dict = cast(
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
)
async def async_setup_platform(
p_type: str,
p_config: ConfigType | None = None,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up a TTS platform."""
if p_config is None:
p_config = {}
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
if platform is None:
return
try:
if hasattr(platform, "async_get_engine"):
provider = await platform.async_get_engine(
hass, p_config, discovery_info
)
else:
provider = await hass.async_add_executor_job(
platform.get_engine, hass, p_config, discovery_info
)
if provider is None:
_LOGGER.error("Error setting up platform %s", p_type)
return
tts.async_register_engine(p_type, provider, p_config)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error setting up platform: %s", p_type)
return
async def async_say_handle(service: ServiceCall) -> None:
"""Service handle for say."""
entity_ids = service.data[ATTR_ENTITY_ID]
message = service.data[ATTR_MESSAGE]
cache = service.data.get(ATTR_CACHE)
language = service.data.get(ATTR_LANGUAGE)
options = service.data.get(ATTR_OPTIONS)
try:
url = await tts.async_get_url_path(
p_type, message, cache=cache, language=language, options=options
)
except HomeAssistantError as err:
_LOGGER.error("Error on init TTS: %s", err)
return
base = tts.base_url or get_url(hass)
url = base + url
data = {
ATTR_MEDIA_CONTENT_ID: url,
ATTR_MEDIA_CONTENT_TYPE: MEDIA_TYPE_MUSIC,
ATTR_ENTITY_ID: entity_ids,
}
await hass.services.async_call(
DOMAIN_MP,
SERVICE_PLAY_MEDIA,
data,
blocking=True,
context=service.context,
)
service_name = p_config.get(CONF_SERVICE_NAME, f"{p_type}_{SERVICE_SAY}")
hass.services.async_register(
DOMAIN, service_name, async_say_handle, schema=SCHEMA_SERVICE_SAY
)
# Register the service description
service_desc = {
CONF_NAME: f"Say an TTS message with {p_type}",
CONF_DESCRIPTION: f"Say something using text-to-speech on a media player with {p_type}.",
CONF_FIELDS: services_dict[SERVICE_SAY][CONF_FIELDS],
}
async_set_service_schema(hass, DOMAIN, service_name, service_desc)
setup_tasks = [
asyncio.create_task(async_setup_platform(p_type, p_config))
for p_type, p_config in config_per_platform(config, DOMAIN)
if p_type is not None
]
if setup_tasks:
await asyncio.wait(setup_tasks)
async def async_platform_discovered(platform, info):
"""Handle for discovered platform."""
await async_setup_platform(platform, discovery_info=info)
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
async def async_clear_cache_handle(service: ServiceCall) -> None:
"""Handle clear cache service call."""
await tts.async_clear_cache()
hass.services.async_register(
DOMAIN,
SERVICE_CLEAR_CACHE,
async_clear_cache_handle,
schema=SCHEMA_SERVICE_CLEAR_CACHE,
)
return True
def _hash_options(options: dict) -> str:
"""Hashes an options dictionary."""
opts_hash = hashlib.blake2s(digest_size=5)
for key, value in sorted(options.items()):
opts_hash.update(str(key).encode())
opts_hash.update(str(value).encode())
return opts_hash.hexdigest()
class SpeechManager:
"""Representation of a speech store."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a speech store."""
self.hass = hass
self.providers: dict[str, Provider] = {}
self.use_cache = DEFAULT_CACHE
self.cache_dir = DEFAULT_CACHE_DIR
self.time_memory = DEFAULT_TIME_MEMORY
self.base_url: str | None = None
self.file_cache: dict[str, str] = {}
self.mem_cache: dict[str, dict[str, str | bytes]] = {}
async def async_init_cache(
self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None
) -> None:
"""Init config folder and load file cache."""
self.use_cache = use_cache
self.time_memory = time_memory
self.base_url = base_url
try:
self.cache_dir = await self.hass.async_add_executor_job(
_init_tts_cache_dir, self.hass, cache_dir
)
except OSError as err:
raise HomeAssistantError(f"Can't init cache dir {err}") from err
try:
cache_files = await self.hass.async_add_executor_job(
_get_cache_files, self.cache_dir
)
except OSError as err:
raise HomeAssistantError(f"Can't read cache dir {err}") from err
if cache_files:
self.file_cache.update(cache_files)
async def async_clear_cache(self) -> None:
"""Read file cache and delete files."""
self.mem_cache = {}
def remove_files():
"""Remove files from filesystem."""
for filename in self.file_cache.values():
try:
os.remove(os.path.join(self.cache_dir, filename))
except OSError as err:
_LOGGER.warning("Can't remove cache file '%s': %s", filename, err)
await self.hass.async_add_executor_job(remove_files)
self.file_cache = {}
@callback
def async_register_engine(
self, engine: str, provider: Provider, config: ConfigType
) -> None:
"""Register a TTS provider."""
provider.hass = self.hass
if provider.name is None:
provider.name = engine
self.providers[engine] = provider
self.hass.config.components.add(
PLATFORM_FORMAT.format(domain=engine, platform=DOMAIN)
)
async def async_get_url_path(
self,
engine: str,
message: str,
cache: bool | None = None,
language: str | None = None,
options: dict | None = None,
) -> str:
"""Get URL for play message.
This method is a coroutine.
"""
provider = self.providers[engine]
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
use_cache = cache if cache is not None else self.use_cache
# Languages
language = language or provider.default_language
if language is None or language not in provider.supported_languages:
raise HomeAssistantError(f"Not supported language {language}")
# Options
if provider.default_options and options:
merged_options = provider.default_options.copy()
merged_options.update(options)
options = merged_options
options = options or provider.default_options
if options is not None:
invalid_opts = [
opt_name
for opt_name in options.keys()
if opt_name not in (provider.supported_options or [])
]
if invalid_opts:
raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
options_key = _hash_options(options)
else:
options_key = "-"
key = KEY_PATTERN.format(
msg_hash, language.replace("_", "-"), options_key, engine
).lower()
# Is speech already in memory
if key in self.mem_cache:
filename = cast(str, self.mem_cache[key][MEM_CACHE_FILENAME])
# Is file store in file cache
elif use_cache and key in self.file_cache:
filename = self.file_cache[key]
self.hass.async_create_task(self.async_file_to_mem(key))
# Load speech from provider into memory
else:
filename = await self.async_get_tts_audio(
engine, key, message, use_cache, language, options
)
return f"/api/tts_proxy/{filename}"
async def async_get_tts_audio(
self,
engine: str,
key: str,
message: str,
cache: bool,
language: str,
options: dict | None,
) -> str:
"""Receive TTS and store for view in cache.
This method is a coroutine.
"""
provider = self.providers[engine]
extension, data = await provider.async_get_tts_audio(message, language, options)
if data is None or extension is None:
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
# Create file infos
filename = f"{key}.{extension}".lower()
# Validate filename
if not _RE_VOICE_FILE.match(filename):
raise HomeAssistantError(
f"TTS filename '{filename}' from {engine} is invalid!"
)
# Save to memory
data = self.write_tags(filename, data, provider, message, language, options)
self._async_store_to_memcache(key, filename, data)
if cache:
self.hass.async_create_task(self.async_save_tts_audio(key, filename, data))
return filename
async def async_save_tts_audio(self, key: str, filename: str, data: bytes) -> None:
"""Store voice data to file and file_cache.
This method is a coroutine.
"""
voice_file = os.path.join(self.cache_dir, filename)
def save_speech() -> None:
"""Store speech to filesystem."""
with open(voice_file, "wb") as speech:
speech.write(data)
try:
await self.hass.async_add_executor_job(save_speech)
self.file_cache[key] = filename
except OSError as err:
_LOGGER.error("Can't write %s: %s", filename, err)
async def async_file_to_mem(self, key: str) -> None:
"""Load voice from file cache into memory.
This method is a coroutine.
"""
if not (filename := self.file_cache.get(key)):
raise HomeAssistantError(f"Key {key} not in file cache!")
voice_file = os.path.join(self.cache_dir, filename)
def load_speech() -> bytes:
"""Load a speech from filesystem."""
with open(voice_file, "rb") as speech:
return speech.read()
try:
data = await self.hass.async_add_executor_job(load_speech)
except OSError as err:
del self.file_cache[key]
raise HomeAssistantError(f"Can't read {voice_file}") from err
self._async_store_to_memcache(key, filename, data)
@callback
def _async_store_to_memcache(self, key: str, filename: str, data: bytes) -> None:
"""Store data to memcache and set timer to remove it."""
self.mem_cache[key] = {MEM_CACHE_FILENAME: filename, MEM_CACHE_VOICE: data}
@callback
def async_remove_from_mem() -> None:
"""Cleanup memcache."""
self.mem_cache.pop(key, None)
self.hass.loop.call_later(self.time_memory, async_remove_from_mem)
async def async_read_tts(self, filename: str) -> tuple[str | None, bytes]:
"""Read a voice file and return binary.
This method is a coroutine.
"""
if not (record := _RE_VOICE_FILE.match(filename.lower())):
raise HomeAssistantError("Wrong tts file format!")
key = KEY_PATTERN.format(
record.group(1), record.group(2), record.group(3), record.group(4)
)
if key not in self.mem_cache:
if key not in self.file_cache:
raise HomeAssistantError(f"{key} not in cache!")
await self.async_file_to_mem(key)
content, _ = mimetypes.guess_type(filename)
return content, cast(bytes, self.mem_cache[key][MEM_CACHE_VOICE])
@staticmethod
def write_tags(
filename: str,
data: bytes,
provider: Provider,
message: str,
language: str,
options: dict | None,
) -> bytes:
"""Write ID3 tags to file.
Async friendly.
"""
data_bytes = io.BytesIO(data)
data_bytes.name = filename
data_bytes.seek(0)
album = provider.name
artist = language
if options is not None and (voice := options.get("voice")) is not None:
artist = voice
try:
tts_file = mutagen.File(data_bytes)
if tts_file is not None:
if not tts_file.tags:
tts_file.add_tags()
if isinstance(tts_file.tags, ID3):
tts_file["artist"] = ID3Text(encoding=3, text=artist)
tts_file["album"] = ID3Text(encoding=3, text=album)
tts_file["title"] = ID3Text(encoding=3, text=message)
else:
tts_file["artist"] = artist
tts_file["album"] = album
tts_file["title"] = message
tts_file.save(data_bytes)
except mutagen.MutagenError as err:
_LOGGER.error("ID3 tag error: %s", err)
return data_bytes.getvalue()
class Provider:
"""Represent a single TTS provider."""
hass: HomeAssistant | None = None
name: str | None = None
@property
def default_language(self):
"""Return the default language."""
return None
@property
def supported_languages(self):
"""Return a list of supported languages."""
return None
@property
def supported_options(self):
"""Return a list of supported options like voice, emotionen."""
return None
@property
def default_options(self):
"""Return a dict include default options."""
return None
def get_tts_audio(
self, message: str, language: str, options: dict | None = None
) -> TtsAudioType:
"""Load tts audio file from provider."""
raise NotImplementedError()
async def async_get_tts_audio(
self, message: str, language: str, options: dict | None = None
) -> TtsAudioType:
"""Load tts audio file from provider.
Return a tuple of file extension and data as bytes.
"""
if TYPE_CHECKING:
assert self.hass
return await self.hass.async_add_executor_job(
ft.partial(self.get_tts_audio, message, language, options=options)
)
def _init_tts_cache_dir(hass: HomeAssistant, cache_dir: str) -> str:
"""Init cache folder."""
if not os.path.isabs(cache_dir):
cache_dir = hass.config.path(cache_dir)
if not os.path.isdir(cache_dir):
_LOGGER.info("Create cache dir %s", cache_dir)
os.mkdir(cache_dir)
return cache_dir
def _get_cache_files(cache_dir: str) -> dict[str, str]:
"""Return a dict of given engine files."""
cache = {}
folder_data = os.listdir(cache_dir)
for file_data in folder_data:
if record := _RE_VOICE_FILE.match(file_data):
key = KEY_PATTERN.format(
record.group(1), record.group(2), record.group(3), record.group(4)
)
cache[key.lower()] = file_data.lower()
return cache
class TextToSpeechUrlView(HomeAssistantView):
"""TTS view to get a url to a generated speech file."""
requires_auth = True
url = "/api/tts_get_url"
name = "api:tts:geturl"
def __init__(self, tts: SpeechManager) -> None:
"""Initialize a tts view."""
self.tts = tts
async def post(self, request: web.Request) -> web.Response:
"""Generate speech and provide url."""
try:
data = await request.json()
except ValueError:
return self.json_message("Invalid JSON specified", HTTPStatus.BAD_REQUEST)
if not data.get(ATTR_PLATFORM) and data.get(ATTR_MESSAGE):
return self.json_message(
"Must specify platform and message", HTTPStatus.BAD_REQUEST
)
p_type = data[ATTR_PLATFORM]
message = data[ATTR_MESSAGE]
cache = data.get(ATTR_CACHE)
language = data.get(ATTR_LANGUAGE)
options = data.get(ATTR_OPTIONS)
try:
path = await self.tts.async_get_url_path(
p_type, message, cache=cache, language=language, options=options
)
except HomeAssistantError as err:
_LOGGER.error("Error on init tts: %s", err)
return self.json({"error": err}, HTTPStatus.BAD_REQUEST)
base = self.tts.base_url or get_url(self.tts.hass)
url = base + path
return self.json({"url": url, "path": path})
class TextToSpeechView(HomeAssistantView):
"""TTS view to serve a speech audio."""
requires_auth = False
url = "/api/tts_proxy/{filename}"
name = "api:tts_speech"
def __init__(self, tts: SpeechManager) -> None:
"""Initialize a tts view."""
self.tts = tts
async def get(self, request: web.Request, filename: str) -> web.Response:
"""Start a get request."""
try:
content, data = await self.tts.async_read_tts(filename)
except HomeAssistantError as err:
_LOGGER.error("Error on load tts: %s", err)
return web.Response(status=HTTPStatus.NOT_FOUND)
return web.Response(body=data, content_type=content)
def get_base_url(hass: HomeAssistant) -> str:
"""Get base URL."""
return hass.data[BASE_URL_KEY] or get_url(hass)