1
mirror of https://github.com/home-assistant/core synced 2024-07-15 09:42:11 +02:00

Add API to fetch Assist devices (#107333)

* Add API to fetch Assist devices

* Revert some changes to fixture, make a single fixture for an Assist device
This commit is contained in:
Paulus Schoutsen 2024-01-05 23:30:18 -05:00 committed by GitHub
parent 6201e81eca
commit f1d2868fd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 166 additions and 23 deletions

View File

@ -1703,7 +1703,7 @@ class PipelineRuns:
pipeline_run.abort_wake_word_detection = True
@dataclass
@dataclass(slots=True)
class DeviceAudioQueue:
"""Audio capture queue for a satellite device."""
@ -1717,6 +1717,14 @@ class DeviceAudioQueue:
"""Flag to be set if audio samples were dropped because the queue was full."""
@dataclass(slots=True)
class AssistDevice:
"""Assist device."""
domain: str
unique_id_prefix: str
class PipelineData:
"""Store and debug data stored in hass.data."""
@ -1724,12 +1732,12 @@ class PipelineData:
"""Initialize."""
self.pipeline_store = pipeline_store
self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {}
self.pipeline_devices: set[str] = set()
self.pipeline_devices: dict[str, AssistDevice] = {}
self.pipeline_runs = PipelineRuns(pipeline_store)
self.device_audio_queues: dict[str, DeviceAudioQueue] = {}
@dataclass
@dataclass(slots=True)
class PipelineRunDebug:
"""Debug data for a pipelinerun."""

View File

@ -10,7 +10,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state
from .const import DOMAIN
from .pipeline import PipelineData, PipelineStorageCollection
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
from .vad import VadSensitivity
OPTION_PREFERRED = "preferred"
@ -70,8 +70,10 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
_attr_current_option = OPTION_PREFERRED
_attr_options = [OPTION_PREFERRED]
def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None:
def __init__(self, hass: HomeAssistant, domain: str, unique_id_prefix: str) -> None:
"""Initialize a pipeline selector."""
self._domain = domain
self._unique_id_prefix = unique_id_prefix
self._attr_unique_id = f"{unique_id_prefix}-pipeline"
self.hass = hass
self._update_options()
@ -91,11 +93,16 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
self._attr_current_option = state.state
if self.registry_entry and (device_id := self.registry_entry.device_id):
pipeline_data.pipeline_devices.add(device_id)
self.async_on_remove(
lambda: pipeline_data.pipeline_devices.discard(device_id)
pipeline_data.pipeline_devices[device_id] = AssistDevice(
self._domain, self._unique_id_prefix
)
def cleanup() -> None:
"""Clean up registered device."""
pipeline_data.pipeline_devices.pop(device_id)
self.async_on_remove(cleanup)
async def async_select_option(self, option: str) -> None:
"""Select an option."""
self._attr_current_option = option

View File

@ -15,7 +15,7 @@ import voluptuous as vol
from homeassistant.components import conversation, stt, tts, websocket_api
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.util import language as language_util
from .const import (
@ -53,6 +53,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_run)
websocket_api.async_register_command(hass, websocket_list_languages)
websocket_api.async_register_command(hass, websocket_list_runs)
websocket_api.async_register_command(hass, websocket_list_devices)
websocket_api.async_register_command(hass, websocket_get_run)
websocket_api.async_register_command(hass, websocket_device_capture)
@ -287,6 +288,35 @@ def websocket_list_runs(
)
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/device/list",
}
)
def websocket_list_devices(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""List assist devices."""
pipeline_data: PipelineData = hass.data[DOMAIN]
ent_reg = er.async_get(hass)
connection.send_result(
msg["id"],
[
{
"device_id": device_id,
"pipeline_entity": ent_reg.async_get_entity_id(
"select", info.domain, f"{info.unique_id_prefix}-pipeline"
),
}
for device_id, info in pipeline_data.pipeline_devices.items()
],
)
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(

View File

@ -12,6 +12,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .domain_data import DomainData
from .entity import (
EsphomeAssistEntity,
@ -75,7 +76,7 @@ class EsphomeAssistPipelineSelect(EsphomeAssistEntity, AssistPipelineSelect):
def __init__(self, hass: HomeAssistant, entry_data: RuntimeEntryData) -> None:
"""Initialize a pipeline selector."""
EsphomeAssistEntity.__init__(self, entry_data)
AssistPipelineSelect.__init__(self, hass, self._device_info.mac_address)
AssistPipelineSelect.__init__(self, hass, DOMAIN, self._device_info.mac_address)
class EsphomeVadSensitivitySelect(EsphomeAssistEntity, VadSensitivitySelect):

View File

@ -51,7 +51,7 @@ class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None:
"""Initialize a pipeline selector."""
VoIPEntity.__init__(self, device)
AssistPipelineSelect.__init__(self, hass, device.voip_id)
AssistPipelineSelect.__init__(self, hass, DOMAIN, device.voip_id)
class VoipVadSensitivitySelect(VoIPEntity, VadSensitivitySelect):

View File

@ -57,7 +57,7 @@ class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelec
self.device = device
WyomingSatelliteEntity.__init__(self, device)
AssistPipelineSelect.__init__(self, hass, device.satellite_id)
AssistPipelineSelect.__init__(self, hass, DOMAIN, device.satellite_id)
async def async_select_option(self, option: str) -> None:
"""Select an option."""

View File

@ -8,13 +8,15 @@ from unittest.mock import AsyncMock
import pytest
from homeassistant.components import stt, tts, wake_word
from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline import DOMAIN, select as assist_select
from homeassistant.components.assist_pipeline.pipeline import (
PipelineData,
PipelineStorageCollection,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component
@ -288,7 +290,7 @@ async def init_supporting_components(
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
config_entry, [Platform.STT, Platform.WAKE_WORD]
)
return True
@ -297,7 +299,7 @@ async def init_supporting_components(
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_unload_platforms(
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
config_entry, [Platform.STT, Platform.WAKE_WORD]
)
return True
@ -369,6 +371,79 @@ async def init_components(hass: HomeAssistant, init_supporting_components):
assert await async_setup_component(hass, "assist_pipeline", {})
@pytest.fixture
async def assist_device(hass: HomeAssistant, init_components) -> dr.DeviceEntry:
"""Create an assist device."""
config_entry = MockConfigEntry(domain="test_assist_device")
config_entry.add_to_hass(hass)
dev_reg = dr.async_get(hass)
device = dev_reg.async_get_or_create(
name="Test Device",
config_entry_id=config_entry.entry_id,
identifiers={("test_assist_device", "test")},
)
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(
config_entry, [Platform.SELECT]
)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_unload_platforms(
config_entry, [Platform.SELECT]
)
return True
async def async_setup_entry_select_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test select platform via config entry."""
entities = [
assist_select.AssistPipelineSelect(
hass, "test_assist_device", "test-prefix"
),
assist_select.VadSensitivitySelect(hass, "test-prefix"),
]
for ent in entities:
ent._attr_device_info = dr.DeviceInfo(
identifiers={("test_assist_device", "test")},
)
async_add_entities(entities)
mock_integration(
hass,
MockModule(
"test_assist_device",
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
mock_platform(
hass,
"test_assist_device.select",
MockPlatform(
async_setup_entry=async_setup_entry_select_platform,
),
)
mock_platform(hass, "test_assist_device.config_flow")
with mock_config_flow("test_assist_device", ConfigFlow):
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return device
@pytest.fixture
def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData:
"""Return pipeline data."""

View File

@ -6,6 +6,7 @@ import pytest
from homeassistant.components.assist_pipeline import Pipeline
from homeassistant.components.assist_pipeline.pipeline import (
AssistDevice,
PipelineData,
PipelineStorageCollection,
)
@ -33,7 +34,7 @@ class SelectPlatform(MockPlatform):
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up fake select platform."""
pipeline_entity = AssistPipelineSelect(hass, "test")
pipeline_entity = AssistPipelineSelect(hass, "test-domain", "test-prefix")
pipeline_entity._attr_device_info = DeviceInfo(
identifiers={("test", "test")},
)
@ -109,13 +110,15 @@ async def test_select_entity_registering_device(
assert device is not None
# Test device is registered
assert pipeline_data.pipeline_devices == {device.id}
assert pipeline_data.pipeline_devices == {
device.id: AssistDevice("test-domain", "test-prefix")
}
await hass.config_entries.async_remove(init_select.entry_id)
await hass.async_block_till_done()
# Test device is removed
assert pipeline_data.pipeline_devices == set()
assert pipeline_data.pipeline_devices == {}
async def test_select_entity_changing_pipelines(
@ -128,7 +131,7 @@ async def test_select_entity_changing_pipelines(
"""Test entity tracking pipeline changes."""
config_entry = init_select # nicer naming
state = hass.states.get("select.assist_pipeline_test_pipeline")
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
assert state is not None
assert state.state == "preferred"
assert state.attributes["options"] == [
@ -143,13 +146,13 @@ async def test_select_entity_changing_pipelines(
"select",
"select_option",
{
"entity_id": "select.assist_pipeline_test_pipeline",
"entity_id": "select.assist_pipeline_test_prefix_pipeline",
"option": pipeline_2.name,
},
blocking=True,
)
state = hass.states.get("select.assist_pipeline_test_pipeline")
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
assert state is not None
assert state.state == pipeline_2.name
@ -157,14 +160,14 @@ async def test_select_entity_changing_pipelines(
assert await hass.config_entries.async_forward_entry_unload(config_entry, "select")
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
state = hass.states.get("select.assist_pipeline_test_pipeline")
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
assert state is not None
assert state.state == pipeline_2.name
# Remove selected pipeline
await pipeline_storage.async_delete_item(pipeline_2.id)
state = hass.states.get("select.assist_pipeline_test_pipeline")
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
assert state is not None
assert state.state == "preferred"
assert state.attributes["options"] == [

View File

@ -2502,3 +2502,22 @@ async def test_pipeline_empty_tts_output(
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
async def test_pipeline_list_devices(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
assist_device,
) -> None:
"""Test list devices."""
client = await hass_ws_client(hass)
await client.send_json_auto_id({"type": "assist_pipeline/device/list"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == [
{
"device_id": assist_device.id,
"pipeline_entity": "select.test_assist_device_test_prefix_pipeline",
}
]