Use fallback voice for selected language in cloud (#114246)

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Martin Hjelmare 2024-03-28 15:44:50 +01:00 committed by GitHub
parent 52ca14de48
commit f9aa7d34f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 115 additions and 18 deletions

View File

@ -140,7 +140,6 @@ class CloudTTSEntity(TextToSpeechEntity):
"""Return a dict include default options."""
return {
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
ATTR_VOICE: self._voice,
}
@property
@ -178,7 +177,18 @@ class CloudTTSEntity(TextToSpeechEntity):
gender: Gender | str | None = options.get(ATTR_GENDER)
gender = handle_deprecated_gender(self.hass, gender)
original_voice: str | None = options.get(ATTR_VOICE)
if original_voice is None and language == self._language:
original_voice = self._voice
voice = handle_deprecated_voice(self.hass, original_voice)
if voice not in TTS_VOICES[language]:
default_voice = TTS_VOICES[language][0]
_LOGGER.debug(
"Unsupported voice %s detected, falling back to default %s for %s",
voice,
default_voice,
language,
)
voice = default_voice
# Process TTS
try:
data = await self.cloud.voice.process_tts(
@ -237,7 +247,6 @@ class CloudProvider(Provider):
"""Return a dict include default options."""
return {
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
ATTR_VOICE: self._voice,
}
async def async_get_tts_audio(
@ -248,7 +257,18 @@ class CloudProvider(Provider):
gender: Gender | str | None = options.get(ATTR_GENDER)
gender = handle_deprecated_gender(self.hass, gender)
original_voice: str | None = options.get(ATTR_VOICE)
if original_voice is None and language == self._language:
original_voice = self._voice
voice = handle_deprecated_voice(self.hass, original_voice)
if voice not in TTS_VOICES[language]:
default_voice = TTS_VOICES[language][0]
_LOGGER.debug(
"Unsupported voice %s detected, falling back to default %s for %s",
voice,
default_voice,
language,
)
voice = default_voice
# Process TTS
try:
data = await self.cloud.voice.process_tts(

View File

@ -12,10 +12,20 @@ import voluptuous as vol
from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
from homeassistant.components.cloud import DOMAIN, const, tts
from homeassistant.components.tts import DOMAIN as TTS_DOMAIN
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
)
from homeassistant.components.tts import (
ATTR_LANGUAGE,
ATTR_MEDIA_PLAYER_ENTITY_ID,
ATTR_MESSAGE,
DOMAIN as TTS_DOMAIN,
)
from homeassistant.components.tts.helper import get_engine_instance
from homeassistant.config import async_process_ha_core_config
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_registry import EntityRegistry
from homeassistant.helpers.issue_registry import IssueRegistry, IssueSeverity
@ -23,6 +33,8 @@ from homeassistant.setup import async_setup_component
from . import PIPELINE_DATA
from tests.common import async_mock_service
from tests.components.tts.common import get_media_source_url
from tests.typing import ClientSessionGenerator
@ -120,13 +132,13 @@ async def test_prefs_default_voice(
assert engine is not None
# The platform config provider will be overridden by the discovery info provider.
assert engine.default_language == "en-US"
assert engine.default_options == {"audio_output": "mp3", "voice": "JennyNeural"}
assert engine.default_options == {"audio_output": "mp3"}
await set_cloud_prefs({"tts_default_voice": ("nl-NL", "MaartenNeural")})
await hass.async_block_till_done()
assert engine.default_language == "nl-NL"
assert engine.default_options == {"audio_output": "mp3", "voice": "MaartenNeural"}
assert engine.default_options == {"audio_output": "mp3"}
async def test_deprecated_platform_config(
@ -228,11 +240,11 @@ async def test_get_tts_audio(
"url": (
"http://example.local:8123/api/tts_proxy/"
"42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_5c97d21c48_{expected_url_suffix}.mp3"
f"_en-us_6e8b81ac47_{expected_url_suffix}.mp3"
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_5c97d21c48_{expected_url_suffix}.mp3"
f"_en-us_6e8b81ac47_{expected_url_suffix}.mp3"
),
}
await hass.async_block_till_done()
@ -242,6 +254,7 @@ async def test_get_tts_audio(
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
assert mock_process_tts.call_args.kwargs["gender"] is None
assert mock_process_tts.call_args.kwargs["voice"] == "JennyNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
@ -280,11 +293,11 @@ async def test_get_tts_audio_logged_out(
"url": (
"http://example.local:8123/api/tts_proxy/"
"42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_5c97d21c48_{expected_url_suffix}.mp3"
f"_en-us_6e8b81ac47_{expected_url_suffix}.mp3"
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_5c97d21c48_{expected_url_suffix}.mp3"
f"_en-us_6e8b81ac47_{expected_url_suffix}.mp3"
),
}
await hass.async_block_till_done()
@ -294,6 +307,7 @@ async def test_get_tts_audio_logged_out(
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
assert mock_process_tts.call_args.kwargs["gender"] is None
assert mock_process_tts.call_args.kwargs["voice"] == "JennyNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
@ -344,11 +358,11 @@ async def test_tts_entity(
"url": (
"http://example.local:8123/api/tts_proxy/"
"42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_5c97d21c48_{entity_id}.mp3"
f"_en-us_6e8b81ac47_{entity_id}.mp3"
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_5c97d21c48_{entity_id}.mp3"
f"_en-us_6e8b81ac47_{entity_id}.mp3"
),
}
await hass.async_block_till_done()
@ -358,6 +372,7 @@ async def test_tts_entity(
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
assert mock_process_tts.call_args.kwargs["gender"] is None
assert mock_process_tts.call_args.kwargs["voice"] == "JennyNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
state = hass.states.get(entity_id)
@ -632,11 +647,11 @@ async def test_deprecated_gender(
"url": (
"http://example.local:8123/api/tts_proxy/"
"42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_{language.lower()}_5c97d21c48_{expected_url_suffix}.mp3"
f"_{language.lower()}_6e8b81ac47_{expected_url_suffix}.mp3"
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_{language.lower()}_5c97d21c48_{expected_url_suffix}.mp3"
f"_{language.lower()}_6e8b81ac47_{expected_url_suffix}.mp3"
),
}
await hass.async_block_till_done()
@ -645,7 +660,7 @@ async def test_deprecated_gender(
assert mock_process_tts.call_args is not None
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == language
assert mock_process_tts.call_args.kwargs["voice"] == "JennyNeural"
assert mock_process_tts.call_args.kwargs["voice"] == "XiaoxiaoNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
issue = issue_registry.async_get_issue("cloud", "deprecated_gender")
assert issue is None
@ -662,11 +677,11 @@ async def test_deprecated_gender(
"url": (
"http://example.local:8123/api/tts_proxy/"
"42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_{language.lower()}_5dded72256_{expected_url_suffix}.mp3"
f"_{language.lower()}_dd0e95eb04_{expected_url_suffix}.mp3"
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_{language.lower()}_5dded72256_{expected_url_suffix}.mp3"
f"_{language.lower()}_dd0e95eb04_{expected_url_suffix}.mp3"
),
}
await hass.async_block_till_done()
@ -678,7 +693,7 @@ async def test_deprecated_gender(
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == language
assert mock_process_tts.call_args.kwargs["gender"] == gender_option
assert mock_process_tts.call_args.kwargs["voice"] == "JennyNeural"
assert mock_process_tts.call_args.kwargs["voice"] == "XiaoxiaoNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
issue = issue_registry.async_get_issue("cloud", issue_id)
assert issue is not None
@ -733,3 +748,65 @@ async def test_deprecated_gender(
}
assert not issue_registry.async_get_issue(DOMAIN, issue_id)
@pytest.mark.parametrize(
("service", "service_data"),
[
(
"speak",
{
ATTR_ENTITY_ID: "tts.home_assistant_cloud",
ATTR_LANGUAGE: "id-ID",
ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
ATTR_MESSAGE: "There is someone at the door.",
},
),
(
"cloud_say",
{
ATTR_ENTITY_ID: "media_player.something",
ATTR_LANGUAGE: "id-ID",
ATTR_MESSAGE: "There is someone at the door.",
},
),
],
)
async def test_tts_services(
hass: HomeAssistant,
cloud: MagicMock,
hass_client: ClientSessionGenerator,
service: str,
service_data: dict[str, Any],
) -> None:
"""Test tts services."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
mock_process_tts = AsyncMock(return_value=b"")
cloud.voice.process_tts = mock_process_tts
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
await hass.async_block_till_done()
await cloud.login("test-user", "test-pass")
client = await hass_client()
await hass.services.async_call(
domain=TTS_DOMAIN,
service=service,
service_data=service_data,
blocking=True,
)
assert len(calls) == 1
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
await hass.async_block_till_done()
response = await client.get(url)
assert response.status == HTTPStatus.OK
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == service_data[ATTR_LANGUAGE]
assert mock_process_tts.call_args.kwargs["voice"] == "GadisNeural"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"