diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 71136dcdecb5..a98f184094fe 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1703,7 +1703,7 @@ class PipelineRuns: pipeline_run.abort_wake_word_detection = True -@dataclass +@dataclass(slots=True) class DeviceAudioQueue: """Audio capture queue for a satellite device.""" @@ -1717,6 +1717,14 @@ class DeviceAudioQueue: """Flag to be set if audio samples were dropped because the queue was full.""" +@dataclass(slots=True) +class AssistDevice: + """Assist device.""" + + domain: str + unique_id_prefix: str + + class PipelineData: """Store and debug data stored in hass.data.""" @@ -1724,12 +1732,12 @@ class PipelineData: """Initialize.""" self.pipeline_store = pipeline_store self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {} - self.pipeline_devices: set[str] = set() + self.pipeline_devices: dict[str, AssistDevice] = {} self.pipeline_runs = PipelineRuns(pipeline_store) self.device_audio_queues: dict[str, DeviceAudioQueue] = {} -@dataclass +@dataclass(slots=True) class PipelineRunDebug: """Debug data for a pipelinerun.""" diff --git a/homeassistant/components/assist_pipeline/select.py b/homeassistant/components/assist_pipeline/select.py index 83e1bd3ab363..43ed003f65d1 100644 --- a/homeassistant/components/assist_pipeline/select.py +++ b/homeassistant/components/assist_pipeline/select.py @@ -10,7 +10,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import collection, entity_registry as er, restore_state from .const import DOMAIN -from .pipeline import PipelineData, PipelineStorageCollection +from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection from .vad import VadSensitivity OPTION_PREFERRED = "preferred" @@ -70,8 +70,10 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity): _attr_current_option = OPTION_PREFERRED _attr_options = [OPTION_PREFERRED] - def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None: + def __init__(self, hass: HomeAssistant, domain: str, unique_id_prefix: str) -> None: """Initialize a pipeline selector.""" + self._domain = domain + self._unique_id_prefix = unique_id_prefix self._attr_unique_id = f"{unique_id_prefix}-pipeline" self.hass = hass self._update_options() @@ -91,11 +93,16 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity): self._attr_current_option = state.state if self.registry_entry and (device_id := self.registry_entry.device_id): - pipeline_data.pipeline_devices.add(device_id) - self.async_on_remove( - lambda: pipeline_data.pipeline_devices.discard(device_id) + pipeline_data.pipeline_devices[device_id] = AssistDevice( + self._domain, self._unique_id_prefix ) + def cleanup() -> None: + """Clean up registered device.""" + pipeline_data.pipeline_devices.pop(device_id) + + self.async_on_remove(cleanup) + async def async_select_option(self, option: str) -> None: """Select an option.""" self._attr_current_option = option diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index 89cced519df9..bfba85638753 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -15,7 +15,7 @@ import voluptuous as vol from homeassistant.components import conversation, stt, tts, websocket_api from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers import config_validation as cv +from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.util import language as language_util from .const import ( @@ -53,6 +53,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, websocket_run) websocket_api.async_register_command(hass, websocket_list_languages) websocket_api.async_register_command(hass, websocket_list_runs) + websocket_api.async_register_command(hass, websocket_list_devices) websocket_api.async_register_command(hass, websocket_get_run) websocket_api.async_register_command(hass, websocket_device_capture) @@ -287,6 +288,35 @@ def websocket_list_runs( ) +@callback +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "assist_pipeline/device/list", + } +) +def websocket_list_devices( + hass: HomeAssistant, + connection: websocket_api.connection.ActiveConnection, + msg: dict[str, Any], +) -> None: + """List assist devices.""" + pipeline_data: PipelineData = hass.data[DOMAIN] + ent_reg = er.async_get(hass) + connection.send_result( + msg["id"], + [ + { + "device_id": device_id, + "pipeline_entity": ent_reg.async_get_entity_id( + "select", info.domain, f"{info.unique_id_prefix}-pipeline" + ), + } + for device_id, info in pipeline_data.pipeline_devices.items() + ], + ) + + @callback @websocket_api.require_admin @websocket_api.websocket_command( diff --git a/homeassistant/components/esphome/select.py b/homeassistant/components/esphome/select.py index a3464b137dcb..3d4d296bb87c 100644 --- a/homeassistant/components/esphome/select.py +++ b/homeassistant/components/esphome/select.py @@ -12,6 +12,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback +from .const import DOMAIN from .domain_data import DomainData from .entity import ( EsphomeAssistEntity, @@ -75,7 +76,7 @@ class EsphomeAssistPipelineSelect(EsphomeAssistEntity, AssistPipelineSelect): def __init__(self, hass: HomeAssistant, entry_data: RuntimeEntryData) -> None: """Initialize a pipeline selector.""" EsphomeAssistEntity.__init__(self, entry_data) - AssistPipelineSelect.__init__(self, hass, self._device_info.mac_address) + AssistPipelineSelect.__init__(self, hass, DOMAIN, self._device_info.mac_address) class EsphomeVadSensitivitySelect(EsphomeAssistEntity, VadSensitivitySelect): diff --git a/homeassistant/components/voip/select.py b/homeassistant/components/voip/select.py index 94a3aacc0fd1..f145f866ae3f 100644 --- a/homeassistant/components/voip/select.py +++ b/homeassistant/components/voip/select.py @@ -51,7 +51,7 @@ class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect): def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None: """Initialize a pipeline selector.""" VoIPEntity.__init__(self, device) - AssistPipelineSelect.__init__(self, hass, device.voip_id) + AssistPipelineSelect.__init__(self, hass, DOMAIN, device.voip_id) class VoipVadSensitivitySelect(VoIPEntity, VadSensitivitySelect): diff --git a/homeassistant/components/wyoming/select.py b/homeassistant/components/wyoming/select.py index c04bad4bef8a..99f26c3e440f 100644 --- a/homeassistant/components/wyoming/select.py +++ b/homeassistant/components/wyoming/select.py @@ -57,7 +57,7 @@ class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelec self.device = device WyomingSatelliteEntity.__init__(self, device) - AssistPipelineSelect.__init__(self, hass, device.satellite_id) + AssistPipelineSelect.__init__(self, hass, DOMAIN, device.satellite_id) async def async_select_option(self, option: str) -> None: """Select an option.""" diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index 97f80a33d1dd..38c96871ed33 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -8,13 +8,15 @@ from unittest.mock import AsyncMock import pytest from homeassistant.components import stt, tts, wake_word -from homeassistant.components.assist_pipeline import DOMAIN +from homeassistant.components.assist_pipeline import DOMAIN, select as assist_select from homeassistant.components.assist_pipeline.pipeline import ( PipelineData, PipelineStorageCollection, ) from homeassistant.config_entries import ConfigEntry, ConfigFlow +from homeassistant.const import Platform from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import device_registry as dr from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.setup import async_setup_component @@ -288,7 +290,7 @@ async def init_supporting_components( ) -> bool: """Set up test config entry.""" await hass.config_entries.async_forward_entry_setups( - config_entry, [stt.DOMAIN, wake_word.DOMAIN] + config_entry, [Platform.STT, Platform.WAKE_WORD] ) return True @@ -297,7 +299,7 @@ async def init_supporting_components( ) -> bool: """Unload up test config entry.""" await hass.config_entries.async_unload_platforms( - config_entry, [stt.DOMAIN, wake_word.DOMAIN] + config_entry, [Platform.STT, Platform.WAKE_WORD] ) return True @@ -369,6 +371,79 @@ async def init_components(hass: HomeAssistant, init_supporting_components): assert await async_setup_component(hass, "assist_pipeline", {}) +@pytest.fixture +async def assist_device(hass: HomeAssistant, init_components) -> dr.DeviceEntry: + """Create an assist device.""" + config_entry = MockConfigEntry(domain="test_assist_device") + config_entry.add_to_hass(hass) + + dev_reg = dr.async_get(hass) + device = dev_reg.async_get_or_create( + name="Test Device", + config_entry_id=config_entry.entry_id, + identifiers={("test_assist_device", "test")}, + ) + + async def async_setup_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Set up test config entry.""" + await hass.config_entries.async_forward_entry_setups( + config_entry, [Platform.SELECT] + ) + return True + + async def async_unload_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Unload up test config entry.""" + await hass.config_entries.async_unload_platforms( + config_entry, [Platform.SELECT] + ) + return True + + async def async_setup_entry_select_platform( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, + ) -> None: + """Set up test select platform via config entry.""" + entities = [ + assist_select.AssistPipelineSelect( + hass, "test_assist_device", "test-prefix" + ), + assist_select.VadSensitivitySelect(hass, "test-prefix"), + ] + for ent in entities: + ent._attr_device_info = dr.DeviceInfo( + identifiers={("test_assist_device", "test")}, + ) + async_add_entities(entities) + + mock_integration( + hass, + MockModule( + "test_assist_device", + async_setup_entry=async_setup_entry_init, + async_unload_entry=async_unload_entry_init, + ), + ) + mock_platform( + hass, + "test_assist_device.select", + MockPlatform( + async_setup_entry=async_setup_entry_select_platform, + ), + ) + mock_platform(hass, "test_assist_device.config_flow") + + with mock_config_flow("test_assist_device", ConfigFlow): + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + return device + + @pytest.fixture def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData: """Return pipeline data.""" diff --git a/tests/components/assist_pipeline/test_select.py b/tests/components/assist_pipeline/test_select.py index c4e750e10192..73c069ddd042 100644 --- a/tests/components/assist_pipeline/test_select.py +++ b/tests/components/assist_pipeline/test_select.py @@ -6,6 +6,7 @@ import pytest from homeassistant.components.assist_pipeline import Pipeline from homeassistant.components.assist_pipeline.pipeline import ( + AssistDevice, PipelineData, PipelineStorageCollection, ) @@ -33,7 +34,7 @@ class SelectPlatform(MockPlatform): async_add_entities: AddEntitiesCallback, ) -> None: """Set up fake select platform.""" - pipeline_entity = AssistPipelineSelect(hass, "test") + pipeline_entity = AssistPipelineSelect(hass, "test-domain", "test-prefix") pipeline_entity._attr_device_info = DeviceInfo( identifiers={("test", "test")}, ) @@ -109,13 +110,15 @@ async def test_select_entity_registering_device( assert device is not None # Test device is registered - assert pipeline_data.pipeline_devices == {device.id} + assert pipeline_data.pipeline_devices == { + device.id: AssistDevice("test-domain", "test-prefix") + } await hass.config_entries.async_remove(init_select.entry_id) await hass.async_block_till_done() # Test device is removed - assert pipeline_data.pipeline_devices == set() + assert pipeline_data.pipeline_devices == {} async def test_select_entity_changing_pipelines( @@ -128,7 +131,7 @@ async def test_select_entity_changing_pipelines( """Test entity tracking pipeline changes.""" config_entry = init_select # nicer naming - state = hass.states.get("select.assist_pipeline_test_pipeline") + state = hass.states.get("select.assist_pipeline_test_prefix_pipeline") assert state is not None assert state.state == "preferred" assert state.attributes["options"] == [ @@ -143,13 +146,13 @@ async def test_select_entity_changing_pipelines( "select", "select_option", { - "entity_id": "select.assist_pipeline_test_pipeline", + "entity_id": "select.assist_pipeline_test_prefix_pipeline", "option": pipeline_2.name, }, blocking=True, ) - state = hass.states.get("select.assist_pipeline_test_pipeline") + state = hass.states.get("select.assist_pipeline_test_prefix_pipeline") assert state is not None assert state.state == pipeline_2.name @@ -157,14 +160,14 @@ async def test_select_entity_changing_pipelines( assert await hass.config_entries.async_forward_entry_unload(config_entry, "select") assert await hass.config_entries.async_forward_entry_setup(config_entry, "select") - state = hass.states.get("select.assist_pipeline_test_pipeline") + state = hass.states.get("select.assist_pipeline_test_prefix_pipeline") assert state is not None assert state.state == pipeline_2.name # Remove selected pipeline await pipeline_storage.async_delete_item(pipeline_2.id) - state = hass.states.get("select.assist_pipeline_test_pipeline") + state = hass.states.get("select.assist_pipeline_test_prefix_pipeline") assert state is not None assert state.state == "preferred" assert state.attributes["options"] == [ diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 458320a9a90c..3ea6be028c1e 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -2502,3 +2502,22 @@ async def test_pipeline_empty_tts_output( assert msg["event"]["type"] == "run-end" assert msg["event"]["data"] == snapshot events.append(msg["event"]) + + +async def test_pipeline_list_devices( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + assist_device, +) -> None: + """Test list devices.""" + client = await hass_ws_client(hass) + + await client.send_json_auto_id({"type": "assist_pipeline/device/list"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == [ + { + "device_id": assist_device.id, + "pipeline_entity": "select.test_assist_device_test_prefix_pipeline", + } + ]