Minor adjustment of tts typing (#93450)

This commit is contained in:
Erik Montnemery 2023-05-24 21:02:55 +02:00 committed by GitHub
parent 68379dd55a
commit 30d9d7d905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 49 additions and 62 deletions

View File

@ -167,12 +167,9 @@ class AmazonPollyProvider(Provider):
self,
message: str,
language: str,
options: dict[str, Any] | None = None,
options: dict[str, Any],
) -> TtsAudioType:
"""Request TTS file from Polly."""
if options is None or language is None:
_LOGGER.debug("language and/or options were missing")
return None, None
voice_id = options.get(CONF_VOICE, self.default_voice)
voice_in_dict = self.all_voices[voice_id]
if language != voice_in_dict.get("LanguageCode"):

View File

@ -104,7 +104,7 @@ class BaiduTTSProvider(Provider):
"""Return a list of supported options."""
return SUPPORTED_OPTIONS
def get_tts_audio(self, message, language, options=None):
def get_tts_audio(self, message, language, options):
"""Load TTS from BaiduTTS."""
aip_speech = AipSpeech(
@ -113,14 +113,11 @@ class BaiduTTSProvider(Provider):
self._app_data["secretkey"],
)
if options is None:
result = aip_speech.synthesis(message, language, 1, self._speech_conf_data)
else:
speech_data = self._speech_conf_data.copy()
for key, value in options.items():
speech_data[_OPTIONS[key]] = value
speech_data = self._speech_conf_data.copy()
for key, value in options.items():
speech_data[_OPTIONS[key]] = value
result = aip_speech.synthesis(message, language, 1, speech_data)
result = aip_speech.synthesis(message, language, 1, speech_data)
if isinstance(result, dict):
_LOGGER.error(

View File

@ -134,12 +134,11 @@ class CloudProvider(Provider):
}
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load TTS from NabuCasa Cloud."""
# Process TTS
try:
assert options is not None
data = await self.cloud.voice.process_tts(
text=message,
language=language,

View File

@ -57,7 +57,7 @@ class DemoProvider(Provider):
return ["voice", "age"]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load TTS from demo."""
filename = os.path.join(os.path.dirname(__file__), "tts.mp3")

View File

@ -241,7 +241,7 @@ class GoogleCloudTTSProvider(Provider):
CONF_TEXT_TYPE: self._text_type,
}
async def async_get_tts_audio(self, message, language, options=None):
async def async_get_tts_audio(self, message, language, options):
"""Load TTS from google."""
options_schema = vol.Schema(
{

View File

@ -59,13 +59,13 @@ class GoogleProvider(Provider):
"""Return a list of supported options."""
return SUPPORT_OPTIONS
def get_tts_audio(self, message, language, options=None):
def get_tts_audio(self, message, language, options):
"""Load TTS from google."""
tld = self._tld
if language in MAP_LANG_TLD:
tld = MAP_LANG_TLD[language].tld
language = MAP_LANG_TLD[language].lang
if options is not None and "tld" in options:
if "tld" in options:
tld = options["tld"]
tts = gTTS(text=message, lang=language, tld=tld)
mp3_data = BytesIO()

View File

@ -80,7 +80,7 @@ class MaryTTSProvider(Provider):
"""Return a list of supported options."""
return SUPPORT_OPTIONS
def get_tts_audio(self, message, language, options=None):
def get_tts_audio(self, message, language, options):
"""Load TTS from MaryTTS."""
effects = options[CONF_EFFECT]

View File

@ -176,7 +176,7 @@ class MicrosoftProvider(Provider):
"""Return a dict include default options."""
return {CONF_GENDER: self._gender, CONF_TYPE: self._type}
def get_tts_audio(self, message, language, options=None):
def get_tts_audio(self, message, language, options):
"""Load TTS from Microsoft."""
if language is None:
language = self._lang

View File

@ -46,7 +46,7 @@ class PicoProvider(Provider):
"""Return list of supported languages."""
return SUPPORT_LANGUAGES
def get_tts_audio(self, message, language, options=None):
def get_tts_audio(self, message, language, options):
"""Load TTS using pico2wave."""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpf:
fname = tmpf.name

View File

@ -364,7 +364,7 @@ class TextToSpeechEntity(RestoreEntity):
@final
async def internal_async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Process an audio stream to TTS service.
@ -377,13 +377,13 @@ class TextToSpeechEntity(RestoreEntity):
)
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from the engine."""
raise NotImplementedError()
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from the engine.
@ -478,9 +478,9 @@ class SpeechManager:
def process_options(
self,
engine_instance: TextToSpeechEntity | Provider,
language: str | None = None,
options: dict | None = None,
) -> tuple[str, dict | None]:
language: str | None,
options: dict | None,
) -> tuple[str, dict[str, Any]]:
"""Validate and process options."""
# Languages
language = language or engine_instance.default_language
@ -491,23 +491,18 @@ class SpeechManager:
):
raise HomeAssistantError(f"Language '{language}' not supported")
# Options
if (default_options := engine_instance.default_options) and options:
merged_options = dict(default_options)
merged_options.update(options)
options = merged_options
if not options:
options = None if default_options is None else dict(default_options)
# Update default options with provided options
merged_options = dict(engine_instance.default_options or {})
merged_options.update(options or {})
if options is not None:
supported_options = engine_instance.supported_options or []
invalid_opts = [
opt_name for opt_name in options if opt_name not in supported_options
]
if invalid_opts:
raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
supported_options = engine_instance.supported_options or []
invalid_opts = [
opt_name for opt_name in merged_options if opt_name not in supported_options
]
if invalid_opts:
raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
return language, options
return language, merged_options
async def async_get_url_path(
self,
@ -602,7 +597,7 @@ class SpeechManager:
message: str,
cache: bool,
language: str,
options: dict | None,
options: dict[str, Any],
) -> str:
"""Receive TTS, store for view in cache and return filename.

View File

@ -240,13 +240,13 @@ class Provider:
return None
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from provider."""
raise NotImplementedError()
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from provider.

View File

@ -187,7 +187,7 @@ class VoiceRSSProvider(Provider):
"""Return list of supported languages."""
return SUPPORT_LANGUAGES
async def async_get_tts_audio(self, message, language, options=None):
async def async_get_tts_audio(self, message, language, options):
"""Load TTS from VoiceRSS."""
websession = async_get_clientsession(self.hass)
form_data = self._form_data.copy()

View File

@ -180,7 +180,7 @@ class WatsonTTSProvider(Provider):
"""Return a list of supported options."""
return [CONF_VOICE]
def get_tts_audio(self, message, language=None, options=None):
def get_tts_audio(self, message, language, options):
"""Request TTS file from Watson TTS."""
response = self.service.synthesize(
text=message, accept=self.output_format, voice=options[CONF_VOICE]

View File

@ -94,7 +94,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
"""Return a list of supported voices for a language."""
return self._voices.get(language)
async def async_get_tts_audio(self, message, language, options=None):
async def async_get_tts_audio(self, message, language, options):
"""Load TTS from UNIX socket."""
try:
async with AsyncTcpClient(self.service.host, self.service.port) as client:
@ -129,7 +129,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
except (OSError, WyomingError):
return (None, None)
if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"):
if options[tts.ATTR_AUDIO_OUTPUT] == "wav":
return ("wav", data)
# Raw output (convert to 16Khz, 16-bit mono)

View File

@ -114,11 +114,10 @@ class YandexSpeechKitProvider(Provider):
"""Return list of supported options."""
return SUPPORTED_OPTIONS
async def async_get_tts_audio(self, message, language, options=None):
async def async_get_tts_audio(self, message, language, options):
"""Load TTS from yandex."""
websession = async_get_clientsession(self.hass)
actual_language = language
options = options or {}
try:
async with async_timeout.timeout(10):

View File

@ -2449,7 +2449,7 @@ _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = {
),
TypeHintMatch(
function_name="get_tts_audio",
arg_types={1: "str", 2: "str", 3: "dict[str, Any] | None"},
arg_types={1: "str", 2: "str", 3: "dict[str, Any]"},
return_type="TtsAudioType",
has_async_counterpart=True,
),

View File

@ -127,7 +127,7 @@ class MockTTSProvider(tts.Provider):
return ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
"""Load TTS data."""
return ("mp3", b"")

View File

@ -76,7 +76,7 @@ class BaseProvider:
return ["voice", "age"]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load TTS dat."""
return ("mp3", b"")

View File

@ -1021,7 +1021,7 @@ class MockProviderBoom(MockProvider):
"""Mock provider that blows up."""
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
"""Load TTS dat."""
# This should not be called, data should be fetched from cache
@ -1032,7 +1032,7 @@ class MockEntityBoom(MockTTSEntity):
"""Mock entity that blows up."""
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
"""Load TTS dat."""
# This should not be called, data should be fetched from cache
@ -1116,7 +1116,7 @@ class MockProviderEmpty(MockProvider):
"""Mock provider with empty get_tts_audio."""
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
"""Load TTS dat."""
return (None, None)
@ -1126,7 +1126,7 @@ class MockEntityEmpty(MockTTSEntity):
"""Mock entity with empty get_tts_audio."""
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
"""Load TTS dat."""
return (None, None)
@ -1486,7 +1486,7 @@ async def test_legacy_fetching_in_async(
return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
return ("mp3", await tts_audio)
@ -1559,7 +1559,7 @@ async def test_fetching_in_async(
return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
return ("mp3", await tts_audio)

View File

@ -103,7 +103,7 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider)
message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Hello World"
assert language == "en_US"
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {}
# Pass language and options
mock_get_tts_audio.reset_mock()
@ -138,7 +138,7 @@ async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None
message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Hello World"
assert language == "en_US"
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {}
# Pass language and options
mock_get_tts_audio.reset_mock()