diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 52769063b7e4..45f892d783e0 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -2,7 +2,7 @@ from functools import partial import importlib import logging -from typing import Any, Awaitable, Callable +from typing import Any, Awaitable, Callable, List import voluptuous as vol @@ -19,7 +19,7 @@ from homeassistant.const import ( SERVICE_TURN_ON, STATE_ON, ) -from homeassistant.core import Context, CoreState, HomeAssistant +from homeassistant.core import Context, CoreState, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import condition, extract_domain_configs, script import homeassistant.helpers.config_validation as cv @@ -119,9 +119,75 @@ def is_on(hass, entity_id): return hass.states.is_state(entity_id, STATE_ON) +@callback +def automations_with_entity(hass: HomeAssistant, entity_id: str) -> List[str]: + """Return all automations that reference the entity.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + results = [] + + for automation_entity in component.entities: + if entity_id in automation_entity.action_script.referenced_entities: + results.append(automation_entity.entity_id) + + return results + + +@callback +def entities_in_automation(hass: HomeAssistant, entity_id: str) -> List[str]: + """Return all entities in a scene.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + automation_entity = component.get_entity(entity_id) + + if automation_entity is None: + return [] + + return list(automation_entity.action_script.referenced_entities) + + +@callback +def automations_with_device(hass: HomeAssistant, device_id: str) -> List[str]: + """Return all automations that reference the device.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + results = [] + + for automation_entity in component.entities: + if device_id in automation_entity.action_script.referenced_devices: + results.append(automation_entity.entity_id) + + return results + + +@callback +def devices_in_automation(hass: HomeAssistant, entity_id: str) -> List[str]: + """Return all devices in a scene.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + automation_entity = component.get_entity(entity_id) + + if automation_entity is None: + return [] + + return list(automation_entity.action_script.referenced_devices) + + async def async_setup(hass, config): """Set up the automation.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + hass.data[DOMAIN] = component = EntityComponent(_LOGGER, DOMAIN, hass) await _async_process_config(hass, config, component) @@ -168,7 +234,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): name, async_attach_triggers, cond_func, - async_action, + action_script, hidden, initial_state, ): @@ -178,7 +244,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): self._async_attach_triggers = async_attach_triggers self._async_detach_triggers = None self._cond_func = cond_func - self._async_action = async_action + self.action_script = action_script self._last_triggered = None self._hidden = hidden self._initial_state = initial_state @@ -277,7 +343,16 @@ class AutomationEntity(ToggleEntity, RestoreEntity): {ATTR_NAME: self._name, ATTR_ENTITY_ID: self.entity_id}, context=trigger_context, ) - await self._async_action(self.entity_id, variables, trigger_context) + + _LOGGER.info("Executing %s", self._name) + + try: + await self.action_script.async_run(variables, trigger_context) + except Exception as err: # pylint: disable=broad-except + self.action_script.async_log_exception( + _LOGGER, f"Error while executing automation {self.entity_id}", err + ) + self._last_triggered = utcnow() await self.async_update_ha_state() @@ -358,7 +433,7 @@ async def _async_process_config(hass, config, component): hidden = config_block[CONF_HIDE_ENTITY] initial_state = config_block.get(CONF_INITIAL_STATE) - action = _async_get_action(hass, config_block.get(CONF_ACTION, {}), name) + action_script = script.Script(hass, config_block.get(CONF_ACTION, {}), name) if CONF_CONDITION in config_block: cond_func = await _async_process_if(hass, config, config_block) @@ -383,7 +458,7 @@ async def _async_process_config(hass, config, component): name, async_attach_triggers, cond_func, - action, + action_script, hidden, initial_state, ) @@ -394,24 +469,6 @@ async def _async_process_config(hass, config, component): await component.async_add_entities(entities) -def _async_get_action(hass, config, name): - """Return an action based on a configuration.""" - script_obj = script.Script(hass, config, name) - - async def action(entity_id, variables, context): - """Execute an action.""" - _LOGGER.info("Executing %s", name) - - try: - await script_obj.async_run(variables, context) - except Exception as err: # pylint: disable=broad-except - script_obj.async_log_exception( - _LOGGER, f"Error while executing automation {entity_id}", err - ) - - return action - - async def _async_process_if(hass, config, p_config): """Process if checks.""" if_configs = p_config.get(CONF_CONDITION) diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index 1d180b54cfd6..44684656372e 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -1,6 +1,7 @@ """Support for scripts.""" import asyncio import logging +from typing import List import voluptuous as vol @@ -15,6 +16,7 @@ from homeassistant.const import ( SERVICE_TURN_ON, STATE_ON, ) +from homeassistant.core import HomeAssistant, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.config_validation import make_entity_service_schema from homeassistant.helpers.entity import ToggleEntity @@ -69,9 +71,75 @@ def is_on(hass, entity_id): return hass.states.is_state(entity_id, STATE_ON) +@callback +def scripts_with_entity(hass: HomeAssistant, entity_id: str) -> List[str]: + """Return all scripts that reference the entity.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + results = [] + + for script_entity in component.entities: + if entity_id in script_entity.script.referenced_entities: + results.append(script_entity.entity_id) + + return results + + +@callback +def entities_in_script(hass: HomeAssistant, entity_id: str) -> List[str]: + """Return all entities in a scene.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + script_entity = component.get_entity(entity_id) + + if script_entity is None: + return [] + + return list(script_entity.script.referenced_entities) + + +@callback +def scripts_with_device(hass: HomeAssistant, device_id: str) -> List[str]: + """Return all scripts that reference the device.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + results = [] + + for script_entity in component.entities: + if device_id in script_entity.script.referenced_devices: + results.append(script_entity.entity_id) + + return results + + +@callback +def devices_in_script(hass: HomeAssistant, entity_id: str) -> List[str]: + """Return all devices in a scene.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + script_entity = component.get_entity(entity_id) + + if script_entity is None: + return [] + + return list(script_entity.script.referenced_devices) + + async def async_setup(hass, config): """Load the scripts from the configuration.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + hass.data[DOMAIN] = component = EntityComponent(_LOGGER, DOMAIN, hass) await _async_process_config(hass, config, component) diff --git a/homeassistant/components/search/__init__.py b/homeassistant/components/search/__init__.py index 47e7f6ef28d7..a3bbd3844aaf 100644 --- a/homeassistant/components/search/__init__.py +++ b/homeassistant/components/search/__init__.py @@ -1,14 +1,16 @@ """The Search integration.""" -from collections import defaultdict +from collections import defaultdict, deque +import logging import voluptuous as vol -from homeassistant.components import group, websocket_api +from homeassistant.components import automation, group, script, websocket_api from homeassistant.components.homeassistant import scene from homeassistant.core import HomeAssistant, callback, split_entity_id from homeassistant.helpers import device_registry, entity_registry DOMAIN = "search" +_LOGGER = logging.getLogger(__name__) async def async_setup(hass: HomeAssistant, config: dict): @@ -73,16 +75,17 @@ class Searcher: self._device_reg = device_reg self._entity_reg = entity_reg self.results = defaultdict(set) - self._to_resolve = set() + self._to_resolve = deque() @callback def async_search(self, item_type, item_id): """Find results.""" + _LOGGER.debug("Searching for %s/%s", item_type, item_id) self.results[item_type].add(item_id) - self._to_resolve.add((item_type, item_id)) + self._to_resolve.append((item_type, item_id)) while self._to_resolve: - search_type, search_id = self._to_resolve.pop() + search_type, search_id = self._to_resolve.popleft() getattr(self, f"_resolve_{search_type}")(search_id) # Clean up entity_id items, from the general "entity" type result, @@ -112,7 +115,7 @@ class Searcher: self.results[item_type].add(item_id) if item_type not in self.DONT_RESOLVE: - self._to_resolve.add((item_type, item_id)) + self._to_resolve.append((item_type, item_id)) @callback def _resolve_area(self, area_id) -> None: @@ -140,7 +143,11 @@ class Searcher: ): self._add_or_resolve("entity", entity_entry.entity_id) - # Extra: Find automations that reference this device + for entity_id in script.scripts_with_device(self.hass, device_id): + self._add_or_resolve("entity", entity_id) + + for entity_id in automation.automations_with_device(self.hass, device_id): + self._add_or_resolve("entity", entity_id) @callback def _resolve_entity(self, entity_id) -> None: @@ -153,6 +160,12 @@ class Searcher: for entity in group.groups_with_entity(self.hass, entity_id): self._add_or_resolve("entity", entity) + for entity in automation.automations_with_entity(self.hass, entity_id): + self._add_or_resolve("entity", entity) + + for entity in script.scripts_with_entity(self.hass, entity_id): + self._add_or_resolve("entity", entity) + # Find devices entity_entry = self._entity_reg.async_get(entity_id) if entity_entry is not None: @@ -164,7 +177,7 @@ class Searcher: domain = split_entity_id(entity_id)[0] - if domain in ("scene", "automation", "script", "group"): + if domain in self.EXIST_AS_ENTITY: self._add_or_resolve(domain, entity_id) @callback @@ -173,7 +186,13 @@ class Searcher: Will only be called if automation is an entry point. """ - # Extra: Check with automation integration what entities/devices they reference + for entity in automation.entities_in_automation( + self.hass, automation_entity_id + ): + self._add_or_resolve("entity", entity) + + for device in automation.devices_in_automation(self.hass, automation_entity_id): + self._add_or_resolve("device", device) @callback def _resolve_script(self, script_entity_id) -> None: @@ -181,7 +200,11 @@ class Searcher: Will only be called if script is an entry point. """ - # Extra: Check with script integration what entities/devices they reference + for entity in script.entities_in_script(self.hass, script_entity_id): + self._add_or_resolve("entity", entity) + + for device in script.devices_in_script(self.hass, script_entity_id): + self._add_or_resolve("device", device) @callback def _resolve_group(self, group_entity_id) -> None: diff --git a/homeassistant/components/search/manifest.json b/homeassistant/components/search/manifest.json index 337ce45f9bfa..581a702f514d 100644 --- a/homeassistant/components/search/manifest.json +++ b/homeassistant/components/search/manifest.json @@ -7,6 +7,6 @@ "zeroconf": [], "homekit": {}, "dependencies": ["websocket_api"], - "after_dependencies": ["scene", "group"], + "after_dependencies": ["scene", "group", "automation", "script"], "codeowners": ["@home-assistant/core"] } diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index c3d098539608..3500a3a4e3d7 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -1,10 +1,11 @@ """Offer reusable conditions.""" import asyncio +from collections import deque from datetime import datetime, timedelta import functools as ft import logging import sys -from typing import Callable, Container, Optional, Union, cast +from typing import Callable, Container, Optional, Set, Union, cast from homeassistant.components import zone as zone_cmp from homeassistant.components.device_automation import ( @@ -19,6 +20,7 @@ from homeassistant.const import ( CONF_BEFORE, CONF_BELOW, CONF_CONDITION, + CONF_DEVICE_ID, CONF_DOMAIN, CONF_ENTITY_ID, CONF_STATE, @@ -31,7 +33,7 @@ from homeassistant.const import ( SUN_EVENT_SUNSET, WEEKDAYS, ) -from homeassistant.core import HomeAssistant, State +from homeassistant.core import HomeAssistant, State, callback from homeassistant.exceptions import HomeAssistantError, TemplateError import homeassistant.helpers.config_validation as cv from homeassistant.helpers.sun import get_astral_event_date @@ -529,3 +531,50 @@ async def async_validate_condition_config( return cast(ConfigType, platform.CONDITION_SCHEMA(config)) # type: ignore return config + + +@callback +def async_extract_entities(config: ConfigType) -> Set[str]: + """Extract entities from a condition.""" + referenced = set() + to_process = deque([config]) + + while to_process: + config = to_process.popleft() + condition = config[CONF_CONDITION] + + if condition in ("and", "or"): + to_process.extend(config["conditions"]) + continue + + entity_id = config.get(CONF_ENTITY_ID) + + if entity_id is not None: + referenced.add(entity_id) + + return referenced + + +@callback +def async_extract_devices(config: ConfigType) -> Set[str]: + """Extract devices from a condition.""" + referenced = set() + to_process = deque([config]) + + while to_process: + config = to_process.popleft() + condition = config[CONF_CONDITION] + + if condition in ("and", "or"): + to_process.extend(config["conditions"]) + continue + + if condition != "device": + continue + + device_id = config.get(CONF_DEVICE_ID) + + if device_id is not None: + referenced.add(device_id) + + return referenced diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 0d973afcfe9b..378a6016c204 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -156,12 +156,66 @@ class Script: ACTION_DEVICE_AUTOMATION: self._async_device_automation, ACTION_ACTIVATE_SCENE: self._async_activate_scene, } + self._referenced_entities: Optional[Set[str]] = None + self._referenced_devices: Optional[Set[str]] = None @property def is_running(self) -> bool: """Return true if script is on.""" return self._cur != -1 + @property + def referenced_devices(self): + """Return a set of referenced devices.""" + if self._referenced_devices is not None: + return self._referenced_devices + + referenced = set() + + for step in self.sequence: + action = _determine_action(step) + + if action == ACTION_CHECK_CONDITION: + referenced |= condition.async_extract_devices(step) + + elif action == ACTION_DEVICE_AUTOMATION: + referenced.add(step[CONF_DEVICE_ID]) + + self._referenced_devices = referenced + return referenced + + @property + def referenced_entities(self): + """Return a set of referenced entities.""" + if self._referenced_entities is not None: + return self._referenced_entities + + referenced = set() + + for step in self.sequence: + action = _determine_action(step) + + if action == ACTION_CALL_SERVICE: + data = step.get(service.CONF_SERVICE_DATA) + if not data: + continue + + entity_ids = data.get(ATTR_ENTITY_ID) + if isinstance(entity_ids, str): + entity_ids = [entity_ids] + + for entity_id in entity_ids: + referenced.add(entity_id) + + elif action == ACTION_CHECK_CONDITION: + referenced |= condition.async_extract_entities(step) + + elif action == ACTION_ACTIVATE_SCENE: + referenced.add(step[CONF_SCENE]) + + self._referenced_entities = referenced + return referenced + def run(self, variables=None, context=None): """Run script.""" asyncio.run_coroutine_threadsafe( diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 83db0cdf7dd3..391c9646dd4f 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -5,6 +5,7 @@ from unittest.mock import Mock, patch import pytest import homeassistant.components.automation as automation +from homeassistant.components.automation import DOMAIN from homeassistant.const import ( ATTR_ENTITY_ID, ATTR_NAME, @@ -922,3 +923,80 @@ async def test_automation_restore_last_triggered_with_initial_state(hass): assert state assert state.state == STATE_ON assert state.attributes["last_triggered"] == time + + +async def test_extraction_functions(hass): + """Test extraction functions.""" + assert await async_setup_component( + hass, + DOMAIN, + { + DOMAIN: [ + { + "alias": "test1", + "trigger": {"platform": "state", "entity_id": "sensor.trigger_1"}, + "action": [ + { + "service": "test.script", + "data": {"entity_id": "light.in_both"}, + }, + { + "service": "test.script", + "data": {"entity_id": "light.in_first"}, + }, + { + "domain": "light", + "device_id": "device-in-both", + "entity_id": "light.bla", + "type": "turn_on", + }, + ], + }, + { + "alias": "test2", + "trigger": {"platform": "state", "entity_id": "sensor.trigger_2"}, + "action": [ + { + "service": "test.script", + "data": {"entity_id": "light.in_both"}, + }, + { + "condition": "state", + "entity_id": "sensor.condition", + "state": "100", + }, + {"scene": "scene.hello"}, + { + "domain": "light", + "device_id": "device-in-both", + "entity_id": "light.bla", + "type": "turn_on", + }, + { + "domain": "light", + "device_id": "device-in-last", + "entity_id": "light.bla", + "type": "turn_on", + }, + ], + }, + ] + }, + ) + + assert set(automation.automations_with_entity(hass, "light.in_both")) == { + "automation.test1", + "automation.test2", + } + assert set(automation.entities_in_automation(hass, "automation.test1")) == { + "light.in_both", + "light.in_first", + } + assert set(automation.automations_with_device(hass, "device-in-both")) == { + "automation.test1", + "automation.test2", + } + assert set(automation.devices_in_automation(hass, "automation.test2")) == { + "device-in-both", + "device-in-last", + } diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index cb66c26b6a35..9d64f5298f44 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -358,9 +358,8 @@ async def test_turning_no_scripts_off(hass): async def test_async_get_descriptions_script(hass): """Test async_set_service_schema for the script integration.""" - script = hass.components.script script_config = { - script.DOMAIN: { + DOMAIN: { "test1": {"sequence": [{"service": "homeassistant.restart"}]}, "test2": { "description": "test2", @@ -375,18 +374,75 @@ async def test_async_get_descriptions_script(hass): } } - await async_setup_component(hass, script.DOMAIN, script_config) + await async_setup_component(hass, DOMAIN, script_config) descriptions = await hass.helpers.service.async_get_all_descriptions() - assert descriptions[script.DOMAIN]["test1"]["description"] == "" - assert not descriptions[script.DOMAIN]["test1"]["fields"] + assert descriptions[DOMAIN]["test1"]["description"] == "" + assert not descriptions[DOMAIN]["test1"]["fields"] - assert descriptions[script.DOMAIN]["test2"]["description"] == "test2" + assert descriptions[DOMAIN]["test2"]["description"] == "test2" assert ( - descriptions[script.DOMAIN]["test2"]["fields"]["param"]["description"] + descriptions[DOMAIN]["test2"]["fields"]["param"]["description"] == "param_description" ) assert ( - descriptions[script.DOMAIN]["test2"]["fields"]["param"]["example"] - == "param_example" + descriptions[DOMAIN]["test2"]["fields"]["param"]["example"] == "param_example" ) + + +async def test_extraction_functions(hass): + """Test extraction functions.""" + assert await async_setup_component( + hass, + DOMAIN, + { + DOMAIN: { + "test1": { + "sequence": [ + { + "service": "test.script", + "data": {"entity_id": "light.in_both"}, + }, + { + "service": "test.script", + "data": {"entity_id": "light.in_first"}, + }, + {"domain": "light", "device_id": "device-in-both"}, + ] + }, + "test2": { + "sequence": [ + { + "service": "test.script", + "data": {"entity_id": "light.in_both"}, + }, + { + "condition": "state", + "entity_id": "sensor.condition", + "state": "100", + }, + {"scene": "scene.hello"}, + {"domain": "light", "device_id": "device-in-both"}, + {"domain": "light", "device_id": "device-in-last"}, + ], + }, + } + }, + ) + + assert set(script.scripts_with_entity(hass, "light.in_both")) == { + "script.test1", + "script.test2", + } + assert set(script.entities_in_script(hass, "script.test1")) == { + "light.in_both", + "light.in_first", + } + assert set(script.scripts_with_device(hass, "device-in-both")) == { + "script.test1", + "script.test2", + } + assert set(script.devices_in_script(hass, "script.test2")) == { + "device-in-both", + "device-in-last", + } diff --git a/tests/components/search/test_init.py b/tests/components/search/test_init.py index 5762468ff1dd..54a32bed2295 100644 --- a/tests/components/search/test_init.py +++ b/tests/components/search/test_init.py @@ -131,6 +131,62 @@ async def test_search(hass): }, ) + await async_setup_component( + hass, + "script", + { + "script": { + "wled": { + "sequence": [ + { + "service": "test.script", + "data": {"entity_id": wled_segment_1_entity.entity_id}, + }, + ] + }, + "hue": { + "sequence": [ + { + "service": "test.script", + "data": {"entity_id": hue_segment_1_entity.entity_id}, + }, + ] + }, + } + }, + ) + + assert await async_setup_component( + hass, + "automation", + { + "automation": [ + { + "alias": "wled_entity", + "trigger": {"platform": "state", "entity_id": "sensor.trigger_1"}, + "action": [ + { + "service": "test.script", + "data": {"entity_id": wled_segment_1_entity.entity_id}, + }, + ], + }, + { + "alias": "wled_device", + "trigger": {"platform": "state", "entity_id": "sensor.trigger_1"}, + "action": [ + { + "domain": "light", + "device_id": wled_device.id, + "entity_id": wled_segment_1_entity.entity_id, + "type": "turn_on", + }, + ], + }, + ] + }, + ) + # Explore the graph from every node and make sure we find the same results expected = { "config_entry": {wled_config_entry.entry_id}, @@ -139,6 +195,8 @@ async def test_search(hass): "entity": {wled_segment_1_entity.entity_id, wled_segment_2_entity.entity_id}, "scene": {"scene.scene_wled_seg_1", "scene.scene_wled_hue"}, "group": {"group.wled", "group.wled_hue"}, + "script": {"script.wled"}, + "automation": {"automation.wled_entity", "automation.wled_device"}, } for search_type, search_id in ( @@ -149,6 +207,9 @@ async def test_search(hass): ("entity", wled_segment_2_entity.entity_id), ("scene", "scene.scene_wled_seg_1"), ("group", "group.wled"), + ("script", "script.wled"), + ("automation", "automation.wled_entity"), + ("automation", "automation.wled_device"), ): searcher = search.Searcher(hass, device_reg, entity_reg) results = searcher.async_search(search_type, search_id) @@ -176,6 +237,8 @@ async def test_search(hass): "scene.scene_wled_hue", }, "group": {"group.wled", "group.hue", "group.wled_hue"}, + "script": {"script.wled", "script.hue"}, + "automation": {"automation.wled_entity", "automation.wled_device"}, } for search_type, search_id in ( ("scene", "scene.scene_wled_hue"), diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index b603f98bb04b..afa428805e98 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -176,3 +176,37 @@ async def test_if_numeric_state_not_raise_on_unavailable(hass): hass.states.async_set("sensor.temperature", "unknown") assert not test(hass) assert len(logwarn.mock_calls) == 0 + + +async def test_extract_entities(): + """Test extracting entities.""" + condition.async_extract_entities( + { + "condition": "and", + "conditions": [ + { + "condition": "state", + "entity_id": "sensor.temperature", + "state": "100", + }, + { + "condition": "numeric_state", + "entity_id": "sensor.temperature_2", + "below": 110, + }, + ], + } + ) == {"sensor.temperature", "sensor.temperature_2"} + + +async def test_extract_devices(): + """Test extracting devices.""" + condition.async_extract_devices( + { + "condition": "and", + "conditions": [ + {"condition": "device", "device_id": "abcd", "domain": "light"}, + {"condition": "device", "device_id": "qwer", "domain": "switch"}, + ], + } + ) == {"abcd", "qwer"} diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index a7fe2c252368..b226ed15720c 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -1022,3 +1022,58 @@ def test_log_exception(): assert p_error == "" else: assert p_error == str(exc) + + +async def test_referenced_entities(): + """Test referenced entities.""" + script_obj = script.Script( + None, + cv.SCRIPT_SCHEMA( + [ + { + "service": "test.script", + "data": {"entity_id": "light.service_not_list"}, + }, + { + "service": "test.script", + "data": {"entity_id": ["light.service_list"]}, + }, + { + "condition": "state", + "entity_id": "sensor.condition", + "state": "100", + }, + {"scene": "scene.hello"}, + {"event": "test_event"}, + {"delay": "{{ delay_period }}"}, + ] + ), + ) + assert script_obj.referenced_entities == { + "light.service_not_list", + "light.service_list", + "sensor.condition", + "scene.hello", + } + # Test we cache results. + assert script_obj.referenced_entities is script_obj.referenced_entities + + +async def test_referenced_devices(): + """Test referenced entities.""" + script_obj = script.Script( + None, + cv.SCRIPT_SCHEMA( + [ + {"domain": "light", "device_id": "script-dev-id"}, + { + "condition": "device", + "device_id": "condition-dev-id", + "domain": "switch", + }, + ] + ), + ) + assert script_obj.referenced_devices == {"script-dev-id", "condition-dev-id"} + # Test we cache results. + assert script_obj.referenced_devices is script_obj.referenced_devices