mirror of
https://github.com/home-assistant/core
synced 2024-07-15 09:42:11 +02:00
Add API to fetch Assist devices (#107333)
* Add API to fetch Assist devices * Revert some changes to fixture, make a single fixture for an Assist device
This commit is contained in:
parent
6201e81eca
commit
f1d2868fd0
@ -1703,7 +1703,7 @@ class PipelineRuns:
|
||||
pipeline_run.abort_wake_word_detection = True
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class DeviceAudioQueue:
|
||||
"""Audio capture queue for a satellite device."""
|
||||
|
||||
@ -1717,6 +1717,14 @@ class DeviceAudioQueue:
|
||||
"""Flag to be set if audio samples were dropped because the queue was full."""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AssistDevice:
|
||||
"""Assist device."""
|
||||
|
||||
domain: str
|
||||
unique_id_prefix: str
|
||||
|
||||
|
||||
class PipelineData:
|
||||
"""Store and debug data stored in hass.data."""
|
||||
|
||||
@ -1724,12 +1732,12 @@ class PipelineData:
|
||||
"""Initialize."""
|
||||
self.pipeline_store = pipeline_store
|
||||
self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {}
|
||||
self.pipeline_devices: set[str] = set()
|
||||
self.pipeline_devices: dict[str, AssistDevice] = {}
|
||||
self.pipeline_runs = PipelineRuns(pipeline_store)
|
||||
self.device_audio_queues: dict[str, DeviceAudioQueue] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class PipelineRunDebug:
|
||||
"""Debug data for a pipelinerun."""
|
||||
|
||||
|
@ -10,7 +10,7 @@ from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||
|
||||
from .const import DOMAIN
|
||||
from .pipeline import PipelineData, PipelineStorageCollection
|
||||
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
||||
from .vad import VadSensitivity
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
@ -70,8 +70,10 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
||||
_attr_current_option = OPTION_PREFERRED
|
||||
_attr_options = [OPTION_PREFERRED]
|
||||
|
||||
def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None:
|
||||
def __init__(self, hass: HomeAssistant, domain: str, unique_id_prefix: str) -> None:
|
||||
"""Initialize a pipeline selector."""
|
||||
self._domain = domain
|
||||
self._unique_id_prefix = unique_id_prefix
|
||||
self._attr_unique_id = f"{unique_id_prefix}-pipeline"
|
||||
self.hass = hass
|
||||
self._update_options()
|
||||
@ -91,11 +93,16 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
||||
self._attr_current_option = state.state
|
||||
|
||||
if self.registry_entry and (device_id := self.registry_entry.device_id):
|
||||
pipeline_data.pipeline_devices.add(device_id)
|
||||
self.async_on_remove(
|
||||
lambda: pipeline_data.pipeline_devices.discard(device_id)
|
||||
pipeline_data.pipeline_devices[device_id] = AssistDevice(
|
||||
self._domain, self._unique_id_prefix
|
||||
)
|
||||
|
||||
def cleanup() -> None:
|
||||
"""Clean up registered device."""
|
||||
pipeline_data.pipeline_devices.pop(device_id)
|
||||
|
||||
self.async_on_remove(cleanup)
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
"""Select an option."""
|
||||
self._attr_current_option = option
|
||||
|
@ -15,7 +15,7 @@ import voluptuous as vol
|
||||
from homeassistant.components import conversation, stt, tts, websocket_api
|
||||
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers import config_validation as cv, entity_registry as er
|
||||
from homeassistant.util import language as language_util
|
||||
|
||||
from .const import (
|
||||
@ -53,6 +53,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
websocket_api.async_register_command(hass, websocket_run)
|
||||
websocket_api.async_register_command(hass, websocket_list_languages)
|
||||
websocket_api.async_register_command(hass, websocket_list_runs)
|
||||
websocket_api.async_register_command(hass, websocket_list_devices)
|
||||
websocket_api.async_register_command(hass, websocket_get_run)
|
||||
websocket_api.async_register_command(hass, websocket_device_capture)
|
||||
|
||||
@ -287,6 +288,35 @@ def websocket_list_runs(
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "assist_pipeline/device/list",
|
||||
}
|
||||
)
|
||||
def websocket_list_devices(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""List assist devices."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
ent_reg = er.async_get(hass)
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
[
|
||||
{
|
||||
"device_id": device_id,
|
||||
"pipeline_entity": ent_reg.async_get_entity_id(
|
||||
"select", info.domain, f"{info.unique_id_prefix}-pipeline"
|
||||
),
|
||||
}
|
||||
for device_id, info in pipeline_data.pipeline_devices.items()
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
|
@ -12,6 +12,7 @@ from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .domain_data import DomainData
|
||||
from .entity import (
|
||||
EsphomeAssistEntity,
|
||||
@ -75,7 +76,7 @@ class EsphomeAssistPipelineSelect(EsphomeAssistEntity, AssistPipelineSelect):
|
||||
def __init__(self, hass: HomeAssistant, entry_data: RuntimeEntryData) -> None:
|
||||
"""Initialize a pipeline selector."""
|
||||
EsphomeAssistEntity.__init__(self, entry_data)
|
||||
AssistPipelineSelect.__init__(self, hass, self._device_info.mac_address)
|
||||
AssistPipelineSelect.__init__(self, hass, DOMAIN, self._device_info.mac_address)
|
||||
|
||||
|
||||
class EsphomeVadSensitivitySelect(EsphomeAssistEntity, VadSensitivitySelect):
|
||||
|
@ -51,7 +51,7 @@ class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
|
||||
def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None:
|
||||
"""Initialize a pipeline selector."""
|
||||
VoIPEntity.__init__(self, device)
|
||||
AssistPipelineSelect.__init__(self, hass, device.voip_id)
|
||||
AssistPipelineSelect.__init__(self, hass, DOMAIN, device.voip_id)
|
||||
|
||||
|
||||
class VoipVadSensitivitySelect(VoIPEntity, VadSensitivitySelect):
|
||||
|
@ -57,7 +57,7 @@ class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelec
|
||||
self.device = device
|
||||
|
||||
WyomingSatelliteEntity.__init__(self, device)
|
||||
AssistPipelineSelect.__init__(self, hass, device.satellite_id)
|
||||
AssistPipelineSelect.__init__(self, hass, DOMAIN, device.satellite_id)
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
"""Select an option."""
|
||||
|
@ -8,13 +8,15 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import stt, tts, wake_word
|
||||
from homeassistant.components.assist_pipeline import DOMAIN
|
||||
from homeassistant.components.assist_pipeline import DOMAIN, select as assist_select
|
||||
from homeassistant.components.assist_pipeline.pipeline import (
|
||||
PipelineData,
|
||||
PipelineStorageCollection,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
@ -288,7 +290,7 @@ async def init_supporting_components(
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(
|
||||
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
|
||||
config_entry, [Platform.STT, Platform.WAKE_WORD]
|
||||
)
|
||||
return True
|
||||
|
||||
@ -297,7 +299,7 @@ async def init_supporting_components(
|
||||
) -> bool:
|
||||
"""Unload up test config entry."""
|
||||
await hass.config_entries.async_unload_platforms(
|
||||
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
|
||||
config_entry, [Platform.STT, Platform.WAKE_WORD]
|
||||
)
|
||||
return True
|
||||
|
||||
@ -369,6 +371,79 @@ async def init_components(hass: HomeAssistant, init_supporting_components):
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def assist_device(hass: HomeAssistant, init_components) -> dr.DeviceEntry:
|
||||
"""Create an assist device."""
|
||||
config_entry = MockConfigEntry(domain="test_assist_device")
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
dev_reg = dr.async_get(hass)
|
||||
device = dev_reg.async_get_or_create(
|
||||
name="Test Device",
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={("test_assist_device", "test")},
|
||||
)
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(
|
||||
config_entry, [Platform.SELECT]
|
||||
)
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload up test config entry."""
|
||||
await hass.config_entries.async_unload_platforms(
|
||||
config_entry, [Platform.SELECT]
|
||||
)
|
||||
return True
|
||||
|
||||
async def async_setup_entry_select_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test select platform via config entry."""
|
||||
entities = [
|
||||
assist_select.AssistPipelineSelect(
|
||||
hass, "test_assist_device", "test-prefix"
|
||||
),
|
||||
assist_select.VadSensitivitySelect(hass, "test-prefix"),
|
||||
]
|
||||
for ent in entities:
|
||||
ent._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={("test_assist_device", "test")},
|
||||
)
|
||||
async_add_entities(entities)
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
"test_assist_device",
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
mock_platform(
|
||||
hass,
|
||||
"test_assist_device.select",
|
||||
MockPlatform(
|
||||
async_setup_entry=async_setup_entry_select_platform,
|
||||
),
|
||||
)
|
||||
mock_platform(hass, "test_assist_device.config_flow")
|
||||
|
||||
with mock_config_flow("test_assist_device", ConfigFlow):
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return device
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData:
|
||||
"""Return pipeline data."""
|
||||
|
@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
from homeassistant.components.assist_pipeline import Pipeline
|
||||
from homeassistant.components.assist_pipeline.pipeline import (
|
||||
AssistDevice,
|
||||
PipelineData,
|
||||
PipelineStorageCollection,
|
||||
)
|
||||
@ -33,7 +34,7 @@ class SelectPlatform(MockPlatform):
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up fake select platform."""
|
||||
pipeline_entity = AssistPipelineSelect(hass, "test")
|
||||
pipeline_entity = AssistPipelineSelect(hass, "test-domain", "test-prefix")
|
||||
pipeline_entity._attr_device_info = DeviceInfo(
|
||||
identifiers={("test", "test")},
|
||||
)
|
||||
@ -109,13 +110,15 @@ async def test_select_entity_registering_device(
|
||||
assert device is not None
|
||||
|
||||
# Test device is registered
|
||||
assert pipeline_data.pipeline_devices == {device.id}
|
||||
assert pipeline_data.pipeline_devices == {
|
||||
device.id: AssistDevice("test-domain", "test-prefix")
|
||||
}
|
||||
|
||||
await hass.config_entries.async_remove(init_select.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Test device is removed
|
||||
assert pipeline_data.pipeline_devices == set()
|
||||
assert pipeline_data.pipeline_devices == {}
|
||||
|
||||
|
||||
async def test_select_entity_changing_pipelines(
|
||||
@ -128,7 +131,7 @@ async def test_select_entity_changing_pipelines(
|
||||
"""Test entity tracking pipeline changes."""
|
||||
config_entry = init_select # nicer naming
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
|
||||
assert state is not None
|
||||
assert state.state == "preferred"
|
||||
assert state.attributes["options"] == [
|
||||
@ -143,13 +146,13 @@ async def test_select_entity_changing_pipelines(
|
||||
"select",
|
||||
"select_option",
|
||||
{
|
||||
"entity_id": "select.assist_pipeline_test_pipeline",
|
||||
"entity_id": "select.assist_pipeline_test_prefix_pipeline",
|
||||
"option": pipeline_2.name,
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
|
||||
assert state is not None
|
||||
assert state.state == pipeline_2.name
|
||||
|
||||
@ -157,14 +160,14 @@ async def test_select_entity_changing_pipelines(
|
||||
assert await hass.config_entries.async_forward_entry_unload(config_entry, "select")
|
||||
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
|
||||
assert state is not None
|
||||
assert state.state == pipeline_2.name
|
||||
|
||||
# Remove selected pipeline
|
||||
await pipeline_storage.async_delete_item(pipeline_2.id)
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
|
||||
assert state is not None
|
||||
assert state.state == "preferred"
|
||||
assert state.attributes["options"] == [
|
||||
|
@ -2502,3 +2502,22 @@ async def test_pipeline_empty_tts_output(
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
|
||||
async def test_pipeline_list_devices(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
assist_device,
|
||||
) -> None:
|
||||
"""Test list devices."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id({"type": "assist_pipeline/device/list"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == [
|
||||
{
|
||||
"device_id": assist_device.id,
|
||||
"pipeline_entity": "select.test_assist_device_test_prefix_pipeline",
|
||||
}
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user