1
mirror of https://github.com/home-assistant/core synced 2024-10-01 05:30:36 +02:00

Find related items scripts/automations (#31293)

* Find related items scripts/automations

* Update manifest
This commit is contained in:
Paulus Schoutsen 2020-01-29 16:19:13 -08:00 committed by GitHub
parent 881437c085
commit 424e15c7a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 586 additions and 49 deletions

View File

@ -2,7 +2,7 @@
from functools import partial from functools import partial
import importlib import importlib
import logging import logging
from typing import Any, Awaitable, Callable from typing import Any, Awaitable, Callable, List
import voluptuous as vol import voluptuous as vol
@ -19,7 +19,7 @@ from homeassistant.const import (
SERVICE_TURN_ON, SERVICE_TURN_ON,
STATE_ON, STATE_ON,
) )
from homeassistant.core import Context, CoreState, HomeAssistant from homeassistant.core import Context, CoreState, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import condition, extract_domain_configs, script from homeassistant.helpers import condition, extract_domain_configs, script
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -119,9 +119,75 @@ def is_on(hass, entity_id):
return hass.states.is_state(entity_id, STATE_ON) return hass.states.is_state(entity_id, STATE_ON)
@callback
def automations_with_entity(hass: HomeAssistant, entity_id: str) -> List[str]:
"""Return all automations that reference the entity."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
results = []
for automation_entity in component.entities:
if entity_id in automation_entity.action_script.referenced_entities:
results.append(automation_entity.entity_id)
return results
@callback
def entities_in_automation(hass: HomeAssistant, entity_id: str) -> List[str]:
"""Return all entities in a scene."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
automation_entity = component.get_entity(entity_id)
if automation_entity is None:
return []
return list(automation_entity.action_script.referenced_entities)
@callback
def automations_with_device(hass: HomeAssistant, device_id: str) -> List[str]:
"""Return all automations that reference the device."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
results = []
for automation_entity in component.entities:
if device_id in automation_entity.action_script.referenced_devices:
results.append(automation_entity.entity_id)
return results
@callback
def devices_in_automation(hass: HomeAssistant, entity_id: str) -> List[str]:
"""Return all devices in a scene."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
automation_entity = component.get_entity(entity_id)
if automation_entity is None:
return []
return list(automation_entity.action_script.referenced_devices)
async def async_setup(hass, config): async def async_setup(hass, config):
"""Set up the automation.""" """Set up the automation."""
component = EntityComponent(_LOGGER, DOMAIN, hass) hass.data[DOMAIN] = component = EntityComponent(_LOGGER, DOMAIN, hass)
await _async_process_config(hass, config, component) await _async_process_config(hass, config, component)
@ -168,7 +234,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
name, name,
async_attach_triggers, async_attach_triggers,
cond_func, cond_func,
async_action, action_script,
hidden, hidden,
initial_state, initial_state,
): ):
@ -178,7 +244,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._async_attach_triggers = async_attach_triggers self._async_attach_triggers = async_attach_triggers
self._async_detach_triggers = None self._async_detach_triggers = None
self._cond_func = cond_func self._cond_func = cond_func
self._async_action = async_action self.action_script = action_script
self._last_triggered = None self._last_triggered = None
self._hidden = hidden self._hidden = hidden
self._initial_state = initial_state self._initial_state = initial_state
@ -277,7 +343,16 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
{ATTR_NAME: self._name, ATTR_ENTITY_ID: self.entity_id}, {ATTR_NAME: self._name, ATTR_ENTITY_ID: self.entity_id},
context=trigger_context, context=trigger_context,
) )
await self._async_action(self.entity_id, variables, trigger_context)
_LOGGER.info("Executing %s", self._name)
try:
await self.action_script.async_run(variables, trigger_context)
except Exception as err: # pylint: disable=broad-except
self.action_script.async_log_exception(
_LOGGER, f"Error while executing automation {self.entity_id}", err
)
self._last_triggered = utcnow() self._last_triggered = utcnow()
await self.async_update_ha_state() await self.async_update_ha_state()
@ -358,7 +433,7 @@ async def _async_process_config(hass, config, component):
hidden = config_block[CONF_HIDE_ENTITY] hidden = config_block[CONF_HIDE_ENTITY]
initial_state = config_block.get(CONF_INITIAL_STATE) initial_state = config_block.get(CONF_INITIAL_STATE)
action = _async_get_action(hass, config_block.get(CONF_ACTION, {}), name) action_script = script.Script(hass, config_block.get(CONF_ACTION, {}), name)
if CONF_CONDITION in config_block: if CONF_CONDITION in config_block:
cond_func = await _async_process_if(hass, config, config_block) cond_func = await _async_process_if(hass, config, config_block)
@ -383,7 +458,7 @@ async def _async_process_config(hass, config, component):
name, name,
async_attach_triggers, async_attach_triggers,
cond_func, cond_func,
action, action_script,
hidden, hidden,
initial_state, initial_state,
) )
@ -394,24 +469,6 @@ async def _async_process_config(hass, config, component):
await component.async_add_entities(entities) await component.async_add_entities(entities)
def _async_get_action(hass, config, name):
"""Return an action based on a configuration."""
script_obj = script.Script(hass, config, name)
async def action(entity_id, variables, context):
"""Execute an action."""
_LOGGER.info("Executing %s", name)
try:
await script_obj.async_run(variables, context)
except Exception as err: # pylint: disable=broad-except
script_obj.async_log_exception(
_LOGGER, f"Error while executing automation {entity_id}", err
)
return action
async def _async_process_if(hass, config, p_config): async def _async_process_if(hass, config, p_config):
"""Process if checks.""" """Process if checks."""
if_configs = p_config.get(CONF_CONDITION) if_configs = p_config.get(CONF_CONDITION)

View File

@ -1,6 +1,7 @@
"""Support for scripts.""" """Support for scripts."""
import asyncio import asyncio
import logging import logging
from typing import List
import voluptuous as vol import voluptuous as vol
@ -15,6 +16,7 @@ from homeassistant.const import (
SERVICE_TURN_ON, SERVICE_TURN_ON,
STATE_ON, STATE_ON,
) )
from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.config_validation import make_entity_service_schema from homeassistant.helpers.config_validation import make_entity_service_schema
from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity import ToggleEntity
@ -69,9 +71,75 @@ def is_on(hass, entity_id):
return hass.states.is_state(entity_id, STATE_ON) return hass.states.is_state(entity_id, STATE_ON)
@callback
def scripts_with_entity(hass: HomeAssistant, entity_id: str) -> List[str]:
"""Return all scripts that reference the entity."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
results = []
for script_entity in component.entities:
if entity_id in script_entity.script.referenced_entities:
results.append(script_entity.entity_id)
return results
@callback
def entities_in_script(hass: HomeAssistant, entity_id: str) -> List[str]:
"""Return all entities in a scene."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
script_entity = component.get_entity(entity_id)
if script_entity is None:
return []
return list(script_entity.script.referenced_entities)
@callback
def scripts_with_device(hass: HomeAssistant, device_id: str) -> List[str]:
"""Return all scripts that reference the device."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
results = []
for script_entity in component.entities:
if device_id in script_entity.script.referenced_devices:
results.append(script_entity.entity_id)
return results
@callback
def devices_in_script(hass: HomeAssistant, entity_id: str) -> List[str]:
"""Return all devices in a scene."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
script_entity = component.get_entity(entity_id)
if script_entity is None:
return []
return list(script_entity.script.referenced_devices)
async def async_setup(hass, config): async def async_setup(hass, config):
"""Load the scripts from the configuration.""" """Load the scripts from the configuration."""
component = EntityComponent(_LOGGER, DOMAIN, hass) hass.data[DOMAIN] = component = EntityComponent(_LOGGER, DOMAIN, hass)
await _async_process_config(hass, config, component) await _async_process_config(hass, config, component)

View File

@ -1,14 +1,16 @@
"""The Search integration.""" """The Search integration."""
from collections import defaultdict from collections import defaultdict, deque
import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components import group, websocket_api from homeassistant.components import automation, group, script, websocket_api
from homeassistant.components.homeassistant import scene from homeassistant.components.homeassistant import scene
from homeassistant.core import HomeAssistant, callback, split_entity_id from homeassistant.core import HomeAssistant, callback, split_entity_id
from homeassistant.helpers import device_registry, entity_registry from homeassistant.helpers import device_registry, entity_registry
DOMAIN = "search" DOMAIN = "search"
_LOGGER = logging.getLogger(__name__)
async def async_setup(hass: HomeAssistant, config: dict): async def async_setup(hass: HomeAssistant, config: dict):
@ -73,16 +75,17 @@ class Searcher:
self._device_reg = device_reg self._device_reg = device_reg
self._entity_reg = entity_reg self._entity_reg = entity_reg
self.results = defaultdict(set) self.results = defaultdict(set)
self._to_resolve = set() self._to_resolve = deque()
@callback @callback
def async_search(self, item_type, item_id): def async_search(self, item_type, item_id):
"""Find results.""" """Find results."""
_LOGGER.debug("Searching for %s/%s", item_type, item_id)
self.results[item_type].add(item_id) self.results[item_type].add(item_id)
self._to_resolve.add((item_type, item_id)) self._to_resolve.append((item_type, item_id))
while self._to_resolve: while self._to_resolve:
search_type, search_id = self._to_resolve.pop() search_type, search_id = self._to_resolve.popleft()
getattr(self, f"_resolve_{search_type}")(search_id) getattr(self, f"_resolve_{search_type}")(search_id)
# Clean up entity_id items, from the general "entity" type result, # Clean up entity_id items, from the general "entity" type result,
@ -112,7 +115,7 @@ class Searcher:
self.results[item_type].add(item_id) self.results[item_type].add(item_id)
if item_type not in self.DONT_RESOLVE: if item_type not in self.DONT_RESOLVE:
self._to_resolve.add((item_type, item_id)) self._to_resolve.append((item_type, item_id))
@callback @callback
def _resolve_area(self, area_id) -> None: def _resolve_area(self, area_id) -> None:
@ -140,7 +143,11 @@ class Searcher:
): ):
self._add_or_resolve("entity", entity_entry.entity_id) self._add_or_resolve("entity", entity_entry.entity_id)
# Extra: Find automations that reference this device for entity_id in script.scripts_with_device(self.hass, device_id):
self._add_or_resolve("entity", entity_id)
for entity_id in automation.automations_with_device(self.hass, device_id):
self._add_or_resolve("entity", entity_id)
@callback @callback
def _resolve_entity(self, entity_id) -> None: def _resolve_entity(self, entity_id) -> None:
@ -153,6 +160,12 @@ class Searcher:
for entity in group.groups_with_entity(self.hass, entity_id): for entity in group.groups_with_entity(self.hass, entity_id):
self._add_or_resolve("entity", entity) self._add_or_resolve("entity", entity)
for entity in automation.automations_with_entity(self.hass, entity_id):
self._add_or_resolve("entity", entity)
for entity in script.scripts_with_entity(self.hass, entity_id):
self._add_or_resolve("entity", entity)
# Find devices # Find devices
entity_entry = self._entity_reg.async_get(entity_id) entity_entry = self._entity_reg.async_get(entity_id)
if entity_entry is not None: if entity_entry is not None:
@ -164,7 +177,7 @@ class Searcher:
domain = split_entity_id(entity_id)[0] domain = split_entity_id(entity_id)[0]
if domain in ("scene", "automation", "script", "group"): if domain in self.EXIST_AS_ENTITY:
self._add_or_resolve(domain, entity_id) self._add_or_resolve(domain, entity_id)
@callback @callback
@ -173,7 +186,13 @@ class Searcher:
Will only be called if automation is an entry point. Will only be called if automation is an entry point.
""" """
# Extra: Check with automation integration what entities/devices they reference for entity in automation.entities_in_automation(
self.hass, automation_entity_id
):
self._add_or_resolve("entity", entity)
for device in automation.devices_in_automation(self.hass, automation_entity_id):
self._add_or_resolve("device", device)
@callback @callback
def _resolve_script(self, script_entity_id) -> None: def _resolve_script(self, script_entity_id) -> None:
@ -181,7 +200,11 @@ class Searcher:
Will only be called if script is an entry point. Will only be called if script is an entry point.
""" """
# Extra: Check with script integration what entities/devices they reference for entity in script.entities_in_script(self.hass, script_entity_id):
self._add_or_resolve("entity", entity)
for device in script.devices_in_script(self.hass, script_entity_id):
self._add_or_resolve("device", device)
@callback @callback
def _resolve_group(self, group_entity_id) -> None: def _resolve_group(self, group_entity_id) -> None:

View File

@ -7,6 +7,6 @@
"zeroconf": [], "zeroconf": [],
"homekit": {}, "homekit": {},
"dependencies": ["websocket_api"], "dependencies": ["websocket_api"],
"after_dependencies": ["scene", "group"], "after_dependencies": ["scene", "group", "automation", "script"],
"codeowners": ["@home-assistant/core"] "codeowners": ["@home-assistant/core"]
} }

View File

@ -1,10 +1,11 @@
"""Offer reusable conditions.""" """Offer reusable conditions."""
import asyncio import asyncio
from collections import deque
from datetime import datetime, timedelta from datetime import datetime, timedelta
import functools as ft import functools as ft
import logging import logging
import sys import sys
from typing import Callable, Container, Optional, Union, cast from typing import Callable, Container, Optional, Set, Union, cast
from homeassistant.components import zone as zone_cmp from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import ( from homeassistant.components.device_automation import (
@ -19,6 +20,7 @@ from homeassistant.const import (
CONF_BEFORE, CONF_BEFORE,
CONF_BELOW, CONF_BELOW,
CONF_CONDITION, CONF_CONDITION,
CONF_DEVICE_ID,
CONF_DOMAIN, CONF_DOMAIN,
CONF_ENTITY_ID, CONF_ENTITY_ID,
CONF_STATE, CONF_STATE,
@ -31,7 +33,7 @@ from homeassistant.const import (
SUN_EVENT_SUNSET, SUN_EVENT_SUNSET,
WEEKDAYS, WEEKDAYS,
) )
from homeassistant.core import HomeAssistant, State from homeassistant.core import HomeAssistant, State, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.exceptions import HomeAssistantError, TemplateError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.sun import get_astral_event_date from homeassistant.helpers.sun import get_astral_event_date
@ -529,3 +531,50 @@ async def async_validate_condition_config(
return cast(ConfigType, platform.CONDITION_SCHEMA(config)) # type: ignore return cast(ConfigType, platform.CONDITION_SCHEMA(config)) # type: ignore
return config return config
@callback
def async_extract_entities(config: ConfigType) -> Set[str]:
"""Extract entities from a condition."""
referenced = set()
to_process = deque([config])
while to_process:
config = to_process.popleft()
condition = config[CONF_CONDITION]
if condition in ("and", "or"):
to_process.extend(config["conditions"])
continue
entity_id = config.get(CONF_ENTITY_ID)
if entity_id is not None:
referenced.add(entity_id)
return referenced
@callback
def async_extract_devices(config: ConfigType) -> Set[str]:
"""Extract devices from a condition."""
referenced = set()
to_process = deque([config])
while to_process:
config = to_process.popleft()
condition = config[CONF_CONDITION]
if condition in ("and", "or"):
to_process.extend(config["conditions"])
continue
if condition != "device":
continue
device_id = config.get(CONF_DEVICE_ID)
if device_id is not None:
referenced.add(device_id)
return referenced

View File

@ -156,12 +156,66 @@ class Script:
ACTION_DEVICE_AUTOMATION: self._async_device_automation, ACTION_DEVICE_AUTOMATION: self._async_device_automation,
ACTION_ACTIVATE_SCENE: self._async_activate_scene, ACTION_ACTIVATE_SCENE: self._async_activate_scene,
} }
self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
"""Return true if script is on.""" """Return true if script is on."""
return self._cur != -1 return self._cur != -1
@property
def referenced_devices(self):
"""Return a set of referenced devices."""
if self._referenced_devices is not None:
return self._referenced_devices
referenced = set()
for step in self.sequence:
action = _determine_action(step)
if action == ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_devices(step)
elif action == ACTION_DEVICE_AUTOMATION:
referenced.add(step[CONF_DEVICE_ID])
self._referenced_devices = referenced
return referenced
@property
def referenced_entities(self):
"""Return a set of referenced entities."""
if self._referenced_entities is not None:
return self._referenced_entities
referenced = set()
for step in self.sequence:
action = _determine_action(step)
if action == ACTION_CALL_SERVICE:
data = step.get(service.CONF_SERVICE_DATA)
if not data:
continue
entity_ids = data.get(ATTR_ENTITY_ID)
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
for entity_id in entity_ids:
referenced.add(entity_id)
elif action == ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_entities(step)
elif action == ACTION_ACTIVATE_SCENE:
referenced.add(step[CONF_SCENE])
self._referenced_entities = referenced
return referenced
def run(self, variables=None, context=None): def run(self, variables=None, context=None):
"""Run script.""" """Run script."""
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(

View File

@ -5,6 +5,7 @@ from unittest.mock import Mock, patch
import pytest import pytest
import homeassistant.components.automation as automation import homeassistant.components.automation as automation
from homeassistant.components.automation import DOMAIN
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
ATTR_NAME, ATTR_NAME,
@ -922,3 +923,80 @@ async def test_automation_restore_last_triggered_with_initial_state(hass):
assert state assert state
assert state.state == STATE_ON assert state.state == STATE_ON
assert state.attributes["last_triggered"] == time assert state.attributes["last_triggered"] == time
async def test_extraction_functions(hass):
"""Test extraction functions."""
assert await async_setup_component(
hass,
DOMAIN,
{
DOMAIN: [
{
"alias": "test1",
"trigger": {"platform": "state", "entity_id": "sensor.trigger_1"},
"action": [
{
"service": "test.script",
"data": {"entity_id": "light.in_both"},
},
{
"service": "test.script",
"data": {"entity_id": "light.in_first"},
},
{
"domain": "light",
"device_id": "device-in-both",
"entity_id": "light.bla",
"type": "turn_on",
},
],
},
{
"alias": "test2",
"trigger": {"platform": "state", "entity_id": "sensor.trigger_2"},
"action": [
{
"service": "test.script",
"data": {"entity_id": "light.in_both"},
},
{
"condition": "state",
"entity_id": "sensor.condition",
"state": "100",
},
{"scene": "scene.hello"},
{
"domain": "light",
"device_id": "device-in-both",
"entity_id": "light.bla",
"type": "turn_on",
},
{
"domain": "light",
"device_id": "device-in-last",
"entity_id": "light.bla",
"type": "turn_on",
},
],
},
]
},
)
assert set(automation.automations_with_entity(hass, "light.in_both")) == {
"automation.test1",
"automation.test2",
}
assert set(automation.entities_in_automation(hass, "automation.test1")) == {
"light.in_both",
"light.in_first",
}
assert set(automation.automations_with_device(hass, "device-in-both")) == {
"automation.test1",
"automation.test2",
}
assert set(automation.devices_in_automation(hass, "automation.test2")) == {
"device-in-both",
"device-in-last",
}

View File

@ -358,9 +358,8 @@ async def test_turning_no_scripts_off(hass):
async def test_async_get_descriptions_script(hass): async def test_async_get_descriptions_script(hass):
"""Test async_set_service_schema for the script integration.""" """Test async_set_service_schema for the script integration."""
script = hass.components.script
script_config = { script_config = {
script.DOMAIN: { DOMAIN: {
"test1": {"sequence": [{"service": "homeassistant.restart"}]}, "test1": {"sequence": [{"service": "homeassistant.restart"}]},
"test2": { "test2": {
"description": "test2", "description": "test2",
@ -375,18 +374,75 @@ async def test_async_get_descriptions_script(hass):
} }
} }
await async_setup_component(hass, script.DOMAIN, script_config) await async_setup_component(hass, DOMAIN, script_config)
descriptions = await hass.helpers.service.async_get_all_descriptions() descriptions = await hass.helpers.service.async_get_all_descriptions()
assert descriptions[script.DOMAIN]["test1"]["description"] == "" assert descriptions[DOMAIN]["test1"]["description"] == ""
assert not descriptions[script.DOMAIN]["test1"]["fields"] assert not descriptions[DOMAIN]["test1"]["fields"]
assert descriptions[script.DOMAIN]["test2"]["description"] == "test2" assert descriptions[DOMAIN]["test2"]["description"] == "test2"
assert ( assert (
descriptions[script.DOMAIN]["test2"]["fields"]["param"]["description"] descriptions[DOMAIN]["test2"]["fields"]["param"]["description"]
== "param_description" == "param_description"
) )
assert ( assert (
descriptions[script.DOMAIN]["test2"]["fields"]["param"]["example"] descriptions[DOMAIN]["test2"]["fields"]["param"]["example"] == "param_example"
== "param_example"
) )
async def test_extraction_functions(hass):
"""Test extraction functions."""
assert await async_setup_component(
hass,
DOMAIN,
{
DOMAIN: {
"test1": {
"sequence": [
{
"service": "test.script",
"data": {"entity_id": "light.in_both"},
},
{
"service": "test.script",
"data": {"entity_id": "light.in_first"},
},
{"domain": "light", "device_id": "device-in-both"},
]
},
"test2": {
"sequence": [
{
"service": "test.script",
"data": {"entity_id": "light.in_both"},
},
{
"condition": "state",
"entity_id": "sensor.condition",
"state": "100",
},
{"scene": "scene.hello"},
{"domain": "light", "device_id": "device-in-both"},
{"domain": "light", "device_id": "device-in-last"},
],
},
}
},
)
assert set(script.scripts_with_entity(hass, "light.in_both")) == {
"script.test1",
"script.test2",
}
assert set(script.entities_in_script(hass, "script.test1")) == {
"light.in_both",
"light.in_first",
}
assert set(script.scripts_with_device(hass, "device-in-both")) == {
"script.test1",
"script.test2",
}
assert set(script.devices_in_script(hass, "script.test2")) == {
"device-in-both",
"device-in-last",
}

View File

@ -131,6 +131,62 @@ async def test_search(hass):
}, },
) )
await async_setup_component(
hass,
"script",
{
"script": {
"wled": {
"sequence": [
{
"service": "test.script",
"data": {"entity_id": wled_segment_1_entity.entity_id},
},
]
},
"hue": {
"sequence": [
{
"service": "test.script",
"data": {"entity_id": hue_segment_1_entity.entity_id},
},
]
},
}
},
)
assert await async_setup_component(
hass,
"automation",
{
"automation": [
{
"alias": "wled_entity",
"trigger": {"platform": "state", "entity_id": "sensor.trigger_1"},
"action": [
{
"service": "test.script",
"data": {"entity_id": wled_segment_1_entity.entity_id},
},
],
},
{
"alias": "wled_device",
"trigger": {"platform": "state", "entity_id": "sensor.trigger_1"},
"action": [
{
"domain": "light",
"device_id": wled_device.id,
"entity_id": wled_segment_1_entity.entity_id,
"type": "turn_on",
},
],
},
]
},
)
# Explore the graph from every node and make sure we find the same results # Explore the graph from every node and make sure we find the same results
expected = { expected = {
"config_entry": {wled_config_entry.entry_id}, "config_entry": {wled_config_entry.entry_id},
@ -139,6 +195,8 @@ async def test_search(hass):
"entity": {wled_segment_1_entity.entity_id, wled_segment_2_entity.entity_id}, "entity": {wled_segment_1_entity.entity_id, wled_segment_2_entity.entity_id},
"scene": {"scene.scene_wled_seg_1", "scene.scene_wled_hue"}, "scene": {"scene.scene_wled_seg_1", "scene.scene_wled_hue"},
"group": {"group.wled", "group.wled_hue"}, "group": {"group.wled", "group.wled_hue"},
"script": {"script.wled"},
"automation": {"automation.wled_entity", "automation.wled_device"},
} }
for search_type, search_id in ( for search_type, search_id in (
@ -149,6 +207,9 @@ async def test_search(hass):
("entity", wled_segment_2_entity.entity_id), ("entity", wled_segment_2_entity.entity_id),
("scene", "scene.scene_wled_seg_1"), ("scene", "scene.scene_wled_seg_1"),
("group", "group.wled"), ("group", "group.wled"),
("script", "script.wled"),
("automation", "automation.wled_entity"),
("automation", "automation.wled_device"),
): ):
searcher = search.Searcher(hass, device_reg, entity_reg) searcher = search.Searcher(hass, device_reg, entity_reg)
results = searcher.async_search(search_type, search_id) results = searcher.async_search(search_type, search_id)
@ -176,6 +237,8 @@ async def test_search(hass):
"scene.scene_wled_hue", "scene.scene_wled_hue",
}, },
"group": {"group.wled", "group.hue", "group.wled_hue"}, "group": {"group.wled", "group.hue", "group.wled_hue"},
"script": {"script.wled", "script.hue"},
"automation": {"automation.wled_entity", "automation.wled_device"},
} }
for search_type, search_id in ( for search_type, search_id in (
("scene", "scene.scene_wled_hue"), ("scene", "scene.scene_wled_hue"),

View File

@ -176,3 +176,37 @@ async def test_if_numeric_state_not_raise_on_unavailable(hass):
hass.states.async_set("sensor.temperature", "unknown") hass.states.async_set("sensor.temperature", "unknown")
assert not test(hass) assert not test(hass)
assert len(logwarn.mock_calls) == 0 assert len(logwarn.mock_calls) == 0
async def test_extract_entities():
"""Test extracting entities."""
condition.async_extract_entities(
{
"condition": "and",
"conditions": [
{
"condition": "state",
"entity_id": "sensor.temperature",
"state": "100",
},
{
"condition": "numeric_state",
"entity_id": "sensor.temperature_2",
"below": 110,
},
],
}
) == {"sensor.temperature", "sensor.temperature_2"}
async def test_extract_devices():
"""Test extracting devices."""
condition.async_extract_devices(
{
"condition": "and",
"conditions": [
{"condition": "device", "device_id": "abcd", "domain": "light"},
{"condition": "device", "device_id": "qwer", "domain": "switch"},
],
}
) == {"abcd", "qwer"}

View File

@ -1022,3 +1022,58 @@ def test_log_exception():
assert p_error == "" assert p_error == ""
else: else:
assert p_error == str(exc) assert p_error == str(exc)
async def test_referenced_entities():
"""Test referenced entities."""
script_obj = script.Script(
None,
cv.SCRIPT_SCHEMA(
[
{
"service": "test.script",
"data": {"entity_id": "light.service_not_list"},
},
{
"service": "test.script",
"data": {"entity_id": ["light.service_list"]},
},
{
"condition": "state",
"entity_id": "sensor.condition",
"state": "100",
},
{"scene": "scene.hello"},
{"event": "test_event"},
{"delay": "{{ delay_period }}"},
]
),
)
assert script_obj.referenced_entities == {
"light.service_not_list",
"light.service_list",
"sensor.condition",
"scene.hello",
}
# Test we cache results.
assert script_obj.referenced_entities is script_obj.referenced_entities
async def test_referenced_devices():
"""Test referenced entities."""
script_obj = script.Script(
None,
cv.SCRIPT_SCHEMA(
[
{"domain": "light", "device_id": "script-dev-id"},
{
"condition": "device",
"device_id": "condition-dev-id",
"domain": "switch",
},
]
),
)
assert script_obj.referenced_devices == {"script-dev-id", "condition-dev-id"}
# Test we cache results.
assert script_obj.referenced_devices is script_obj.referenced_devices