1
mirror of https://github.com/home-assistant/core synced 2024-08-28 03:36:46 +02:00

Move cloud stt engine to config entry (#99608)

* Migrate cloud stt to config entry

* Update default engine

* Test config flow

* Migrate pipelines with cloud stt engine to new engine id

* Fix test after rebase

* Update and add comment

* Remove cloud specifics from default stt engine

* Refactor cloud assist pipeline

* Fix cloud stt entity_id

* Try to wait for platforms before creating default pipeline

* Clean up import

* Move function in cloud assist pipeline

* Wait for tts platform loaded in stt migration

* Update deprecation dates

* Clean up not used fixture

* Add test for async_update_pipeline

* Define pipeline update interface better

* Remove leftover

* Fix tests

* Change default engine test

* Add test for missing stt entity during login

* Add and update comments

* Update config entry title
This commit is contained in:
Martin Hjelmare 2023-12-21 13:39:02 +01:00 committed by GitHub
parent f0104d6851
commit e1f31194f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 650 additions and 61 deletions

View File

@ -31,6 +31,7 @@ from .pipeline import (
async_get_pipeline,
async_get_pipelines,
async_setup_pipeline_store,
async_update_pipeline,
)
from .websocket_api import async_register_websocket_api
@ -40,6 +41,7 @@ __all__ = (
"async_get_pipelines",
"async_setup",
"async_pipeline_from_audio_stream",
"async_update_pipeline",
"AudioSettings",
"Pipeline",
"PipelineEvent",

View File

@ -43,6 +43,7 @@ from homeassistant.helpers.collection import (
)
from homeassistant.helpers.singleton import singleton
from homeassistant.helpers.storage import Store
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
from homeassistant.util import (
dt as dt_util,
language as language_util,
@ -276,6 +277,48 @@ def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]:
return pipeline_data.pipeline_store.data.values()
async def async_update_pipeline(
hass: HomeAssistant,
pipeline: Pipeline,
*,
conversation_engine: str | UndefinedType = UNDEFINED,
conversation_language: str | UndefinedType = UNDEFINED,
language: str | UndefinedType = UNDEFINED,
name: str | UndefinedType = UNDEFINED,
stt_engine: str | None | UndefinedType = UNDEFINED,
stt_language: str | None | UndefinedType = UNDEFINED,
tts_engine: str | None | UndefinedType = UNDEFINED,
tts_language: str | None | UndefinedType = UNDEFINED,
tts_voice: str | None | UndefinedType = UNDEFINED,
wake_word_entity: str | None | UndefinedType = UNDEFINED,
wake_word_id: str | None | UndefinedType = UNDEFINED,
) -> None:
"""Update a pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN]
updates: dict[str, Any] = pipeline.to_json()
updates.pop("id")
# Refactor this once we bump to Python 3.12
# and have https://peps.python.org/pep-0692/
for key, val in (
("conversation_engine", conversation_engine),
("conversation_language", conversation_language),
("language", language),
("name", name),
("stt_engine", stt_engine),
("stt_language", stt_language),
("tts_engine", tts_engine),
("tts_language", tts_language),
("tts_voice", tts_voice),
("wake_word_entity", wake_word_entity),
("wake_word_id", wake_word_id),
):
if val is not UNDEFINED:
updates[key] = val
await pipeline_data.pipeline_store.async_update_item(pipeline.id, updates)
class PipelineEventType(StrEnum):
"""Event types emitted during a pipeline run."""

View File

@ -10,6 +10,7 @@ from hass_nabucasa import Cloud
import voluptuous as vol
from homeassistant.components import alexa, google_assistant
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
CONF_DESCRIPTION,
CONF_MODE,
@ -51,6 +52,7 @@ from .const import (
CONF_SERVICEHANDLERS_SERVER,
CONF_THINGTALK_SERVER,
CONF_USER_POOL_ID,
DATA_PLATFORMS_SETUP,
DOMAIN,
MODE_DEV,
MODE_PROD,
@ -61,6 +63,8 @@ from .subscription import async_subscription_info
DEFAULT_MODE = MODE_PROD
PLATFORMS = [Platform.STT]
SERVICE_REMOTE_CONNECT = "remote_connect"
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
@ -262,6 +266,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async_manage_legacy_subscription_issue(hass, subscription_info)
loaded = False
stt_platform_loaded = asyncio.Event()
tts_platform_loaded = asyncio.Event()
hass.data[DATA_PLATFORMS_SETUP] = {
Platform.STT: stt_platform_loaded,
Platform.TTS: tts_platform_loaded,
}
async def _on_start() -> None:
"""Discover platforms."""
@ -272,15 +282,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return
loaded = True
stt_platform_loaded = asyncio.Event()
tts_platform_loaded = asyncio.Event()
stt_info = {"platform_loaded": stt_platform_loaded}
tts_info = {"platform_loaded": tts_platform_loaded}
await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config)
await async_load_platform(hass, Platform.STT, DOMAIN, stt_info, config)
await async_load_platform(hass, Platform.TTS, DOMAIN, tts_info, config)
await asyncio.gather(stt_platform_loaded.wait(), tts_platform_loaded.wait())
await tts_platform_loaded.wait()
# The config entry should be loaded after the legacy tts platform is loaded
# to make sure that the tts integration is setup before we try to migrate
# old assist pipelines in the cloud stt entity.
await hass.config_entries.flow.async_init(DOMAIN, context={"source": "system"})
async def _on_connect() -> None:
"""Handle cloud connect."""
@ -304,7 +315,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
cloud.register_on_initialized(_on_initialized)
await cloud.initialize()
await http_api.async_setup(hass)
http_api.async_setup(hass)
account_link.async_setup(hass)
@ -340,3 +351,19 @@ def _remote_handle_prefs_updated(cloud: Cloud[CloudClient]) -> None:
await cloud.remote.disconnect()
cloud.client.prefs.async_listen_updates(remote_prefs_updated)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
stt_platform_loaded.set()
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
return unload_ok

View File

@ -1,31 +1,48 @@
"""Handle Cloud assist pipelines."""
import asyncio
from homeassistant.components.assist_pipeline import (
async_create_default_pipeline,
async_get_pipelines,
async_setup_pipeline_store,
async_update_pipeline,
)
from homeassistant.components.conversation import HOME_ASSISTANT_AGENT
from homeassistant.components.stt import DOMAIN as STT_DOMAIN
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
import homeassistant.helpers.entity_registry as er
from .const import DOMAIN
from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID
async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
"""Create a cloud assist pipeline."""
# Wait for stt and tts platforms to set up before creating the pipeline.
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
await asyncio.gather(*(event.wait() for event in platforms_setup.values()))
# Make sure the pipeline store is loaded, needed because assist_pipeline
# is an after dependency of cloud
await async_setup_pipeline_store(hass)
entity_registry = er.async_get(hass)
new_stt_engine_id = entity_registry.async_get_entity_id(
STT_DOMAIN, DOMAIN, STT_ENTITY_UNIQUE_ID
)
if new_stt_engine_id is None:
# If there's no cloud stt entity, we can't create a cloud pipeline.
return None
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
"""Return the ID of a cloud-enabled assist pipeline or None.
Check if a cloud pipeline already exists with
legacy cloud engine id.
Check if a cloud pipeline already exists with either
legacy or current cloud engine ids.
"""
for pipeline in async_get_pipelines(hass):
if (
pipeline.conversation_engine == HOME_ASSISTANT_AGENT
and pipeline.stt_engine == DOMAIN
and pipeline.stt_engine in (DOMAIN, new_stt_engine_id)
and pipeline.tts_engine == DOMAIN
):
return pipeline.id
@ -34,7 +51,7 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
if (cloud_assist_pipeline(hass)) is not None or (
cloud_pipeline := await async_create_default_pipeline(
hass,
stt_engine_id=DOMAIN,
stt_engine_id=new_stt_engine_id,
tts_engine_id=DOMAIN,
pipeline_name="Home Assistant Cloud",
)
@ -42,3 +59,27 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
return None
return cloud_pipeline.id
async def async_migrate_cloud_pipeline_stt_engine(
hass: HomeAssistant, stt_engine_id: str
) -> None:
"""Migrate the speech-to-text engine in the cloud assist pipeline."""
# Migrate existing pipelines with cloud stt to use new cloud stt engine id.
# Added in 2024.01.0. Can be removed in 2025.01.0.
# We need to make sure that tts is loaded before this migration.
# Assist pipeline will call default engine of tts when setting up the store.
# Wait for the tts platform loaded event here.
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
await platforms_setup[Platform.TTS].wait()
# Make sure the pipeline store is loaded, needed because assist_pipeline
# is an after dependency of cloud
await async_setup_pipeline_store(hass)
pipelines = async_get_pipelines(hass)
for pipeline in pipelines:
if pipeline.stt_engine != DOMAIN:
continue
await async_update_pipeline(hass, pipeline, stt_engine=stt_engine_id)

View File

@ -0,0 +1,23 @@
"""Config flow for the Cloud integration."""
from __future__ import annotations
from typing import Any
from homeassistant.config_entries import ConfigFlow
from homeassistant.data_entry_flow import FlowResult
from .const import DOMAIN
class CloudConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for the Cloud integration."""
VERSION = 1
async def async_step_system(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the system step."""
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
return self.async_create_entry(title="Home Assistant Cloud", data={})

View File

@ -1,5 +1,6 @@
"""Constants for the cloud component."""
DOMAIN = "cloud"
DATA_PLATFORMS_SETUP = "cloud_platforms_setup"
REQUEST_TIMEOUT = 10
PREF_ENABLE_ALEXA = "alexa_enabled"
@ -64,3 +65,5 @@ MODE_DEV = "development"
MODE_PROD = "production"
DISPATCHER_REMOTE_UPDATE = "cloud_remote_update"
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"

View File

@ -28,7 +28,7 @@ from homeassistant.components.homeassistant import exposed_entities
from homeassistant.components.http import HomeAssistantView, require_admin
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.util.location import async_detect_location_info
@ -66,7 +66,8 @@ _CLOUD_ERRORS: dict[type[Exception], tuple[HTTPStatus, str]] = {
}
async def async_setup(hass: HomeAssistant) -> None:
@callback
def async_setup(hass: HomeAssistant) -> None:
"""Initialize the HTTP API."""
websocket_api.async_register_command(hass, websocket_cloud_status)
websocket_api.async_register_command(hass, websocket_subscription)

View File

@ -1,4 +1,10 @@
{
"config": {
"step": {},
"abort": {
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]"
}
},
"system_health": {
"info": {
"can_reach_cert_server": "Reach Certificate Server",

View File

@ -13,37 +13,38 @@ from homeassistant.components.stt import (
AudioCodecs,
AudioFormats,
AudioSampleRates,
Provider,
SpeechMetadata,
SpeechResult,
SpeechResultState,
SpeechToTextEntity,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .assist_pipeline import async_migrate_cloud_pipeline_stt_engine
from .client import CloudClient
from .const import DOMAIN
from .const import DOMAIN, STT_ENTITY_UNIQUE_ID
_LOGGER = logging.getLogger(__name__)
async def async_get_engine(
async def async_setup_entry(
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> CloudProvider:
"""Set up Cloud speech component."""
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Home Assistant Cloud speech platform via config entry."""
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
cloud_provider = CloudProvider(cloud)
if discovery_info is not None:
discovery_info["platform_loaded"].set()
return cloud_provider
async_add_entities([CloudProviderEntity(cloud)])
class CloudProvider(Provider):
class CloudProviderEntity(SpeechToTextEntity):
"""NabuCasa speech API provider."""
_attr_name = "Home Assistant Cloud"
_attr_unique_id = STT_ENTITY_UNIQUE_ID
def __init__(self, cloud: Cloud[CloudClient]) -> None:
"""Home Assistant NabuCasa Speech to text."""
self.cloud = cloud
@ -78,6 +79,10 @@ class CloudProvider(Provider):
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
async def async_added_to_hass(self) -> None:
"""Run when entity is about to be added to hass."""
await async_migrate_cloud_pipeline_stt_engine(self.hass, self.entity_id)
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:

View File

@ -29,9 +29,6 @@ _LOGGER = logging.getLogger(__name__)
@callback
def async_default_provider(hass: HomeAssistant) -> str | None:
"""Return the domain of the default provider."""
if "cloud" in hass.data[DATA_PROVIDERS]:
return "cloud"
return next(iter(hass.data[DATA_PROVIDERS]), None)

View File

@ -1,4 +1,5 @@
"""Websocket tests for Voice Assistant integration."""
from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import ANY, patch
@ -16,6 +17,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
async_create_default_pipeline,
async_get_pipeline,
async_get_pipelines,
async_update_pipeline,
)
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
@ -26,6 +28,13 @@ from .conftest import MockSttProvider, MockTTSProvider
from tests.common import flush_store
@pytest.fixture(autouse=True)
async def delay_save_fixture() -> AsyncGenerator[None, None]:
"""Load the homeassistant integration."""
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
yield
@pytest.fixture(autouse=True)
async def load_homeassistant(hass) -> None:
"""Load the homeassistant integration."""
@ -478,3 +487,125 @@ async def test_default_pipeline_unsupported_tts_language(
wake_word_entity=None,
wake_word_id=None,
)
async def test_update_pipeline(
hass: HomeAssistant,
hass_storage: dict[str, Any],
) -> None:
"""Test async_update_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
assert pipelines == [
Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=ANY,
language="en",
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
]
pipeline = pipelines[0]
await async_update_pipeline(
hass,
pipeline,
conversation_engine="homeassistant_1",
conversation_language="de",
language="de",
name="Home Assistant 1",
stt_engine="stt.test_1",
stt_language="de",
tts_engine="test_1",
tts_language="de",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
pipeline = pipelines[0]
assert pipelines == [
Pipeline(
conversation_engine="homeassistant_1",
conversation_language="de",
id=pipeline.id,
language="de",
name="Home Assistant 1",
stt_engine="stt.test_1",
stt_language="de",
tts_engine="test_1",
tts_language="de",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
]
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
"conversation_engine": "homeassistant_1",
"conversation_language": "de",
"id": pipeline.id,
"language": "de",
"name": "Home Assistant 1",
"stt_engine": "stt.test_1",
"stt_language": "de",
"tts_engine": "test_1",
"tts_language": "de",
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
}
await async_update_pipeline(
hass,
pipeline,
stt_engine="stt.test_2",
stt_language="en",
tts_engine="test_2",
tts_language="en",
)
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
assert pipelines == [
Pipeline(
conversation_engine="homeassistant_1",
conversation_language="de",
id=pipeline.id,
language="de",
name="Home Assistant 1",
stt_engine="stt.test_2",
stt_language="en",
tts_engine="test_2",
tts_language="en",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
]
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
"conversation_engine": "homeassistant_1",
"conversation_language": "de",
"id": pipeline.id,
"language": "de",
"name": "Home Assistant 1",
"stt_engine": "stt.test_2",
"stt_language": "en",
"tts_engine": "test_2",
"tts_language": "en",
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
}

View File

@ -1,7 +1,8 @@
"""Tests for the cloud component."""
from unittest.mock import AsyncMock, patch
from hass_nabucasa import Cloud
from homeassistant.components import cloud
from homeassistant.components.cloud import const, prefs as cloud_prefs
from homeassistant.setup import async_setup_component
@ -14,7 +15,7 @@ async def mock_cloud(hass, config=None):
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, cloud.DOMAIN, {"cloud": config or {}})
cloud_inst = hass.data["cloud"]
cloud_inst: Cloud = hass.data["cloud"]
with patch("hass_nabucasa.Cloud.run_executor", AsyncMock(return_value=None)):
await cloud_inst.initialize()

View File

@ -0,0 +1,40 @@
"""Test the Home Assistant Cloud config flow."""
from unittest.mock import patch
from homeassistant.components.cloud.const import DOMAIN
from homeassistant.core import HomeAssistant
from tests.common import MockConfigEntry
async def test_config_flow(hass: HomeAssistant) -> None:
"""Test create cloud entry."""
with patch(
"homeassistant.components.cloud.async_setup", return_value=True
) as mock_setup, patch(
"homeassistant.components.cloud.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": "system"}
)
assert result["type"] == "create_entry"
assert result["title"] == "Home Assistant Cloud"
assert result["data"] == {}
await hass.async_block_till_done()
assert len(mock_setup.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
async def test_multiple_entries(hass: HomeAssistant) -> None:
"""Test creating multiple cloud entries."""
config_entry = MockConfigEntry(domain=DOMAIN)
config_entry.add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": "system"}
)
assert result["type"] == "abort"
assert result["reason"] == "single_instance_allowed"

View File

@ -46,6 +46,26 @@ PIPELINE_DATA_LEGACY = {
"preferred_item": "12345",
}
PIPELINE_DATA = {
"items": [
{
"conversation_engine": "homeassistant",
"conversation_language": "language_1",
"id": "12345",
"language": "language_1",
"name": "Home Assistant Cloud",
"stt_engine": "stt.home_assistant_cloud",
"stt_language": "language_1",
"tts_engine": "cloud",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
},
],
"preferred_item": "12345",
}
PIPELINE_DATA_OTHER = {
"items": [
{
@ -127,7 +147,34 @@ async def test_google_actions_sync_fails(
assert mock_request_sync.call_count == 1
@pytest.mark.parametrize("pipeline_data", [PIPELINE_DATA_LEGACY])
async def test_login_view_missing_stt_entity(
hass: HomeAssistant,
setup_cloud: None,
entity_registry: er.EntityRegistry,
hass_client: ClientSessionGenerator,
) -> None:
"""Test logging in when the cloud stt entity is missing."""
# Make sure that the cloud stt entity does not exist.
entity_registry.async_remove("stt.home_assistant_cloud")
await hass.async_block_till_done()
cloud_client = await hass_client()
# We assume the user needs to login again for some reason.
with patch(
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
) as create_pipeline_mock:
req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)
assert req.status == HTTPStatus.OK
result = await req.json()
assert result == {"success": True, "cloud_pipeline": None}
create_pipeline_mock.assert_not_awaited()
@pytest.mark.parametrize("pipeline_data", [PIPELINE_DATA, PIPELINE_DATA_LEGACY])
async def test_login_view_existing_pipeline(
hass: HomeAssistant,
cloud: MagicMock,
@ -195,7 +242,7 @@ async def test_login_view_create_pipeline(
assert result == {"success": True, "cloud_pipeline": "12345"}
create_pipeline_mock.assert_awaited_once_with(
hass,
stt_engine_id="cloud",
stt_engine_id="stt.home_assistant_cloud",
tts_engine_id="cloud",
pipeline_name="Home Assistant Cloud",
)
@ -234,7 +281,7 @@ async def test_login_view_create_pipeline_fail(
assert result == {"success": True, "cloud_pipeline": None}
create_pipeline_mock.assert_awaited_once_with(
hass,
stt_engine_id="cloud",
stt_engine_id="stt.home_assistant_cloud",
tts_engine_id="cloud",
pipeline_name="Home Assistant Cloud",
)

View File

@ -0,0 +1,201 @@
"""Test the speech-to-text platform for the cloud integration."""
from collections.abc import AsyncGenerator
from copy import deepcopy
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from hass_nabucasa.voice import STTResponse, VoiceError
import pytest
from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
from homeassistant.components.cloud import DOMAIN
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.typing import ClientSessionGenerator
PIPELINE_DATA = {
"items": [
{
"conversation_engine": "conversation_engine_1",
"conversation_language": "language_1",
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
"language": "language_1",
"name": "Home Assistant Cloud",
"stt_engine": "cloud",
"stt_language": "language_1",
"tts_engine": "cloud",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
},
{
"conversation_engine": "conversation_engine_2",
"conversation_language": "language_2",
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"stt_language": "language_2",
"tts_engine": "tts_engine_2",
"tts_language": "language_2",
"tts_voice": "The Voice",
"wake_word_entity": None,
"wake_word_id": None,
},
{
"conversation_engine": "conversation_engine_3",
"conversation_language": "language_3",
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
"language": "language_3",
"name": "name_3",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
},
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
}
@pytest.fixture(autouse=True)
async def load_homeassistant(hass: HomeAssistant) -> None:
"""Load the homeassistant integration."""
assert await async_setup_component(hass, "homeassistant", {})
@pytest.fixture(autouse=True)
async def delay_save_fixture() -> AsyncGenerator[None, None]:
"""Load the homeassistant integration."""
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
yield
@pytest.mark.parametrize(
("mock_process_stt", "expected_response_data"),
[
(
AsyncMock(return_value=STTResponse(True, "Turn the Kitchen Lights on")),
{"text": "Turn the Kitchen Lights on", "result": "success"},
),
(AsyncMock(side_effect=VoiceError("Boom!")), {"text": None, "result": "error"}),
],
)
async def test_cloud_speech(
hass: HomeAssistant,
cloud: MagicMock,
hass_client: ClientSessionGenerator,
mock_process_stt: AsyncMock,
expected_response_data: dict[str, Any],
) -> None:
"""Test cloud text-to-speech."""
cloud.voice.process_stt = mock_process_stt
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done()
on_start_callback = cloud.register_on_start.call_args[0][0]
await on_start_callback()
state = hass.states.get("stt.home_assistant_cloud")
assert state
assert state.state == STATE_UNKNOWN
client = await hass_client()
response = await client.post(
"/api/stt/stt.home_assistant_cloud",
headers={
"X-Speech-Content": (
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
" language=de-DE"
)
},
data=b"Test",
)
response_data = await response.json()
assert mock_process_stt.call_count == 1
assert (
mock_process_stt.call_args.kwargs["content_type"]
== "audio/wav; codecs=audio/pcm; samplerate=16000"
)
assert mock_process_stt.call_args.kwargs["language"] == "de-DE"
assert response.status == HTTPStatus.OK
assert response_data == expected_response_data
state = hass.states.get("stt.home_assistant_cloud")
assert state
assert state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
async def test_migrating_pipelines(
hass: HomeAssistant,
cloud: MagicMock,
hass_client: ClientSessionGenerator,
hass_storage: dict[str, Any],
) -> None:
"""Test migrating pipelines when cloud stt entity is added."""
cloud.voice.process_stt = AsyncMock(
return_value=STTResponse(True, "Turn the Kitchen Lights on")
)
hass_storage[STORAGE_KEY] = {
"version": 1,
"minor_version": 1,
"key": "assist_pipeline.pipelines",
"data": deepcopy(PIPELINE_DATA),
}
assert await async_setup_component(hass, "assist_pipeline", {})
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done()
on_start_callback = cloud.register_on_start.call_args[0][0]
await on_start_callback()
await hass.async_block_till_done()
state = hass.states.get("stt.home_assistant_cloud")
assert state
assert state.state == STATE_UNKNOWN
# The stt engine should be updated to the new cloud stt engine id.
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"]
== "stt.home_assistant_cloud"
)
# The other items should stay the same.
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_engine"]
== "conversation_engine_1"
)
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_language"]
== "language_1"
)
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["id"]
== "01GX8ZWBAQYWNB1XV3EXEZ75DY"
)
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["language"] == "language_1"
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud"
)
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1"
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == "cloud"
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1"
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"]
== "Arnold Schwarzenegger"
)
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_entity"] is None
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_id"] is None
assert hass_storage[STORAGE_KEY]["data"]["items"][1] == PIPELINE_DATA["items"][1]
assert hass_storage[STORAGE_KEY]["data"]["items"][2] == PIPELINE_DATA["items"][2]

View File

@ -121,12 +121,20 @@ class STTFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture(autouse=True)
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
"""Mock config flow."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
@pytest.fixture(name="config_flow_test_domain")
def config_flow_test_domain_fixture() -> str:
"""Test domain fixture."""
return TEST_DOMAIN
with mock_config_flow(TEST_DOMAIN, STTFlow):
@pytest.fixture(autouse=True)
def config_flow_fixture(
hass: HomeAssistant, config_flow_test_domain: str
) -> Generator[None, None, None]:
"""Mock config flow."""
mock_platform(hass, f"{config_flow_test_domain}.config_flow")
with mock_config_flow(config_flow_test_domain, STTFlow):
yield
@ -137,6 +145,7 @@ async def setup_fixture(
request: pytest.FixtureRequest,
) -> MockProvider | MockProviderEntity:
"""Set up the test environment."""
provider: MockProvider | MockProviderEntity
if request.param == "mock_setup":
provider = MockProvider()
await mock_setup(hass, tmp_path, provider)
@ -166,7 +175,10 @@ async def mock_setup(
async def mock_config_entry_setup(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
test_domain: str = TEST_DOMAIN,
) -> MockConfigEntry:
"""Set up a test provider via config entry."""
@ -187,7 +199,7 @@ async def mock_config_entry_setup(
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
test_domain,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
@ -201,9 +213,9 @@ async def mock_config_entry_setup(
"""Set up test stt platform via config entry."""
async_add_entities([mock_provider_entity])
mock_stt_entity_platform(hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform)
mock_stt_entity_platform(hass, tmp_path, test_domain, async_setup_entry_platform)
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
config_entry = MockConfigEntry(domain=test_domain)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
@ -456,7 +468,11 @@ async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
assert async_default_engine(hass) is None
async def test_default_engine(hass: HomeAssistant, tmp_path: Path) -> None:
async def test_default_engine(
hass: HomeAssistant,
tmp_path: Path,
mock_provider: MockProvider,
) -> None:
"""Test async_default_engine."""
mock_stt_platform(
hass,
@ -479,26 +495,31 @@ async def test_default_engine_entity(
assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
async def test_default_engine_prefer_cloud(hass: HomeAssistant, tmp_path: Path) -> None:
@pytest.mark.parametrize("config_flow_test_domain", ["new_test"])
async def test_default_engine_prefer_provider(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
mock_provider: MockProvider,
config_flow_test_domain: str,
) -> None:
"""Test async_default_engine."""
mock_stt_platform(
hass,
tmp_path,
TEST_DOMAIN,
async_get_engine=AsyncMock(return_value=mock_provider),
)
mock_stt_platform(
hass,
tmp_path,
"cloud",
async_get_engine=AsyncMock(return_value=mock_provider),
)
assert await async_setup_component(
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
mock_provider_entity.url_path = "stt.new_test"
mock_provider_entity._attr_name = "New test"
await mock_setup(hass, tmp_path, mock_provider)
await mock_config_entry_setup(
hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domain
)
await hass.async_block_till_done()
assert async_default_engine(hass) == "cloud"
entity_engine = async_get_speech_to_text_engine(hass, "stt.new_test")
assert entity_engine is not None
assert entity_engine.name == "New test"
provider_engine = async_get_speech_to_text_engine(hass, "test")
assert provider_engine is not None
assert provider_engine.name == "test"
assert async_default_engine(hass) == "test"
async def test_get_engine_legacy(